You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
316 lines
9.7 KiB
Python
316 lines
9.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
OCR API 测试
|
|
"""
|
|
|
|
import io
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
class TestOCRRecognizeMultipart:
|
|
"""OCR 识别端点测试 (multipart/form-data)"""
|
|
|
|
def test_recognize_success(
|
|
self,
|
|
test_client: TestClient,
|
|
sample_image_bytes: bytes,
|
|
):
|
|
"""测试 OCR 识别 - 正常情况"""
|
|
response = test_client.post(
|
|
"/api/v1/ocr/recognize",
|
|
files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
|
|
data={"lang": "ch", "drop_score": "0.5"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["data"] is not None
|
|
assert "text_count" in data["data"]
|
|
assert "text_blocks" in data["data"]
|
|
assert "processing_time_ms" in data["data"]
|
|
|
|
def test_recognize_with_roi(
|
|
self,
|
|
test_client: TestClient,
|
|
sample_image_bytes: bytes,
|
|
):
|
|
"""测试 OCR 识别 - 带 ROI 参数"""
|
|
response = test_client.post(
|
|
"/api/v1/ocr/recognize",
|
|
files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
|
|
data={
|
|
"lang": "ch",
|
|
"roi_x": "0.1",
|
|
"roi_y": "0.1",
|
|
"roi_width": "0.8",
|
|
"roi_height": "0.8",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
|
|
def test_recognize_with_annotated_image(
|
|
self,
|
|
test_client: TestClient,
|
|
sample_image_bytes: bytes,
|
|
):
|
|
"""测试 OCR 识别 - 返回标注图片"""
|
|
response = test_client.post(
|
|
"/api/v1/ocr/recognize",
|
|
files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
|
|
data={"return_annotated_image": "true"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
# 注意: 标注图片只有在有识别结果时才返回
|
|
if data["data"]["text_count"] > 0:
|
|
assert data["data"]["annotated_image_base64"] is not None
|
|
|
|
def test_recognize_png_image(
|
|
self,
|
|
test_client: TestClient,
|
|
sample_png_bytes: bytes,
|
|
):
|
|
"""测试 OCR 识别 - PNG 格式"""
|
|
response = test_client.post(
|
|
"/api/v1/ocr/recognize",
|
|
files={"file": ("test.png", io.BytesIO(sample_png_bytes), "image/png")},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
|
|
def test_recognize_invalid_file(
|
|
self,
|
|
test_client: TestClient,
|
|
invalid_file_bytes: bytes,
|
|
):
|
|
"""测试 OCR 识别 - 无效文件"""
|
|
response = test_client.post(
|
|
"/api/v1/ocr/recognize",
|
|
files={
|
|
"file": (
|
|
"test.txt",
|
|
io.BytesIO(invalid_file_bytes),
|
|
"text/plain",
|
|
)
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is False
|
|
assert data["error"] is not None
|
|
|
|
def test_recognize_no_file(self, test_client: TestClient):
|
|
"""测试 OCR 识别 - 未提供文件"""
|
|
response = test_client.post("/api/v1/ocr/recognize")
|
|
|
|
assert response.status_code == 422 # Validation Error
|
|
|
|
|
|
class TestOCRRecognizeBase64:
|
|
"""OCR 识别端点测试 (Base64 JSON)"""
|
|
|
|
def test_recognize_base64_success(
|
|
self,
|
|
test_client: TestClient,
|
|
sample_image_base64: str,
|
|
):
|
|
"""测试 OCR 识别 (Base64) - 正常情况"""
|
|
response = test_client.post(
|
|
"/api/v1/ocr/recognize/base64",
|
|
json={
|
|
"image_base64": sample_image_base64,
|
|
"lang": "ch",
|
|
"drop_score": 0.5,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["data"] is not None
|
|
|
|
def test_recognize_base64_with_data_url(
|
|
self,
|
|
test_client: TestClient,
|
|
sample_image_base64: str,
|
|
):
|
|
"""测试 OCR 识别 (Base64) - Data URL 格式"""
|
|
data_url = f"data:image/jpeg;base64,{sample_image_base64}"
|
|
response = test_client.post(
|
|
"/api/v1/ocr/recognize/base64",
|
|
json={"image_base64": data_url},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
|
|
def test_recognize_base64_with_roi(
|
|
self,
|
|
test_client: TestClient,
|
|
sample_image_base64: str,
|
|
):
|
|
"""测试 OCR 识别 (Base64) - 带 ROI 参数"""
|
|
response = test_client.post(
|
|
"/api/v1/ocr/recognize/base64",
|
|
json={
|
|
"image_base64": sample_image_base64,
|
|
"roi": {"x": 0.1, "y": 0.1, "width": 0.8, "height": 0.8},
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
|
|
def test_recognize_base64_invalid(self, test_client: TestClient):
|
|
"""测试 OCR 识别 (Base64) - 无效 Base64"""
|
|
response = test_client.post(
|
|
"/api/v1/ocr/recognize/base64",
|
|
json={"image_base64": "not-valid-base64!!!"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is False
|
|
assert data["error"] is not None
|
|
|
|
def test_recognize_base64_missing_field(self, test_client: TestClient):
|
|
"""测试 OCR 识别 (Base64) - 缺少必填字段"""
|
|
response = test_client.post(
|
|
"/api/v1/ocr/recognize/base64",
|
|
json={"lang": "ch"}, # 缺少 image_base64
|
|
)
|
|
|
|
assert response.status_code == 422 # Validation Error
|
|
|
|
|
|
class TestExpressMultipart:
|
|
"""快递单解析端点测试 (multipart/form-data)"""
|
|
|
|
def test_express_success(
|
|
self,
|
|
test_client: TestClient,
|
|
sample_image_bytes: bytes,
|
|
):
|
|
"""测试快递单解析 - 正常情况"""
|
|
response = test_client.post(
|
|
"/api/v1/ocr/express",
|
|
files={"file": ("express.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["data"] is not None
|
|
assert "express_info" in data["data"]
|
|
assert "merged_text" in data["data"]
|
|
assert "processing_time_ms" in data["data"]
|
|
|
|
def test_express_info_structure(
|
|
self,
|
|
test_client: TestClient,
|
|
sample_image_bytes: bytes,
|
|
):
|
|
"""测试快递单解析 - 响应结构"""
|
|
response = test_client.post(
|
|
"/api/v1/ocr/express",
|
|
files={"file": ("express.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
express_info = data["data"]["express_info"]
|
|
|
|
# 验证结构完整性
|
|
assert "tracking_number" in express_info
|
|
assert "sender" in express_info
|
|
assert "receiver" in express_info
|
|
assert "courier_company" in express_info
|
|
assert "confidence" in express_info
|
|
|
|
# 验证 sender/receiver 结构
|
|
assert "name" in express_info["sender"]
|
|
assert "phone" in express_info["sender"]
|
|
assert "address" in express_info["sender"]
|
|
|
|
|
|
class TestExpressBase64:
|
|
"""快递单解析端点测试 (Base64 JSON)"""
|
|
|
|
def test_express_base64_success(
|
|
self,
|
|
test_client: TestClient,
|
|
sample_image_base64: str,
|
|
):
|
|
"""测试快递单解析 (Base64) - 正常情况"""
|
|
response = test_client.post(
|
|
"/api/v1/ocr/express/base64",
|
|
json={"image_base64": sample_image_base64},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["data"] is not None
|
|
assert "express_info" in data["data"]
|
|
|
|
|
|
class TestSecurityValidation:
|
|
"""安全验证测试"""
|
|
|
|
def test_file_size_limit(self, test_client: TestClient):
|
|
"""测试文件大小限制"""
|
|
# 创建一个超大的假文件 (11MB)
|
|
large_content = b"x" * (11 * 1024 * 1024)
|
|
|
|
response = test_client.post(
|
|
"/api/v1/ocr/recognize",
|
|
files={"file": ("large.jpg", io.BytesIO(large_content), "image/jpeg")},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is False
|
|
assert "大小" in data["error"]["message"] or "size" in data["error"]["message"].lower()
|
|
|
|
def test_invalid_extension(
|
|
self,
|
|
test_client: TestClient,
|
|
sample_image_bytes: bytes,
|
|
):
|
|
"""测试无效文件扩展名"""
|
|
response = test_client.post(
|
|
"/api/v1/ocr/recognize",
|
|
files={"file": ("test.exe", io.BytesIO(sample_image_bytes), "application/octet-stream")},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is False
|
|
|
|
def test_magic_bytes_validation(self, test_client: TestClient):
|
|
"""测试文件魔数验证"""
|
|
# 创建一个假的 jpg 文件 (扩展名正确但内容不是图片)
|
|
fake_jpg = b"This is not a real JPEG file"
|
|
|
|
response = test_client.post(
|
|
"/api/v1/ocr/recognize",
|
|
files={"file": ("fake.jpg", io.BytesIO(fake_jpg), "image/jpeg")},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is False
|