You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

316 lines
9.7 KiB
Python

# -*- coding: utf-8 -*-
"""
OCR API 测试
"""
import io
import pytest
from fastapi.testclient import TestClient
class TestOCRRecognizeMultipart:
"""OCR 识别端点测试 (multipart/form-data)"""
def test_recognize_success(
self,
test_client: TestClient,
sample_image_bytes: bytes,
):
"""测试 OCR 识别 - 正常情况"""
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
data={"lang": "ch", "drop_score": "0.5"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"] is not None
assert "text_count" in data["data"]
assert "text_blocks" in data["data"]
assert "processing_time_ms" in data["data"]
def test_recognize_with_roi(
self,
test_client: TestClient,
sample_image_bytes: bytes,
):
"""测试 OCR 识别 - 带 ROI 参数"""
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
data={
"lang": "ch",
"roi_x": "0.1",
"roi_y": "0.1",
"roi_width": "0.8",
"roi_height": "0.8",
},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_recognize_with_annotated_image(
self,
test_client: TestClient,
sample_image_bytes: bytes,
):
"""测试 OCR 识别 - 返回标注图片"""
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
data={"return_annotated_image": "true"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
# 注意: 标注图片只有在有识别结果时才返回
if data["data"]["text_count"] > 0:
assert data["data"]["annotated_image_base64"] is not None
def test_recognize_png_image(
self,
test_client: TestClient,
sample_png_bytes: bytes,
):
"""测试 OCR 识别 - PNG 格式"""
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("test.png", io.BytesIO(sample_png_bytes), "image/png")},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_recognize_invalid_file(
self,
test_client: TestClient,
invalid_file_bytes: bytes,
):
"""测试 OCR 识别 - 无效文件"""
response = test_client.post(
"/api/v1/ocr/recognize",
files={
"file": (
"test.txt",
io.BytesIO(invalid_file_bytes),
"text/plain",
)
},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["error"] is not None
def test_recognize_no_file(self, test_client: TestClient):
"""测试 OCR 识别 - 未提供文件"""
response = test_client.post("/api/v1/ocr/recognize")
assert response.status_code == 422 # Validation Error
class TestOCRRecognizeBase64:
"""OCR 识别端点测试 (Base64 JSON)"""
def test_recognize_base64_success(
self,
test_client: TestClient,
sample_image_base64: str,
):
"""测试 OCR 识别 (Base64) - 正常情况"""
response = test_client.post(
"/api/v1/ocr/recognize/base64",
json={
"image_base64": sample_image_base64,
"lang": "ch",
"drop_score": 0.5,
},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"] is not None
def test_recognize_base64_with_data_url(
self,
test_client: TestClient,
sample_image_base64: str,
):
"""测试 OCR 识别 (Base64) - Data URL 格式"""
data_url = f"data:image/jpeg;base64,{sample_image_base64}"
response = test_client.post(
"/api/v1/ocr/recognize/base64",
json={"image_base64": data_url},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_recognize_base64_with_roi(
self,
test_client: TestClient,
sample_image_base64: str,
):
"""测试 OCR 识别 (Base64) - 带 ROI 参数"""
response = test_client.post(
"/api/v1/ocr/recognize/base64",
json={
"image_base64": sample_image_base64,
"roi": {"x": 0.1, "y": 0.1, "width": 0.8, "height": 0.8},
},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_recognize_base64_invalid(self, test_client: TestClient):
"""测试 OCR 识别 (Base64) - 无效 Base64"""
response = test_client.post(
"/api/v1/ocr/recognize/base64",
json={"image_base64": "not-valid-base64!!!"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["error"] is not None
def test_recognize_base64_missing_field(self, test_client: TestClient):
"""测试 OCR 识别 (Base64) - 缺少必填字段"""
response = test_client.post(
"/api/v1/ocr/recognize/base64",
json={"lang": "ch"}, # 缺少 image_base64
)
assert response.status_code == 422 # Validation Error
class TestExpressMultipart:
"""快递单解析端点测试 (multipart/form-data)"""
def test_express_success(
self,
test_client: TestClient,
sample_image_bytes: bytes,
):
"""测试快递单解析 - 正常情况"""
response = test_client.post(
"/api/v1/ocr/express",
files={"file": ("express.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"] is not None
assert "express_info" in data["data"]
assert "merged_text" in data["data"]
assert "processing_time_ms" in data["data"]
def test_express_info_structure(
self,
test_client: TestClient,
sample_image_bytes: bytes,
):
"""测试快递单解析 - 响应结构"""
response = test_client.post(
"/api/v1/ocr/express",
files={"file": ("express.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
)
assert response.status_code == 200
data = response.json()
express_info = data["data"]["express_info"]
# 验证结构完整性
assert "tracking_number" in express_info
assert "sender" in express_info
assert "receiver" in express_info
assert "courier_company" in express_info
assert "confidence" in express_info
# 验证 sender/receiver 结构
assert "name" in express_info["sender"]
assert "phone" in express_info["sender"]
assert "address" in express_info["sender"]
class TestExpressBase64:
"""快递单解析端点测试 (Base64 JSON)"""
def test_express_base64_success(
self,
test_client: TestClient,
sample_image_base64: str,
):
"""测试快递单解析 (Base64) - 正常情况"""
response = test_client.post(
"/api/v1/ocr/express/base64",
json={"image_base64": sample_image_base64},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"] is not None
assert "express_info" in data["data"]
class TestSecurityValidation:
"""安全验证测试"""
def test_file_size_limit(self, test_client: TestClient):
"""测试文件大小限制"""
# 创建一个超大的假文件 (11MB)
large_content = b"x" * (11 * 1024 * 1024)
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("large.jpg", io.BytesIO(large_content), "image/jpeg")},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert "大小" in data["error"]["message"] or "size" in data["error"]["message"].lower()
def test_invalid_extension(
self,
test_client: TestClient,
sample_image_bytes: bytes,
):
"""测试无效文件扩展名"""
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("test.exe", io.BytesIO(sample_image_bytes), "application/octet-stream")},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
def test_magic_bytes_validation(self, test_client: TestClient):
"""测试文件魔数验证"""
# 创建一个假的 jpg 文件 (扩展名正确但内容不是图片)
fake_jpg = b"This is not a real JPEG file"
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("fake.jpg", io.BytesIO(fake_jpg), "image/jpeg")},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False