feat: 添加基于 FastAPI 的 REST API 服务
parent
7570e1314d
commit
c05e3e58ed
@ -0,0 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Vision-OCR REST API 模块
|
||||
提供 HTTP 接口访问 OCR 功能
|
||||
"""
|
||||
|
||||
from api.main import app
|
||||
|
||||
__all__ = ["app"]
|
||||
@ -0,0 +1,157 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
依赖注入模块
|
||||
提供 FastAPI 依赖项
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import Depends, File, Form, Request, UploadFile
|
||||
|
||||
from api.exceptions import InvalidImageError, ModelNotLoadedError
|
||||
from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams
|
||||
from api.security import decode_base64_image, validate_uploaded_file
|
||||
from ocr.pipeline import OCRPipeline, OCRResult
|
||||
from utils.config import OCRConfig, PipelineConfig, ROIConfig
|
||||
|
||||
|
||||
def get_ocr_pipeline(request: Request) -> OCRPipeline:
|
||||
"""
|
||||
获取 OCR Pipeline 实例
|
||||
|
||||
Args:
|
||||
request: FastAPI 请求对象
|
||||
|
||||
Returns:
|
||||
OCRPipeline 实例
|
||||
|
||||
Raises:
|
||||
ModelNotLoadedError: 模型未加载
|
||||
"""
|
||||
pipeline = getattr(request.app.state, "ocr_pipeline", None)
|
||||
if pipeline is None:
|
||||
raise ModelNotLoadedError()
|
||||
return pipeline
|
||||
|
||||
|
||||
async def parse_multipart_image(
|
||||
file: UploadFile = File(..., description="图片文件"),
|
||||
) -> bytes:
|
||||
"""
|
||||
解析 multipart 上传的图片文件
|
||||
|
||||
Args:
|
||||
file: 上传的文件
|
||||
|
||||
Returns:
|
||||
验证后的图片字节数据
|
||||
"""
|
||||
content = await file.read()
|
||||
return validate_uploaded_file(
|
||||
content=content,
|
||||
filename=file.filename,
|
||||
content_type=file.content_type,
|
||||
)
|
||||
|
||||
|
||||
async def parse_multipart_params(
|
||||
lang: str = Form(default="ch", description="识别语言"),
|
||||
use_gpu: bool = Form(default=False, description="是否使用 GPU"),
|
||||
drop_score: float = Form(default=0.5, ge=0.0, le=1.0, description="置信度阈值"),
|
||||
roi_x: Optional[float] = Form(default=None, ge=0.0, le=1.0, description="ROI X"),
|
||||
roi_y: Optional[float] = Form(default=None, ge=0.0, le=1.0, description="ROI Y"),
|
||||
roi_width: Optional[float] = Form(
|
||||
default=None, gt=0.0, le=1.0, description="ROI 宽度"
|
||||
),
|
||||
roi_height: Optional[float] = Form(
|
||||
default=None, gt=0.0, le=1.0, description="ROI 高度"
|
||||
),
|
||||
return_annotated_image: bool = Form(
|
||||
default=False, description="是否返回标注图片"
|
||||
),
|
||||
) -> OCRRequestParams:
|
||||
"""
|
||||
解析 multipart 表单参数
|
||||
|
||||
Returns:
|
||||
OCR 请求参数对象
|
||||
"""
|
||||
return OCRRequestParams(
|
||||
lang=lang,
|
||||
use_gpu=use_gpu,
|
||||
drop_score=drop_score,
|
||||
roi_x=roi_x,
|
||||
roi_y=roi_y,
|
||||
roi_width=roi_width,
|
||||
roi_height=roi_height,
|
||||
return_annotated_image=return_annotated_image,
|
||||
)
|
||||
|
||||
|
||||
def decode_image_bytes(content: bytes) -> np.ndarray:
|
||||
"""
|
||||
将图片字节解码为 numpy 数组
|
||||
|
||||
Args:
|
||||
content: 图片字节数据
|
||||
|
||||
Returns:
|
||||
OpenCV 图像 (BGR 格式)
|
||||
|
||||
Raises:
|
||||
InvalidImageError: 图片解码失败
|
||||
"""
|
||||
try:
|
||||
nparr = np.frombuffer(content, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
if image is None:
|
||||
raise InvalidImageError("图片解码失败")
|
||||
return image
|
||||
except Exception as e:
|
||||
if isinstance(e, InvalidImageError):
|
||||
raise
|
||||
raise InvalidImageError(f"图片解码失败: {str(e)}")
|
||||
|
||||
|
||||
def encode_image_base64(image: np.ndarray, format: str = ".jpg") -> str:
|
||||
"""
|
||||
将图像编码为 Base64 字符串
|
||||
|
||||
Args:
|
||||
image: OpenCV 图像
|
||||
format: 图片格式 (如 '.jpg', '.png')
|
||||
|
||||
Returns:
|
||||
Base64 编码的图片字符串
|
||||
"""
|
||||
success, encoded = cv2.imencode(format, image)
|
||||
if not success:
|
||||
return ""
|
||||
return base64.b64encode(encoded.tobytes()).decode("utf-8")
|
||||
|
||||
|
||||
def build_pipeline_config(roi: Optional[ROIParams]) -> PipelineConfig:
|
||||
"""
|
||||
根据 ROI 参数构建管道配置
|
||||
|
||||
Args:
|
||||
roi: ROI 参数
|
||||
|
||||
Returns:
|
||||
管道配置对象
|
||||
"""
|
||||
if roi is None:
|
||||
return PipelineConfig(roi=ROIConfig(enabled=False))
|
||||
|
||||
return PipelineConfig(
|
||||
roi=ROIConfig(
|
||||
enabled=True,
|
||||
x_ratio=roi.x,
|
||||
y_ratio=roi.y,
|
||||
width_ratio=roi.width,
|
||||
height_ratio=roi.height,
|
||||
)
|
||||
)
|
||||
@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
自定义异常模块
|
||||
定义 API 层的异常类型
|
||||
"""
|
||||
|
||||
|
||||
class OCRAPIException(Exception):
|
||||
"""OCR API 基础异常"""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class InvalidImageError(OCRAPIException):
|
||||
"""无效的图片文件"""
|
||||
|
||||
def __init__(self, message: str = "无效的图片文件"):
|
||||
super().__init__(message, status_code=400)
|
||||
|
||||
|
||||
class FileTooLargeError(OCRAPIException):
|
||||
"""文件过大"""
|
||||
|
||||
def __init__(self, message: str = "文件大小超过限制"):
|
||||
super().__init__(message, status_code=413)
|
||||
|
||||
|
||||
class UnsupportedFormatError(OCRAPIException):
|
||||
"""不支持的文件格式"""
|
||||
|
||||
def __init__(self, message: str = "不支持的文件格式"):
|
||||
super().__init__(message, status_code=415)
|
||||
|
||||
|
||||
class OCRProcessingError(OCRAPIException):
|
||||
"""OCR 处理错误"""
|
||||
|
||||
def __init__(self, message: str = "OCR 处理失败"):
|
||||
super().__init__(message, status_code=500)
|
||||
|
||||
|
||||
class ModelNotLoadedError(OCRAPIException):
|
||||
"""模型未加载"""
|
||||
|
||||
def __init__(self, message: str = "OCR 模型尚未加载完成"):
|
||||
super().__init__(message, status_code=503)
|
||||
@ -0,0 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
路由模块
|
||||
定义 API 端点
|
||||
"""
|
||||
|
||||
from api.routes.health import router as health_router
|
||||
from api.routes.ocr import router as ocr_router
|
||||
|
||||
__all__ = ["health_router", "ocr_router"]
|
||||
@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
健康检查路由
|
||||
提供服务状态检查端点
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
from api.schemas.response import HealthResponse
|
||||
|
||||
router = APIRouter(prefix="/health", tags=["健康检查"])
|
||||
|
||||
# API 版本
|
||||
API_VERSION = "1.0.0"
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=HealthResponse,
|
||||
summary="健康检查",
|
||||
description="检查服务是否正常运行",
|
||||
)
|
||||
async def health_check(request: Request) -> HealthResponse:
|
||||
"""
|
||||
基础健康检查
|
||||
返回服务状态和模型加载状态
|
||||
"""
|
||||
model_loaded = getattr(request.app.state, "model_loaded", False)
|
||||
|
||||
return HealthResponse(
|
||||
status="healthy" if model_loaded else "unhealthy",
|
||||
model_loaded=model_loaded,
|
||||
version=API_VERSION,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/ready",
|
||||
response_model=HealthResponse,
|
||||
summary="就绪检查",
|
||||
description="检查服务是否已准备好处理请求 (模型已加载)",
|
||||
)
|
||||
async def readiness_check(request: Request) -> HealthResponse:
|
||||
"""
|
||||
就绪检查
|
||||
只有当模型加载完成后才返回 healthy
|
||||
"""
|
||||
model_loaded = getattr(request.app.state, "model_loaded", False)
|
||||
|
||||
return HealthResponse(
|
||||
status="healthy" if model_loaded else "unhealthy",
|
||||
model_loaded=model_loaded,
|
||||
version=API_VERSION,
|
||||
)
|
||||
@ -0,0 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Pydantic 数据模型模块
|
||||
定义 API 请求和响应的数据结构
|
||||
"""
|
||||
|
||||
from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams
|
||||
from api.schemas.response import (
|
||||
ErrorDetail,
|
||||
ExpressInfoData,
|
||||
ExpressResponse,
|
||||
HealthResponse,
|
||||
OCRResponse,
|
||||
OCRResultData,
|
||||
TextBlockData,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Request
|
||||
"OCRRequestBase64",
|
||||
"OCRRequestParams",
|
||||
"ROIParams",
|
||||
# Response
|
||||
"OCRResponse",
|
||||
"OCRResultData",
|
||||
"TextBlockData",
|
||||
"ExpressResponse",
|
||||
"ExpressInfoData",
|
||||
"HealthResponse",
|
||||
"ErrorDetail",
|
||||
]
|
||||
@ -0,0 +1,141 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
请求数据模型
|
||||
定义 API 请求的数据结构
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ROIParams(BaseModel):
|
||||
"""
|
||||
ROI (感兴趣区域) 参数
|
||||
使用归一化坐标 (0.0 ~ 1.0)
|
||||
"""
|
||||
|
||||
x: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="ROI 左上角 X 坐标 (归一化)",
|
||||
)
|
||||
y: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="ROI 左上角 Y 坐标 (归一化)",
|
||||
)
|
||||
width: float = Field(
|
||||
default=1.0,
|
||||
gt=0.0,
|
||||
le=1.0,
|
||||
description="ROI 宽度 (归一化)",
|
||||
)
|
||||
height: float = Field(
|
||||
default=1.0,
|
||||
gt=0.0,
|
||||
le=1.0,
|
||||
description="ROI 高度 (归一化)",
|
||||
)
|
||||
|
||||
|
||||
class OCRRequestParams(BaseModel):
|
||||
"""
|
||||
OCR 请求参数 (用于 multipart/form-data)
|
||||
"""
|
||||
|
||||
lang: str = Field(
|
||||
default="ch",
|
||||
description="识别语言,支持 'ch'(中文), 'en'(英文) 等",
|
||||
)
|
||||
use_gpu: bool = Field(
|
||||
default=False,
|
||||
description="是否使用 GPU 加速",
|
||||
)
|
||||
drop_score: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="置信度阈值,低于此值的结果将被过滤",
|
||||
)
|
||||
roi_x: Optional[float] = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="ROI 左上角 X 坐标 (归一化)",
|
||||
)
|
||||
roi_y: Optional[float] = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="ROI 左上角 Y 坐标 (归一化)",
|
||||
)
|
||||
roi_width: Optional[float] = Field(
|
||||
default=None,
|
||||
gt=0.0,
|
||||
le=1.0,
|
||||
description="ROI 宽度 (归一化)",
|
||||
)
|
||||
roi_height: Optional[float] = Field(
|
||||
default=None,
|
||||
gt=0.0,
|
||||
le=1.0,
|
||||
description="ROI 高度 (归一化)",
|
||||
)
|
||||
return_annotated_image: bool = Field(
|
||||
default=False,
|
||||
description="是否返回标注后的图片 (Base64)",
|
||||
)
|
||||
|
||||
def has_roi(self) -> bool:
|
||||
"""检查是否设置了 ROI 参数"""
|
||||
return all(
|
||||
v is not None
|
||||
for v in [self.roi_x, self.roi_y, self.roi_width, self.roi_height]
|
||||
)
|
||||
|
||||
def get_roi(self) -> Optional[ROIParams]:
|
||||
"""获取 ROI 参数对象"""
|
||||
if not self.has_roi():
|
||||
return None
|
||||
return ROIParams(
|
||||
x=self.roi_x,
|
||||
y=self.roi_y,
|
||||
width=self.roi_width,
|
||||
height=self.roi_height,
|
||||
)
|
||||
|
||||
|
||||
class OCRRequestBase64(BaseModel):
|
||||
"""
|
||||
OCR 请求 (Base64 JSON 格式)
|
||||
"""
|
||||
|
||||
image_base64: str = Field(
|
||||
...,
|
||||
description="Base64 编码的图片数据,支持 Data URL 格式",
|
||||
)
|
||||
lang: str = Field(
|
||||
default="ch",
|
||||
description="识别语言,支持 'ch'(中文), 'en'(英文) 等",
|
||||
)
|
||||
use_gpu: bool = Field(
|
||||
default=False,
|
||||
description="是否使用 GPU 加速",
|
||||
)
|
||||
drop_score: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="置信度阈值,低于此值的结果将被过滤",
|
||||
)
|
||||
roi: Optional[ROIParams] = Field(
|
||||
default=None,
|
||||
description="ROI 区域参数",
|
||||
)
|
||||
return_annotated_image: bool = Field(
|
||||
default=False,
|
||||
description="是否返回标注后的图片 (Base64)",
|
||||
)
|
||||
@ -0,0 +1,142 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
响应数据模型
|
||||
定义 API 响应的数据结构
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
"""错误详情"""
|
||||
|
||||
code: str = Field(..., description="错误代码")
|
||||
message: str = Field(..., description="错误信息")
|
||||
details: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="额外的错误详情",
|
||||
)
|
||||
|
||||
|
||||
class TextBlockData(BaseModel):
|
||||
"""文本块数据"""
|
||||
|
||||
text: str = Field(..., description="识别出的文本内容")
|
||||
confidence: float = Field(..., description="置信度 (0.0 ~ 1.0)")
|
||||
bbox: List[List[float]] = Field(
|
||||
...,
|
||||
description="边界框 4 个顶点坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]",
|
||||
)
|
||||
bbox_with_offset: List[List[float]] = Field(
|
||||
...,
|
||||
description="带偏移的边界框坐标 (已还原到原图坐标系)",
|
||||
)
|
||||
center: List[float] = Field(..., description="文本块中心点坐标 [x, y]")
|
||||
width: float = Field(..., description="文本块宽度 (像素)")
|
||||
height: float = Field(..., description="文本块高度 (像素)")
|
||||
|
||||
|
||||
class OCRResultData(BaseModel):
|
||||
"""OCR 识别结果数据"""
|
||||
|
||||
processing_time_ms: float = Field(..., description="处理耗时 (毫秒)")
|
||||
text_count: int = Field(..., description="识别出的文本块数量")
|
||||
average_confidence: float = Field(
|
||||
...,
|
||||
description="所有文本块的平均置信度",
|
||||
)
|
||||
roi_applied: bool = Field(..., description="是否应用了 ROI 裁剪")
|
||||
roi_rect: Optional[List[int]] = Field(
|
||||
default=None,
|
||||
description="ROI 矩形区域 [x, y, width, height]",
|
||||
)
|
||||
text_blocks: List[TextBlockData] = Field(
|
||||
default_factory=list,
|
||||
description="识别出的文本块列表",
|
||||
)
|
||||
annotated_image_base64: Optional[str] = Field(
|
||||
default=None,
|
||||
description="标注后的图片 (Base64 编码)",
|
||||
)
|
||||
|
||||
|
||||
class OCRResponse(BaseModel):
|
||||
"""OCR 识别响应"""
|
||||
|
||||
success: bool = Field(..., description="请求是否成功")
|
||||
data: Optional[OCRResultData] = Field(
|
||||
default=None,
|
||||
description="识别结果数据",
|
||||
)
|
||||
error: Optional[ErrorDetail] = Field(
|
||||
default=None,
|
||||
description="错误信息 (仅在失败时存在)",
|
||||
)
|
||||
|
||||
|
||||
class ExpressPersonData(BaseModel):
|
||||
"""快递单人员信息"""
|
||||
|
||||
name: Optional[str] = Field(default=None, description="姓名")
|
||||
phone: Optional[str] = Field(default=None, description="电话")
|
||||
address: Optional[str] = Field(default=None, description="地址")
|
||||
|
||||
|
||||
class ExpressInfoData(BaseModel):
|
||||
"""快递单结构化信息"""
|
||||
|
||||
tracking_number: Optional[str] = Field(default=None, description="运单号")
|
||||
sender: ExpressPersonData = Field(
|
||||
default_factory=ExpressPersonData,
|
||||
description="寄件人信息",
|
||||
)
|
||||
receiver: ExpressPersonData = Field(
|
||||
default_factory=ExpressPersonData,
|
||||
description="收件人信息",
|
||||
)
|
||||
courier_company: Optional[str] = Field(
|
||||
default=None,
|
||||
description="快递公司名称",
|
||||
)
|
||||
confidence: float = Field(default=0.0, description="平均置信度")
|
||||
extra_fields: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="其他识别到的额外字段",
|
||||
)
|
||||
raw_text: str = Field(default="", description="原始合并文本")
|
||||
|
||||
|
||||
class ExpressResultData(BaseModel):
|
||||
"""快递单解析结果数据"""
|
||||
|
||||
processing_time_ms: float = Field(..., description="处理耗时 (毫秒)")
|
||||
express_info: ExpressInfoData = Field(..., description="快递单结构化信息")
|
||||
merged_text: str = Field(..., description="智能合并后的完整文本")
|
||||
annotated_image_base64: Optional[str] = Field(
|
||||
default=None,
|
||||
description="标注后的图片 (Base64 编码)",
|
||||
)
|
||||
|
||||
|
||||
class ExpressResponse(BaseModel):
|
||||
"""快递单解析响应"""
|
||||
|
||||
success: bool = Field(..., description="请求是否成功")
|
||||
data: Optional[ExpressResultData] = Field(
|
||||
default=None,
|
||||
description="解析结果数据",
|
||||
)
|
||||
error: Optional[ErrorDetail] = Field(
|
||||
default=None,
|
||||
description="错误信息 (仅在失败时存在)",
|
||||
)
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""健康检查响应"""
|
||||
|
||||
status: str = Field(..., description="服务状态: 'healthy' 或 'unhealthy'")
|
||||
model_loaded: bool = Field(..., description="OCR 模型是否已加载")
|
||||
version: str = Field(..., description="API 版本")
|
||||
@ -0,0 +1,180 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
安全模块
|
||||
提供文件验证、大小限制等安全功能
|
||||
"""
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from api.exceptions import (
|
||||
FileTooLargeError,
|
||||
InvalidImageError,
|
||||
UnsupportedFormatError,
|
||||
)
|
||||
|
||||
|
||||
# 安全配置常量
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tiff", ".tif"}
|
||||
ALLOWED_MIME_TYPES = {
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/bmp",
|
||||
"image/webp",
|
||||
"image/tiff",
|
||||
}
|
||||
|
||||
# 图片文件魔数 (Magic Bytes)
|
||||
IMAGE_MAGIC_BYTES = {
|
||||
b"\xff\xd8\xff": "jpeg", # JPEG
|
||||
b"\x89PNG\r\n\x1a\n": "png", # PNG
|
||||
b"BM": "bmp", # BMP
|
||||
b"RIFF": "webp", # WebP (需要进一步检查)
|
||||
b"II*\x00": "tiff", # TIFF (Little Endian)
|
||||
b"MM\x00*": "tiff", # TIFF (Big Endian)
|
||||
}
|
||||
|
||||
|
||||
def validate_file_extension(filename: Optional[str]) -> bool:
|
||||
"""
|
||||
验证文件扩展名
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
是否为允许的扩展名
|
||||
"""
|
||||
if not filename:
|
||||
return False
|
||||
ext = Path(filename).suffix.lower()
|
||||
return ext in ALLOWED_EXTENSIONS
|
||||
|
||||
|
||||
def validate_file_size(content: bytes) -> bool:
|
||||
"""
|
||||
验证文件大小
|
||||
|
||||
Args:
|
||||
content: 文件内容
|
||||
|
||||
Returns:
|
||||
是否在允许的大小范围内
|
||||
"""
|
||||
return len(content) <= MAX_FILE_SIZE
|
||||
|
||||
|
||||
def validate_image_magic_bytes(content: bytes) -> bool:
|
||||
"""
|
||||
验证图片文件魔数
|
||||
|
||||
Args:
|
||||
content: 文件内容
|
||||
|
||||
Returns:
|
||||
是否为有效的图片文件
|
||||
"""
|
||||
if len(content) < 8:
|
||||
return False
|
||||
|
||||
for magic, _ in IMAGE_MAGIC_BYTES.items():
|
||||
if content.startswith(magic):
|
||||
# WebP 需要额外检查
|
||||
if magic == b"RIFF" and len(content) >= 12:
|
||||
if content[8:12] != b"WEBP":
|
||||
continue
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def validate_mime_type(content_type: Optional[str]) -> bool:
|
||||
"""
|
||||
验证 MIME 类型
|
||||
|
||||
Args:
|
||||
content_type: MIME 类型
|
||||
|
||||
Returns:
|
||||
是否为允许的 MIME 类型
|
||||
"""
|
||||
if not content_type:
|
||||
return False
|
||||
# 处理带参数的 MIME 类型,如 "image/jpeg; charset=utf-8"
|
||||
mime = content_type.split(";")[0].strip().lower()
|
||||
return mime in ALLOWED_MIME_TYPES
|
||||
|
||||
|
||||
def validate_uploaded_file(
|
||||
content: bytes,
|
||||
filename: Optional[str] = None,
|
||||
content_type: Optional[str] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
综合验证上传的文件
|
||||
|
||||
Args:
|
||||
content: 文件内容
|
||||
filename: 文件名
|
||||
content_type: MIME 类型
|
||||
|
||||
Returns:
|
||||
验证通过的文件内容
|
||||
|
||||
Raises:
|
||||
FileTooLargeError: 文件过大
|
||||
UnsupportedFormatError: 不支持的格式
|
||||
InvalidImageError: 无效的图片
|
||||
"""
|
||||
# 验证文件大小
|
||||
if not validate_file_size(content):
|
||||
raise FileTooLargeError(
|
||||
f"文件大小超过限制,最大允许 {MAX_FILE_SIZE // 1024 // 1024}MB"
|
||||
)
|
||||
|
||||
# 验证扩展名 (如果提供)
|
||||
if filename and not validate_file_extension(filename):
|
||||
raise UnsupportedFormatError(
|
||||
f"不支持的文件格式,允许的格式: {', '.join(ALLOWED_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
# 验证 MIME 类型 (如果提供)
|
||||
if content_type and not validate_mime_type(content_type):
|
||||
raise UnsupportedFormatError(f"不支持的 MIME 类型: {content_type}")
|
||||
|
||||
# 验证文件魔数
|
||||
if not validate_image_magic_bytes(content):
|
||||
raise InvalidImageError("文件内容不是有效的图片格式")
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def decode_base64_image(base64_string: str) -> bytes:
|
||||
"""
|
||||
解码 Base64 图片
|
||||
|
||||
Args:
|
||||
base64_string: Base64 编码的图片字符串
|
||||
|
||||
Returns:
|
||||
解码后的图片字节
|
||||
|
||||
Raises:
|
||||
InvalidImageError: Base64 解码失败或不是有效图片
|
||||
"""
|
||||
# 移除可能的 Data URL 前缀
|
||||
if "," in base64_string:
|
||||
base64_string = base64_string.split(",", 1)[1]
|
||||
|
||||
# 移除空白字符
|
||||
base64_string = base64_string.strip()
|
||||
|
||||
try:
|
||||
content = base64.b64decode(base64_string)
|
||||
except Exception as e:
|
||||
raise InvalidImageError(f"Base64 解码失败: {str(e)}")
|
||||
|
||||
# 验证解码后的内容
|
||||
return validate_uploaded_file(content)
|
||||
@ -0,0 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Vision-OCR API 测试模块
|
||||
"""
|
||||
@ -0,0 +1,134 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
pytest 配置和共享 fixtures
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# 将项目根目录添加到 Python 路径
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
# 设置模型路径环境变量
|
||||
os.environ["PADDLEOCR_HOME"] = str(PROJECT_ROOT / "models")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_ocr_pipeline():
|
||||
"""
|
||||
创建模拟的 OCR Pipeline
|
||||
|
||||
避免在测试中加载真实的 OCR 模型
|
||||
"""
|
||||
from ocr.engine import TextBlock
|
||||
from ocr.pipeline import OCRResult
|
||||
|
||||
mock_pipeline = MagicMock()
|
||||
|
||||
# 模拟 OCR 结果
|
||||
mock_text_blocks = [
|
||||
TextBlock(
|
||||
text="测试文本1",
|
||||
confidence=0.95,
|
||||
bbox=[[10, 10], [100, 10], [100, 30], [10, 30]],
|
||||
bbox_offset=(0, 0),
|
||||
),
|
||||
TextBlock(
|
||||
text="测试文本2",
|
||||
confidence=0.88,
|
||||
bbox=[[10, 40], [150, 40], [150, 60], [10, 60]],
|
||||
bbox_offset=(0, 0),
|
||||
),
|
||||
]
|
||||
|
||||
mock_result = OCRResult(
|
||||
image_index=1,
|
||||
image_path=None,
|
||||
timestamp=1704672000.0,
|
||||
processing_time_ms=45.6,
|
||||
text_blocks=mock_text_blocks,
|
||||
roi_applied=False,
|
||||
roi_rect=None,
|
||||
)
|
||||
|
||||
mock_pipeline.process.return_value = mock_result
|
||||
mock_pipeline.initialize.return_value = None
|
||||
mock_pipeline._pipeline_config = MagicMock()
|
||||
|
||||
return mock_pipeline
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_client(mock_ocr_pipeline) -> Generator[TestClient, None, None]:
|
||||
"""
|
||||
创建测试客户端
|
||||
|
||||
使用模拟的 OCR Pipeline 避免加载真实模型
|
||||
"""
|
||||
# 延迟导入以确保环境变量已设置
|
||||
from api.main import app
|
||||
|
||||
# 设置模拟的 pipeline
|
||||
app.state.ocr_pipeline = mock_ocr_pipeline
|
||||
app.state.model_loaded = True
|
||||
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image_bytes() -> bytes:
|
||||
"""
|
||||
创建测试用的图片字节数据
|
||||
|
||||
生成一个简单的 100x100 白色 JPEG 图片
|
||||
"""
|
||||
import cv2
|
||||
|
||||
# 创建白色图片
|
||||
image = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
||||
|
||||
# 编码为 JPEG
|
||||
success, encoded = cv2.imencode(".jpg", image)
|
||||
assert success, "图片编码失败"
|
||||
|
||||
return encoded.tobytes()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image_base64(sample_image_bytes) -> str:
|
||||
"""
|
||||
创建测试用的 Base64 编码图片
|
||||
"""
|
||||
return base64.b64encode(sample_image_bytes).decode("utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_png_bytes() -> bytes:
|
||||
"""
|
||||
创建测试用的 PNG 图片字节数据
|
||||
"""
|
||||
import cv2
|
||||
|
||||
image = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
||||
success, encoded = cv2.imencode(".png", image)
|
||||
assert success, "PNG 编码失败"
|
||||
|
||||
return encoded.tobytes()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_file_bytes() -> bytes:
|
||||
"""
|
||||
创建无效的文件字节数据 (非图片)
|
||||
"""
|
||||
return b"This is not an image file content"
|
||||
@ -0,0 +1,76 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
健康检查 API 测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestHealthEndpoints:
|
||||
"""健康检查端点测试"""
|
||||
|
||||
def test_health_check_success(self, test_client: TestClient):
|
||||
"""测试健康检查端点 - 正常情况"""
|
||||
response = test_client.get("/api/v1/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert data["model_loaded"] is True
|
||||
assert "version" in data
|
||||
|
||||
def test_readiness_check_success(self, test_client: TestClient):
|
||||
"""测试就绪检查端点 - 正常情况"""
|
||||
response = test_client.get("/api/v1/health/ready")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert data["model_loaded"] is True
|
||||
|
||||
def test_root_endpoint(self, test_client: TestClient):
|
||||
"""测试根路径端点"""
|
||||
response = test_client.get("/")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Vision-OCR API"
|
||||
assert "version" in data
|
||||
assert "docs" in data
|
||||
|
||||
|
||||
class TestHealthEndpointsUnhealthy:
|
||||
"""健康检查端点测试 - 模型未加载情况"""
|
||||
|
||||
def test_health_check_model_not_loaded(self, test_client: TestClient):
|
||||
"""测试健康检查 - 模型未加载"""
|
||||
# 临时设置模型未加载状态
|
||||
original_state = test_client.app.state.model_loaded
|
||||
test_client.app.state.model_loaded = False
|
||||
|
||||
try:
|
||||
response = test_client.get("/api/v1/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "unhealthy"
|
||||
assert data["model_loaded"] is False
|
||||
finally:
|
||||
# 恢复原始状态
|
||||
test_client.app.state.model_loaded = original_state
|
||||
|
||||
def test_readiness_check_model_not_loaded(self, test_client: TestClient):
|
||||
"""测试就绪检查 - 模型未加载"""
|
||||
original_state = test_client.app.state.model_loaded
|
||||
test_client.app.state.model_loaded = False
|
||||
|
||||
try:
|
||||
response = test_client.get("/api/v1/health/ready")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "unhealthy"
|
||||
assert data["model_loaded"] is False
|
||||
finally:
|
||||
test_client.app.state.model_loaded = original_state
|
||||
@ -0,0 +1,315 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
OCR API 测试
|
||||
"""
|
||||
|
||||
import io
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestOCRRecognizeMultipart:
|
||||
"""OCR 识别端点测试 (multipart/form-data)"""
|
||||
|
||||
def test_recognize_success(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
sample_image_bytes: bytes,
|
||||
):
|
||||
"""测试 OCR 识别 - 正常情况"""
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/recognize",
|
||||
files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
|
||||
data={"lang": "ch", "drop_score": "0.5"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"] is not None
|
||||
assert "text_count" in data["data"]
|
||||
assert "text_blocks" in data["data"]
|
||||
assert "processing_time_ms" in data["data"]
|
||||
|
||||
def test_recognize_with_roi(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
sample_image_bytes: bytes,
|
||||
):
|
||||
"""测试 OCR 识别 - 带 ROI 参数"""
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/recognize",
|
||||
files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
|
||||
data={
|
||||
"lang": "ch",
|
||||
"roi_x": "0.1",
|
||||
"roi_y": "0.1",
|
||||
"roi_width": "0.8",
|
||||
"roi_height": "0.8",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
def test_recognize_with_annotated_image(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
sample_image_bytes: bytes,
|
||||
):
|
||||
"""测试 OCR 识别 - 返回标注图片"""
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/recognize",
|
||||
files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
|
||||
data={"return_annotated_image": "true"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
# 注意: 标注图片只有在有识别结果时才返回
|
||||
if data["data"]["text_count"] > 0:
|
||||
assert data["data"]["annotated_image_base64"] is not None
|
||||
|
||||
def test_recognize_png_image(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
sample_png_bytes: bytes,
|
||||
):
|
||||
"""测试 OCR 识别 - PNG 格式"""
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/recognize",
|
||||
files={"file": ("test.png", io.BytesIO(sample_png_bytes), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
def test_recognize_invalid_file(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
invalid_file_bytes: bytes,
|
||||
):
|
||||
"""测试 OCR 识别 - 无效文件"""
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/recognize",
|
||||
files={
|
||||
"file": (
|
||||
"test.txt",
|
||||
io.BytesIO(invalid_file_bytes),
|
||||
"text/plain",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert data["error"] is not None
|
||||
|
||||
def test_recognize_no_file(self, test_client: TestClient):
|
||||
"""测试 OCR 识别 - 未提供文件"""
|
||||
response = test_client.post("/api/v1/ocr/recognize")
|
||||
|
||||
assert response.status_code == 422 # Validation Error
|
||||
|
||||
|
||||
class TestOCRRecognizeBase64:
|
||||
"""OCR 识别端点测试 (Base64 JSON)"""
|
||||
|
||||
def test_recognize_base64_success(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
sample_image_base64: str,
|
||||
):
|
||||
"""测试 OCR 识别 (Base64) - 正常情况"""
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/recognize/base64",
|
||||
json={
|
||||
"image_base64": sample_image_base64,
|
||||
"lang": "ch",
|
||||
"drop_score": 0.5,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"] is not None
|
||||
|
||||
def test_recognize_base64_with_data_url(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
sample_image_base64: str,
|
||||
):
|
||||
"""测试 OCR 识别 (Base64) - Data URL 格式"""
|
||||
data_url = f"data:image/jpeg;base64,{sample_image_base64}"
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/recognize/base64",
|
||||
json={"image_base64": data_url},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
def test_recognize_base64_with_roi(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
sample_image_base64: str,
|
||||
):
|
||||
"""测试 OCR 识别 (Base64) - 带 ROI 参数"""
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/recognize/base64",
|
||||
json={
|
||||
"image_base64": sample_image_base64,
|
||||
"roi": {"x": 0.1, "y": 0.1, "width": 0.8, "height": 0.8},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
def test_recognize_base64_invalid(self, test_client: TestClient):
|
||||
"""测试 OCR 识别 (Base64) - 无效 Base64"""
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/recognize/base64",
|
||||
json={"image_base64": "not-valid-base64!!!"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert data["error"] is not None
|
||||
|
||||
def test_recognize_base64_missing_field(self, test_client: TestClient):
|
||||
"""测试 OCR 识别 (Base64) - 缺少必填字段"""
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/recognize/base64",
|
||||
json={"lang": "ch"}, # 缺少 image_base64
|
||||
)
|
||||
|
||||
assert response.status_code == 422 # Validation Error
|
||||
|
||||
|
||||
class TestExpressMultipart:
|
||||
"""快递单解析端点测试 (multipart/form-data)"""
|
||||
|
||||
def test_express_success(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
sample_image_bytes: bytes,
|
||||
):
|
||||
"""测试快递单解析 - 正常情况"""
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/express",
|
||||
files={"file": ("express.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"] is not None
|
||||
assert "express_info" in data["data"]
|
||||
assert "merged_text" in data["data"]
|
||||
assert "processing_time_ms" in data["data"]
|
||||
|
||||
def test_express_info_structure(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
sample_image_bytes: bytes,
|
||||
):
|
||||
"""测试快递单解析 - 响应结构"""
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/express",
|
||||
files={"file": ("express.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
express_info = data["data"]["express_info"]
|
||||
|
||||
# 验证结构完整性
|
||||
assert "tracking_number" in express_info
|
||||
assert "sender" in express_info
|
||||
assert "receiver" in express_info
|
||||
assert "courier_company" in express_info
|
||||
assert "confidence" in express_info
|
||||
|
||||
# 验证 sender/receiver 结构
|
||||
assert "name" in express_info["sender"]
|
||||
assert "phone" in express_info["sender"]
|
||||
assert "address" in express_info["sender"]
|
||||
|
||||
|
||||
class TestExpressBase64:
|
||||
"""快递单解析端点测试 (Base64 JSON)"""
|
||||
|
||||
def test_express_base64_success(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
sample_image_base64: str,
|
||||
):
|
||||
"""测试快递单解析 (Base64) - 正常情况"""
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/express/base64",
|
||||
json={"image_base64": sample_image_base64},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"] is not None
|
||||
assert "express_info" in data["data"]
|
||||
|
||||
|
||||
class TestSecurityValidation:
|
||||
"""安全验证测试"""
|
||||
|
||||
def test_file_size_limit(self, test_client: TestClient):
|
||||
"""测试文件大小限制"""
|
||||
# 创建一个超大的假文件 (11MB)
|
||||
large_content = b"x" * (11 * 1024 * 1024)
|
||||
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/recognize",
|
||||
files={"file": ("large.jpg", io.BytesIO(large_content), "image/jpeg")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert "大小" in data["error"]["message"] or "size" in data["error"]["message"].lower()
|
||||
|
||||
def test_invalid_extension(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
sample_image_bytes: bytes,
|
||||
):
|
||||
"""测试无效文件扩展名"""
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/recognize",
|
||||
files={"file": ("test.exe", io.BytesIO(sample_image_bytes), "application/octet-stream")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
|
||||
def test_magic_bytes_validation(self, test_client: TestClient):
|
||||
"""测试文件魔数验证"""
|
||||
# 创建一个假的 jpg 文件 (扩展名正确但内容不是图片)
|
||||
fake_jpg = b"This is not a real JPEG file"
|
||||
|
||||
response = test_client.post(
|
||||
"/api/v1/ocr/recognize",
|
||||
files={"file": ("fake.jpg", io.BytesIO(fake_jpg), "image/jpeg")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
Loading…
Reference in New Issue