feat: 添加基于 FastAPI 的 REST API 服务

master
蒋尚宏 4 weeks ago
parent 7570e1314d
commit c05e3e58ed

@ -1,6 +1,6 @@
# Vision-OCR: 图片 OCR 识别系统
基于 PaddleOCR 的图片 OCR 识别系统,支持单张图片、批量图片和目录扫描,提供文本检测、识别、方向分类,输出结构化识别结果。
基于 PaddleOCR 的图片 OCR 识别系统,支持单张图片、批量图片和目录扫描,提供文本检测、识别、方向分类,输出结构化识别结果。同时提供 REST API 服务,支持 HTTP 接口调用。
## 功能特性
@ -11,6 +11,7 @@
- **可视化展示**: 在图片上绘制文本框和识别结果
- **结果导出**: 支持 JSON 结果导出和标注图片保存
- **ROI 裁剪**: 支持只识别图片指定区域
- **REST API**: 提供 HTTP 接口,支持文件上传和 Base64 两种方式
- **模块化设计**: 图片加载与 OCR 逻辑完全解耦,便于扩展
- **全本地运行**: 不依赖任何云服务
@ -18,6 +19,18 @@
```
vision-ocr/
├── api/ # REST API 模块
│ ├── __init__.py
│ ├── main.py # FastAPI 应用入口
│ ├── dependencies.py # 依赖注入
│ ├── exceptions.py # 自定义异常
│ ├── security.py # 安全验证
│ ├── routes/ # 路由模块
│ │ ├── health.py # 健康检查端点
│ │ └── ocr.py # OCR 识别端点
│ └── schemas/ # 数据模型
│ ├── request.py # 请求模型
│ └── response.py # 响应模型
├── input/ # 图片输入模块
│ ├── __init__.py
│ └── loader.py # 图片加载器
@ -32,8 +45,12 @@ vision-ocr/
├── utils/ # 工具模块
│ ├── __init__.py
│ └── config.py # 配置管理
├── models/ # 模型文件目录(运行 download_models.py 后生成)
├── main.py # 主入口
├── tests/ # 测试模块
│ ├── conftest.py # pytest 配置
│ ├── test_api_health.py # 健康检查测试
│ └── test_api_ocr.py # OCR API 测试
├── models/ # 模型文件目录
├── main.py # CLI 主入口
├── download_models.py # 模型下载脚本
├── requirements.txt # 依赖清单
└── README.md
@ -540,6 +557,236 @@ pipeline.add_postprocessor(filter_short_text)
4. **提高置信度阈值**: 使用 `--drop-score` 过滤低质量结果
5. **批量处理**: 使用目录模式批量处理多张图片
## REST API 服务
除了命令行工具,本项目还提供基于 FastAPI 的 REST API 服务。
### 启动服务
```bash
# 开发模式 (支持热重载)
uvicorn api.main:app --reload --host 0.0.0.0 --port 8000
# 生产模式 (多进程)
uvicorn api.main:app --host 0.0.0.0 --port 8000 --workers 4
# 或直接运行
python -m api.main
```
启动后访问:
- API 文档 (Swagger UI): http://localhost:8000/docs
- API 文档 (ReDoc): http://localhost:8000/redoc
- 健康检查: http://localhost:8000/api/v1/health
### API 端点
| 方法 | 端点 | 功能 | 输入格式 |
|------|------|------|---------|
| `GET` | `/api/v1/health` | 健康检查 | - |
| `GET` | `/api/v1/health/ready` | 就绪检查 | - |
| `POST` | `/api/v1/ocr/recognize` | OCR 识别 | multipart/form-data |
| `POST` | `/api/v1/ocr/recognize/base64` | OCR 识别 | JSON (Base64) |
| `POST` | `/api/v1/ocr/express` | 快递单解析 | multipart/form-data |
| `POST` | `/api/v1/ocr/express/base64` | 快递单解析 | JSON (Base64) |
### 使用示例
#### 方式一: 文件上传 (multipart/form-data)
```bash
# OCR 识别
curl -X POST "http://localhost:8000/api/v1/ocr/recognize" \
-F "file=@image.jpg" \
-F "lang=ch" \
-F "drop_score=0.5"
# 快递单解析
curl -X POST "http://localhost:8000/api/v1/ocr/express" \
-F "file=@express.jpg"
# 带 ROI 参数
curl -X POST "http://localhost:8000/api/v1/ocr/recognize" \
-F "file=@image.jpg" \
-F "roi_x=0.1" \
-F "roi_y=0.1" \
-F "roi_width=0.8" \
-F "roi_height=0.8"
# 返回标注图片
curl -X POST "http://localhost:8000/api/v1/ocr/recognize" \
-F "file=@image.jpg" \
-F "return_annotated_image=true"
```
#### 方式二: Base64 JSON
```bash
# OCR 识别
curl -X POST "http://localhost:8000/api/v1/ocr/recognize/base64" \
-H "Content-Type: application/json" \
-d '{
"image_base64": "'"$(base64 -w 0 image.jpg)"'",
"lang": "ch",
"drop_score": 0.5
}'
# 快递单解析
curl -X POST "http://localhost:8000/api/v1/ocr/express/base64" \
-H "Content-Type: application/json" \
-d '{
"image_base64": "'"$(base64 -w 0 express.jpg)"'"
}'
# 带 ROI 参数
curl -X POST "http://localhost:8000/api/v1/ocr/recognize/base64" \
-H "Content-Type: application/json" \
-d '{
"image_base64": "...",
"roi": {"x": 0.1, "y": 0.1, "width": 0.8, "height": 0.8}
}'
```
#### Python 调用示例
```python
import requests
import base64
# 方式一: 文件上传
with open("image.jpg", "rb") as f:
response = requests.post(
"http://localhost:8000/api/v1/ocr/recognize",
files={"file": ("image.jpg", f, "image/jpeg")},
data={"lang": "ch", "drop_score": 0.5}
)
result = response.json()
# 方式二: Base64
with open("image.jpg", "rb") as f:
image_base64 = base64.b64encode(f.read()).decode()
response = requests.post(
"http://localhost:8000/api/v1/ocr/recognize/base64",
json={
"image_base64": image_base64,
"lang": "ch"
}
)
result = response.json()
# 处理结果
if result["success"]:
for block in result["data"]["text_blocks"]:
print(f"文本: {block['text']}, 置信度: {block['confidence']}")
```
### API 响应格式
#### OCR 识别响应
```json
{
"success": true,
"data": {
"processing_time_ms": 45.2,
"text_count": 3,
"average_confidence": 0.92,
"roi_applied": false,
"roi_rect": null,
"text_blocks": [
{
"text": "识别的文本",
"confidence": 0.95,
"bbox": [[100, 50], [200, 50], [200, 80], [100, 80]],
"bbox_with_offset": [[100, 50], [200, 50], [200, 80], [100, 80]],
"center": [150, 65],
"width": 100,
"height": 30
}
],
"annotated_image_base64": null
},
"error": null
}
```
#### 快递单解析响应
```json
{
"success": true,
"data": {
"processing_time_ms": 52.3,
"express_info": {
"tracking_number": "SF1234567890",
"sender": {
"name": "张三",
"phone": "13800138000",
"address": "北京市朝阳区xxx路"
},
"receiver": {
"name": "李四",
"phone": "13900139000",
"address": "上海市浦东新区xxx路"
},
"courier_company": "顺丰速运",
"confidence": 0.95,
"extra_fields": {},
"raw_text": "..."
},
"merged_text": "顺丰速运\n运单号SF1234567890\n...",
"annotated_image_base64": null
},
"error": null
}
```
#### 错误响应
```json
{
"success": false,
"data": null,
"error": {
"code": "InvalidImageError",
"message": "无效的图片文件"
}
}
```
### API 请求参数
| 参数 | 类型 | 默认值 | 说明 |
|------|------|--------|------|
| `file` | File | - | 图片文件 (multipart 模式) |
| `image_base64` | string | - | Base64 编码图片 (JSON 模式) |
| `lang` | string | "ch" | 识别语言 |
| `use_gpu` | bool | false | 是否使用 GPU |
| `drop_score` | float | 0.5 | 置信度阈值 (0.0~1.0) |
| `roi_x` | float | - | ROI 左上角 X (0.0~1.0) |
| `roi_y` | float | - | ROI 左上角 Y (0.0~1.0) |
| `roi_width` | float | - | ROI 宽度 (0.0~1.0) |
| `roi_height` | float | - | ROI 高度 (0.0~1.0) |
| `return_annotated_image` | bool | false | 是否返回标注图片 |
### 安全限制
- 最大文件大小: 10MB
- 支持的格式: jpg, jpeg, png, bmp, webp, tiff
- 文件验证: 扩展名 + MIME 类型 + 文件魔数
### 运行测试
```bash
# 运行所有 API 测试
pytest tests/ -v
# 运行特定测试
pytest tests/test_api_health.py -v
pytest tests/test_api_ocr.py -v
```
## 常见问题
### Q: Windows 中文用户名导致模型加载失败?

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
"""
Vision-OCR REST API 模块
提供 HTTP 接口访问 OCR 功能
"""
from api.main import app
__all__ = ["app"]

@ -0,0 +1,157 @@
# -*- coding: utf-8 -*-
"""
依赖注入模块
提供 FastAPI 依赖项
"""
import base64
from typing import Optional
import cv2
import numpy as np
from fastapi import Depends, File, Form, Request, UploadFile
from api.exceptions import InvalidImageError, ModelNotLoadedError
from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams
from api.security import decode_base64_image, validate_uploaded_file
from ocr.pipeline import OCRPipeline, OCRResult
from utils.config import OCRConfig, PipelineConfig, ROIConfig
def get_ocr_pipeline(request: Request) -> OCRPipeline:
"""
获取 OCR Pipeline 实例
Args:
request: FastAPI 请求对象
Returns:
OCRPipeline 实例
Raises:
ModelNotLoadedError: 模型未加载
"""
pipeline = getattr(request.app.state, "ocr_pipeline", None)
if pipeline is None:
raise ModelNotLoadedError()
return pipeline
async def parse_multipart_image(
file: UploadFile = File(..., description="图片文件"),
) -> bytes:
"""
解析 multipart 上传的图片文件
Args:
file: 上传的文件
Returns:
验证后的图片字节数据
"""
content = await file.read()
return validate_uploaded_file(
content=content,
filename=file.filename,
content_type=file.content_type,
)
async def parse_multipart_params(
lang: str = Form(default="ch", description="识别语言"),
use_gpu: bool = Form(default=False, description="是否使用 GPU"),
drop_score: float = Form(default=0.5, ge=0.0, le=1.0, description="置信度阈值"),
roi_x: Optional[float] = Form(default=None, ge=0.0, le=1.0, description="ROI X"),
roi_y: Optional[float] = Form(default=None, ge=0.0, le=1.0, description="ROI Y"),
roi_width: Optional[float] = Form(
default=None, gt=0.0, le=1.0, description="ROI 宽度"
),
roi_height: Optional[float] = Form(
default=None, gt=0.0, le=1.0, description="ROI 高度"
),
return_annotated_image: bool = Form(
default=False, description="是否返回标注图片"
),
) -> OCRRequestParams:
"""
解析 multipart 表单参数
Returns:
OCR 请求参数对象
"""
return OCRRequestParams(
lang=lang,
use_gpu=use_gpu,
drop_score=drop_score,
roi_x=roi_x,
roi_y=roi_y,
roi_width=roi_width,
roi_height=roi_height,
return_annotated_image=return_annotated_image,
)
def decode_image_bytes(content: bytes) -> np.ndarray:
"""
将图片字节解码为 numpy 数组
Args:
content: 图片字节数据
Returns:
OpenCV 图像 (BGR 格式)
Raises:
InvalidImageError: 图片解码失败
"""
try:
nparr = np.frombuffer(content, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None:
raise InvalidImageError("图片解码失败")
return image
except Exception as e:
if isinstance(e, InvalidImageError):
raise
raise InvalidImageError(f"图片解码失败: {str(e)}")
def encode_image_base64(image: np.ndarray, format: str = ".jpg") -> str:
"""
将图像编码为 Base64 字符串
Args:
image: OpenCV 图像
format: 图片格式 ( '.jpg', '.png')
Returns:
Base64 编码的图片字符串
"""
success, encoded = cv2.imencode(format, image)
if not success:
return ""
return base64.b64encode(encoded.tobytes()).decode("utf-8")
def build_pipeline_config(roi: Optional[ROIParams]) -> PipelineConfig:
"""
根据 ROI 参数构建管道配置
Args:
roi: ROI 参数
Returns:
管道配置对象
"""
if roi is None:
return PipelineConfig(roi=ROIConfig(enabled=False))
return PipelineConfig(
roi=ROIConfig(
enabled=True,
x_ratio=roi.x,
y_ratio=roi.y,
width_ratio=roi.width,
height_ratio=roi.height,
)
)

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
"""
自定义异常模块
定义 API 层的异常类型
"""
class OCRAPIException(Exception):
"""OCR API 基础异常"""
def __init__(self, message: str, status_code: int = 500):
self.message = message
self.status_code = status_code
super().__init__(self.message)
class InvalidImageError(OCRAPIException):
"""无效的图片文件"""
def __init__(self, message: str = "无效的图片文件"):
super().__init__(message, status_code=400)
class FileTooLargeError(OCRAPIException):
"""文件过大"""
def __init__(self, message: str = "文件大小超过限制"):
super().__init__(message, status_code=413)
class UnsupportedFormatError(OCRAPIException):
"""不支持的文件格式"""
def __init__(self, message: str = "不支持的文件格式"):
super().__init__(message, status_code=415)
class OCRProcessingError(OCRAPIException):
"""OCR 处理错误"""
def __init__(self, message: str = "OCR 处理失败"):
super().__init__(message, status_code=500)
class ModelNotLoadedError(OCRAPIException):
"""模型未加载"""
def __init__(self, message: str = "OCR 模型尚未加载完成"):
super().__init__(message, status_code=503)

@ -0,0 +1,172 @@
# -*- coding: utf-8 -*-
"""
Vision-OCR REST API 主入口
基于 FastAPI OCR 服务
"""
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
# 设置项目根目录和模型路径
_PROJECT_ROOT = Path(__file__).parent.parent
_MODELS_DIR = _PROJECT_ROOT / "models"
_MODELS_DIR.mkdir(exist_ok=True)
os.environ["PADDLEOCR_HOME"] = str(_MODELS_DIR)
# 将项目根目录添加到 Python 路径
sys.path.insert(0, str(_PROJECT_ROOT))
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from api.exceptions import OCRAPIException
from api.routes import health_router, ocr_router
from ocr.pipeline import OCRPipeline
from utils.config import OCRConfig, PipelineConfig
def _get_ocr_config() -> OCRConfig:
"""
获取 OCR 配置
检查模型是否存在于项目目录如果存在则使用本地模型
"""
det_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_det_infer")
rec_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_rec_infer")
cls_model_dir = str(_MODELS_DIR / "ch_ppocr_mobile_v2.0_cls_infer")
# 检查模型是否已下载
models_exist = (
Path(det_model_dir).exists()
and Path(rec_model_dir).exists()
and Path(cls_model_dir).exists()
)
return OCRConfig(
lang="ch",
use_angle_cls=True,
use_gpu=False,
drop_score=0.5,
det_model_dir=det_model_dir if models_exist else None,
rec_model_dir=rec_model_dir if models_exist else None,
cls_model_dir=cls_model_dir if models_exist else None,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用生命周期管理
启动时预加载 OCR 模型关闭时清理资源
"""
# 启动: 加载 OCR 模型
print("[INFO] 正在初始化 OCR API 服务...")
ocr_config = _get_ocr_config()
if ocr_config.det_model_dir is None:
print("[WARN] 模型未在项目目录中找到")
print("[WARN] 对于 Windows 中文用户名用户,请先运行:")
print("[WARN] python download_models.py")
print("[INFO] 回退到默认 PaddleOCR 模型路径...")
pipeline_config = PipelineConfig()
print("[INFO] 正在加载 OCR 模型...")
pipeline = OCRPipeline(ocr_config, pipeline_config)
pipeline.initialize()
app.state.ocr_pipeline = pipeline
app.state.model_loaded = True
print("[INFO] OCR 模型加载完成,服务已就绪")
yield
# 关闭: 清理资源
print("[INFO] 正在关闭 OCR API 服务...")
app.state.model_loaded = False
print("[INFO] 服务已关闭")
# 创建 FastAPI 应用
app = FastAPI(
title="Vision-OCR API",
description="基于 PaddleOCR 的图片 OCR 识别服务,支持通用文字识别和快递单解析",
version="1.0.0",
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
)
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境应限制具体域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局异常处理
@app.exception_handler(OCRAPIException)
async def ocr_exception_handler(request: Request, exc: OCRAPIException):
"""处理 OCR API 自定义异常"""
return JSONResponse(
status_code=exc.status_code,
content={
"success": False,
"data": None,
"error": {
"code": type(exc).__name__,
"message": exc.message,
},
},
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""处理未捕获的异常"""
return JSONResponse(
status_code=500,
content={
"success": False,
"data": None,
"error": {
"code": "InternalServerError",
"message": f"服务器内部错误: {str(exc)}",
},
},
)
# 注册路由
app.include_router(health_router, prefix="/api/v1")
app.include_router(ocr_router, prefix="/api/v1")
@app.get("/", tags=["根路径"])
async def root():
"""API 根路径,返回基本信息"""
return {
"name": "Vision-OCR API",
"version": "1.0.0",
"docs": "/docs",
"health": "/api/v1/health",
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"api.main:app",
host="0.0.0.0",
port=8000,
reload=True,
)

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
"""
路由模块
定义 API 端点
"""
from api.routes.health import router as health_router
from api.routes.ocr import router as ocr_router
__all__ = ["health_router", "ocr_router"]

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
"""
健康检查路由
提供服务状态检查端点
"""
from fastapi import APIRouter, Request
from api.schemas.response import HealthResponse
router = APIRouter(prefix="/health", tags=["健康检查"])
# API 版本
API_VERSION = "1.0.0"
@router.get(
"",
response_model=HealthResponse,
summary="健康检查",
description="检查服务是否正常运行",
)
async def health_check(request: Request) -> HealthResponse:
"""
基础健康检查
返回服务状态和模型加载状态
"""
model_loaded = getattr(request.app.state, "model_loaded", False)
return HealthResponse(
status="healthy" if model_loaded else "unhealthy",
model_loaded=model_loaded,
version=API_VERSION,
)
@router.get(
"/ready",
response_model=HealthResponse,
summary="就绪检查",
description="检查服务是否已准备好处理请求 (模型已加载)",
)
async def readiness_check(request: Request) -> HealthResponse:
"""
就绪检查
只有当模型加载完成后才返回 healthy
"""
model_loaded = getattr(request.app.state, "model_loaded", False)
return HealthResponse(
status="healthy" if model_loaded else "unhealthy",
model_loaded=model_loaded,
version=API_VERSION,
)

@ -0,0 +1,360 @@
# -*- coding: utf-8 -*-
"""
OCR 路由模块
提供 OCR 识别和快递单解析端点
"""
from typing import Optional
from fastapi import APIRouter, Depends, File, Request, UploadFile
from api.dependencies import (
build_pipeline_config,
decode_image_bytes,
encode_image_base64,
get_ocr_pipeline,
parse_multipart_image,
parse_multipart_params,
)
from api.exceptions import OCRProcessingError
from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams
from api.schemas.response import (
ErrorDetail,
ExpressInfoData,
ExpressPersonData,
ExpressResponse,
ExpressResultData,
OCRResponse,
OCRResultData,
TextBlockData,
)
from api.security import decode_base64_image
from ocr.pipeline import OCRPipeline, OCRResult
from utils.config import OCRConfig, PipelineConfig, ROIConfig
from visualize.draw import OCRVisualizer
from utils.config import VisualizeConfig
router = APIRouter(prefix="/ocr", tags=["OCR 识别"])
def _convert_ocr_result_to_response(
result: OCRResult,
annotated_image_base64: Optional[str] = None,
) -> OCRResultData:
"""
OCRResult 转换为响应数据模型
Args:
result: OCR 处理结果
annotated_image_base64: 标注图片的 Base64 编码
Returns:
响应数据模型
"""
text_blocks = [
TextBlockData(
text=block.text,
confidence=block.confidence,
bbox=block.bbox,
bbox_with_offset=block.bbox_with_offset,
center=list(block.center),
width=block.width,
height=block.height,
)
for block in result.text_blocks
]
return OCRResultData(
processing_time_ms=result.processing_time_ms,
text_count=result.text_count,
average_confidence=result.average_confidence,
roi_applied=result.roi_applied,
roi_rect=list(result.roi_rect) if result.roi_rect else None,
text_blocks=text_blocks,
annotated_image_base64=annotated_image_base64,
)
def _process_ocr(
image_bytes: bytes,
pipeline: OCRPipeline,
roi: Optional[ROIParams] = None,
return_annotated_image: bool = False,
) -> tuple[OCRResult, Optional[str]]:
"""
执行 OCR 处理
Args:
image_bytes: 图片字节数据
pipeline: OCR 管道
roi: ROI 参数
return_annotated_image: 是否返回标注图片
Returns:
(OCR 结果, 标注图片 Base64)
"""
# 解码图片
image = decode_image_bytes(image_bytes)
# 构建管道配置
pipeline_config = build_pipeline_config(roi)
# 临时更新管道配置
original_config = pipeline._pipeline_config
pipeline._pipeline_config = pipeline_config
try:
# 执行 OCR
result = pipeline.process(image)
except Exception as e:
raise OCRProcessingError(f"OCR 处理失败: {str(e)}")
finally:
# 恢复原始配置
pipeline._pipeline_config = original_config
# 生成标注图片
annotated_image_base64 = None
if return_annotated_image and result.text_count > 0:
visualizer = OCRVisualizer(VisualizeConfig())
annotated = visualizer.draw_result(image, result)
annotated_image_base64 = encode_image_base64(annotated)
return result, annotated_image_base64
# ============================================================
# multipart/form-data 端点
# ============================================================
@router.post(
"/recognize",
response_model=OCRResponse,
summary="OCR 识别 (文件上传)",
description="上传图片文件进行 OCR 识别,支持 jpg/png/bmp/webp 格式",
)
async def recognize_multipart(
request: Request,
file: UploadFile = File(..., description="图片文件"),
params: OCRRequestParams = Depends(parse_multipart_params),
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
) -> OCRResponse:
"""
OCR 识别端点 (multipart/form-data)
上传图片文件进行文字识别
"""
try:
# 验证并读取文件
image_bytes = await parse_multipart_image(file)
# 执行 OCR
result, annotated_base64 = _process_ocr(
image_bytes=image_bytes,
pipeline=pipeline,
roi=params.get_roi(),
return_annotated_image=params.return_annotated_image,
)
# 构建响应
return OCRResponse(
success=True,
data=_convert_ocr_result_to_response(result, annotated_base64),
)
except Exception as e:
return OCRResponse(
success=False,
error=ErrorDetail(
code=type(e).__name__,
message=str(e),
),
)
@router.post(
"/express",
response_model=ExpressResponse,
summary="快递单解析 (文件上传)",
description="上传快递单图片进行 OCR 识别并解析结构化信息",
)
async def express_multipart(
request: Request,
file: UploadFile = File(..., description="快递单图片"),
params: OCRRequestParams = Depends(parse_multipart_params),
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
) -> ExpressResponse:
"""
快递单解析端点 (multipart/form-data)
上传快递单图片自动识别并提取结构化信息
"""
try:
# 验证并读取文件
image_bytes = await parse_multipart_image(file)
# 执行 OCR
result, annotated_base64 = _process_ocr(
image_bytes=image_bytes,
pipeline=pipeline,
roi=params.get_roi(),
return_annotated_image=params.return_annotated_image,
)
# 解析快递单信息
express_info = result.parse_express()
merged_text = result.merge_text()
# 构建响应
return ExpressResponse(
success=True,
data=ExpressResultData(
processing_time_ms=result.processing_time_ms,
express_info=ExpressInfoData(
tracking_number=express_info.tracking_number,
sender=ExpressPersonData(
name=express_info.sender_name,
phone=express_info.sender_phone,
address=express_info.sender_address,
),
receiver=ExpressPersonData(
name=express_info.receiver_name,
phone=express_info.receiver_phone,
address=express_info.receiver_address,
),
courier_company=express_info.courier_company,
confidence=express_info.confidence,
extra_fields=express_info.extra_fields,
raw_text=express_info.raw_text,
),
merged_text=merged_text,
annotated_image_base64=annotated_base64,
),
)
except Exception as e:
return ExpressResponse(
success=False,
error=ErrorDetail(
code=type(e).__name__,
message=str(e),
),
)
# ============================================================
# JSON (Base64) 端点
# ============================================================
@router.post(
"/recognize/base64",
response_model=OCRResponse,
summary="OCR 识别 (Base64)",
description="提交 Base64 编码的图片进行 OCR 识别",
)
async def recognize_base64(
request: Request,
body: OCRRequestBase64,
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
) -> OCRResponse:
"""
OCR 识别端点 (JSON Base64)
提交 Base64 编码的图片进行文字识别
"""
try:
# 解码并验证 Base64 图片
image_bytes = decode_base64_image(body.image_base64)
# 执行 OCR
result, annotated_base64 = _process_ocr(
image_bytes=image_bytes,
pipeline=pipeline,
roi=body.roi,
return_annotated_image=body.return_annotated_image,
)
# 构建响应
return OCRResponse(
success=True,
data=_convert_ocr_result_to_response(result, annotated_base64),
)
except Exception as e:
return OCRResponse(
success=False,
error=ErrorDetail(
code=type(e).__name__,
message=str(e),
),
)
@router.post(
"/express/base64",
response_model=ExpressResponse,
summary="快递单解析 (Base64)",
description="提交 Base64 编码的快递单图片进行解析",
)
async def express_base64(
request: Request,
body: OCRRequestBase64,
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
) -> ExpressResponse:
"""
快递单解析端点 (JSON Base64)
提交 Base64 编码的快递单图片自动识别并提取结构化信息
"""
try:
# 解码并验证 Base64 图片
image_bytes = decode_base64_image(body.image_base64)
# 执行 OCR
result, annotated_base64 = _process_ocr(
image_bytes=image_bytes,
pipeline=pipeline,
roi=body.roi,
return_annotated_image=body.return_annotated_image,
)
# 解析快递单信息
express_info = result.parse_express()
merged_text = result.merge_text()
# 构建响应
return ExpressResponse(
success=True,
data=ExpressResultData(
processing_time_ms=result.processing_time_ms,
express_info=ExpressInfoData(
tracking_number=express_info.tracking_number,
sender=ExpressPersonData(
name=express_info.sender_name,
phone=express_info.sender_phone,
address=express_info.sender_address,
),
receiver=ExpressPersonData(
name=express_info.receiver_name,
phone=express_info.receiver_phone,
address=express_info.receiver_address,
),
courier_company=express_info.courier_company,
confidence=express_info.confidence,
extra_fields=express_info.extra_fields,
raw_text=express_info.raw_text,
),
merged_text=merged_text,
annotated_image_base64=annotated_base64,
),
)
except Exception as e:
return ExpressResponse(
success=False,
error=ErrorDetail(
code=type(e).__name__,
message=str(e),
),
)

@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
"""
Pydantic 数据模型模块
定义 API 请求和响应的数据结构
"""
from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams
from api.schemas.response import (
ErrorDetail,
ExpressInfoData,
ExpressResponse,
HealthResponse,
OCRResponse,
OCRResultData,
TextBlockData,
)
__all__ = [
# Request
"OCRRequestBase64",
"OCRRequestParams",
"ROIParams",
# Response
"OCRResponse",
"OCRResultData",
"TextBlockData",
"ExpressResponse",
"ExpressInfoData",
"HealthResponse",
"ErrorDetail",
]

@ -0,0 +1,141 @@
# -*- coding: utf-8 -*-
"""
请求数据模型
定义 API 请求的数据结构
"""
from typing import Optional
from pydantic import BaseModel, Field
class ROIParams(BaseModel):
"""
ROI (感兴趣区域) 参数
使用归一化坐标 (0.0 ~ 1.0)
"""
x: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="ROI 左上角 X 坐标 (归一化)",
)
y: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="ROI 左上角 Y 坐标 (归一化)",
)
width: float = Field(
default=1.0,
gt=0.0,
le=1.0,
description="ROI 宽度 (归一化)",
)
height: float = Field(
default=1.0,
gt=0.0,
le=1.0,
description="ROI 高度 (归一化)",
)
class OCRRequestParams(BaseModel):
"""
OCR 请求参数 (用于 multipart/form-data)
"""
lang: str = Field(
default="ch",
description="识别语言,支持 'ch'(中文), 'en'(英文) 等",
)
use_gpu: bool = Field(
default=False,
description="是否使用 GPU 加速",
)
drop_score: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="置信度阈值,低于此值的结果将被过滤",
)
roi_x: Optional[float] = Field(
default=None,
ge=0.0,
le=1.0,
description="ROI 左上角 X 坐标 (归一化)",
)
roi_y: Optional[float] = Field(
default=None,
ge=0.0,
le=1.0,
description="ROI 左上角 Y 坐标 (归一化)",
)
roi_width: Optional[float] = Field(
default=None,
gt=0.0,
le=1.0,
description="ROI 宽度 (归一化)",
)
roi_height: Optional[float] = Field(
default=None,
gt=0.0,
le=1.0,
description="ROI 高度 (归一化)",
)
return_annotated_image: bool = Field(
default=False,
description="是否返回标注后的图片 (Base64)",
)
def has_roi(self) -> bool:
"""检查是否设置了 ROI 参数"""
return all(
v is not None
for v in [self.roi_x, self.roi_y, self.roi_width, self.roi_height]
)
def get_roi(self) -> Optional[ROIParams]:
"""获取 ROI 参数对象"""
if not self.has_roi():
return None
return ROIParams(
x=self.roi_x,
y=self.roi_y,
width=self.roi_width,
height=self.roi_height,
)
class OCRRequestBase64(BaseModel):
"""
OCR 请求 (Base64 JSON 格式)
"""
image_base64: str = Field(
...,
description="Base64 编码的图片数据,支持 Data URL 格式",
)
lang: str = Field(
default="ch",
description="识别语言,支持 'ch'(中文), 'en'(英文) 等",
)
use_gpu: bool = Field(
default=False,
description="是否使用 GPU 加速",
)
drop_score: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="置信度阈值,低于此值的结果将被过滤",
)
roi: Optional[ROIParams] = Field(
default=None,
description="ROI 区域参数",
)
return_annotated_image: bool = Field(
default=False,
description="是否返回标注后的图片 (Base64)",
)

@ -0,0 +1,142 @@
# -*- coding: utf-8 -*-
"""
响应数据模型
定义 API 响应的数据结构
"""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class ErrorDetail(BaseModel):
"""错误详情"""
code: str = Field(..., description="错误代码")
message: str = Field(..., description="错误信息")
details: Optional[Dict[str, Any]] = Field(
default=None,
description="额外的错误详情",
)
class TextBlockData(BaseModel):
"""文本块数据"""
text: str = Field(..., description="识别出的文本内容")
confidence: float = Field(..., description="置信度 (0.0 ~ 1.0)")
bbox: List[List[float]] = Field(
...,
description="边界框 4 个顶点坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]",
)
bbox_with_offset: List[List[float]] = Field(
...,
description="带偏移的边界框坐标 (已还原到原图坐标系)",
)
center: List[float] = Field(..., description="文本块中心点坐标 [x, y]")
width: float = Field(..., description="文本块宽度 (像素)")
height: float = Field(..., description="文本块高度 (像素)")
class OCRResultData(BaseModel):
"""OCR 识别结果数据"""
processing_time_ms: float = Field(..., description="处理耗时 (毫秒)")
text_count: int = Field(..., description="识别出的文本块数量")
average_confidence: float = Field(
...,
description="所有文本块的平均置信度",
)
roi_applied: bool = Field(..., description="是否应用了 ROI 裁剪")
roi_rect: Optional[List[int]] = Field(
default=None,
description="ROI 矩形区域 [x, y, width, height]",
)
text_blocks: List[TextBlockData] = Field(
default_factory=list,
description="识别出的文本块列表",
)
annotated_image_base64: Optional[str] = Field(
default=None,
description="标注后的图片 (Base64 编码)",
)
class OCRResponse(BaseModel):
"""OCR 识别响应"""
success: bool = Field(..., description="请求是否成功")
data: Optional[OCRResultData] = Field(
default=None,
description="识别结果数据",
)
error: Optional[ErrorDetail] = Field(
default=None,
description="错误信息 (仅在失败时存在)",
)
class ExpressPersonData(BaseModel):
"""快递单人员信息"""
name: Optional[str] = Field(default=None, description="姓名")
phone: Optional[str] = Field(default=None, description="电话")
address: Optional[str] = Field(default=None, description="地址")
class ExpressInfoData(BaseModel):
"""快递单结构化信息"""
tracking_number: Optional[str] = Field(default=None, description="运单号")
sender: ExpressPersonData = Field(
default_factory=ExpressPersonData,
description="寄件人信息",
)
receiver: ExpressPersonData = Field(
default_factory=ExpressPersonData,
description="收件人信息",
)
courier_company: Optional[str] = Field(
default=None,
description="快递公司名称",
)
confidence: float = Field(default=0.0, description="平均置信度")
extra_fields: Dict[str, str] = Field(
default_factory=dict,
description="其他识别到的额外字段",
)
raw_text: str = Field(default="", description="原始合并文本")
class ExpressResultData(BaseModel):
"""快递单解析结果数据"""
processing_time_ms: float = Field(..., description="处理耗时 (毫秒)")
express_info: ExpressInfoData = Field(..., description="快递单结构化信息")
merged_text: str = Field(..., description="智能合并后的完整文本")
annotated_image_base64: Optional[str] = Field(
default=None,
description="标注后的图片 (Base64 编码)",
)
class ExpressResponse(BaseModel):
"""快递单解析响应"""
success: bool = Field(..., description="请求是否成功")
data: Optional[ExpressResultData] = Field(
default=None,
description="解析结果数据",
)
error: Optional[ErrorDetail] = Field(
default=None,
description="错误信息 (仅在失败时存在)",
)
class HealthResponse(BaseModel):
"""健康检查响应"""
status: str = Field(..., description="服务状态: 'healthy''unhealthy'")
model_loaded: bool = Field(..., description="OCR 模型是否已加载")
version: str = Field(..., description="API 版本")

@ -0,0 +1,180 @@
# -*- coding: utf-8 -*-
"""
安全模块
提供文件验证大小限制等安全功能
"""
import base64
from pathlib import Path
from typing import Optional
from api.exceptions import (
FileTooLargeError,
InvalidImageError,
UnsupportedFormatError,
)
# 安全配置常量
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tiff", ".tif"}
ALLOWED_MIME_TYPES = {
"image/jpeg",
"image/png",
"image/bmp",
"image/webp",
"image/tiff",
}
# 图片文件魔数 (Magic Bytes)
IMAGE_MAGIC_BYTES = {
b"\xff\xd8\xff": "jpeg", # JPEG
b"\x89PNG\r\n\x1a\n": "png", # PNG
b"BM": "bmp", # BMP
b"RIFF": "webp", # WebP (需要进一步检查)
b"II*\x00": "tiff", # TIFF (Little Endian)
b"MM\x00*": "tiff", # TIFF (Big Endian)
}
def validate_file_extension(filename: Optional[str]) -> bool:
"""
验证文件扩展名
Args:
filename: 文件名
Returns:
是否为允许的扩展名
"""
if not filename:
return False
ext = Path(filename).suffix.lower()
return ext in ALLOWED_EXTENSIONS
def validate_file_size(content: bytes) -> bool:
"""
验证文件大小
Args:
content: 文件内容
Returns:
是否在允许的大小范围内
"""
return len(content) <= MAX_FILE_SIZE
def validate_image_magic_bytes(content: bytes) -> bool:
"""
验证图片文件魔数
Args:
content: 文件内容
Returns:
是否为有效的图片文件
"""
if len(content) < 8:
return False
for magic, _ in IMAGE_MAGIC_BYTES.items():
if content.startswith(magic):
# WebP 需要额外检查
if magic == b"RIFF" and len(content) >= 12:
if content[8:12] != b"WEBP":
continue
return True
return False
def validate_mime_type(content_type: Optional[str]) -> bool:
"""
验证 MIME 类型
Args:
content_type: MIME 类型
Returns:
是否为允许的 MIME 类型
"""
if not content_type:
return False
# 处理带参数的 MIME 类型,如 "image/jpeg; charset=utf-8"
mime = content_type.split(";")[0].strip().lower()
return mime in ALLOWED_MIME_TYPES
def validate_uploaded_file(
content: bytes,
filename: Optional[str] = None,
content_type: Optional[str] = None,
) -> bytes:
"""
综合验证上传的文件
Args:
content: 文件内容
filename: 文件名
content_type: MIME 类型
Returns:
验证通过的文件内容
Raises:
FileTooLargeError: 文件过大
UnsupportedFormatError: 不支持的格式
InvalidImageError: 无效的图片
"""
# 验证文件大小
if not validate_file_size(content):
raise FileTooLargeError(
f"文件大小超过限制,最大允许 {MAX_FILE_SIZE // 1024 // 1024}MB"
)
# 验证扩展名 (如果提供)
if filename and not validate_file_extension(filename):
raise UnsupportedFormatError(
f"不支持的文件格式,允许的格式: {', '.join(ALLOWED_EXTENSIONS)}"
)
# 验证 MIME 类型 (如果提供)
if content_type and not validate_mime_type(content_type):
raise UnsupportedFormatError(f"不支持的 MIME 类型: {content_type}")
# 验证文件魔数
if not validate_image_magic_bytes(content):
raise InvalidImageError("文件内容不是有效的图片格式")
return content
def decode_base64_image(base64_string: str) -> bytes:
"""
解码 Base64 图片
Args:
base64_string: Base64 编码的图片字符串
Returns:
解码后的图片字节
Raises:
InvalidImageError: Base64 解码失败或不是有效图片
"""
# 移除可能的 Data URL 前缀
if "," in base64_string:
base64_string = base64_string.split(",", 1)[1]
# 移除空白字符
base64_string = base64_string.strip()
try:
content = base64.b64decode(base64_string)
except Exception as e:
raise InvalidImageError(f"Base64 解码失败: {str(e)}")
# 验证解码后的内容
return validate_uploaded_file(content)

@ -8,3 +8,14 @@ numpy>=1.24.0,<2.0
# Optional dependencies (for Chinese font rendering)
Pillow>=10.0.0
# REST API dependencies
fastapi>=0.109.0
uvicorn[standard]>=0.27.0
python-multipart>=0.0.6
pydantic>=2.5.0
# Testing dependencies
pytest>=7.4.0
pytest-asyncio>=0.23.0
httpx>=0.26.0

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
"""
Vision-OCR API 测试模块
"""

@ -0,0 +1,134 @@
# -*- coding: utf-8 -*-
"""
pytest 配置和共享 fixtures
"""
import base64
import os
import sys
from pathlib import Path
from typing import Generator
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from fastapi.testclient import TestClient
# 将项目根目录添加到 Python 路径
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
# 设置模型路径环境变量
os.environ["PADDLEOCR_HOME"] = str(PROJECT_ROOT / "models")
@pytest.fixture(scope="session")
def mock_ocr_pipeline():
"""
创建模拟的 OCR Pipeline
避免在测试中加载真实的 OCR 模型
"""
from ocr.engine import TextBlock
from ocr.pipeline import OCRResult
mock_pipeline = MagicMock()
# 模拟 OCR 结果
mock_text_blocks = [
TextBlock(
text="测试文本1",
confidence=0.95,
bbox=[[10, 10], [100, 10], [100, 30], [10, 30]],
bbox_offset=(0, 0),
),
TextBlock(
text="测试文本2",
confidence=0.88,
bbox=[[10, 40], [150, 40], [150, 60], [10, 60]],
bbox_offset=(0, 0),
),
]
mock_result = OCRResult(
image_index=1,
image_path=None,
timestamp=1704672000.0,
processing_time_ms=45.6,
text_blocks=mock_text_blocks,
roi_applied=False,
roi_rect=None,
)
mock_pipeline.process.return_value = mock_result
mock_pipeline.initialize.return_value = None
mock_pipeline._pipeline_config = MagicMock()
return mock_pipeline
@pytest.fixture(scope="session")
def test_client(mock_ocr_pipeline) -> Generator[TestClient, None, None]:
"""
创建测试客户端
使用模拟的 OCR Pipeline 避免加载真实模型
"""
# 延迟导入以确保环境变量已设置
from api.main import app
# 设置模拟的 pipeline
app.state.ocr_pipeline = mock_ocr_pipeline
app.state.model_loaded = True
with TestClient(app) as client:
yield client
@pytest.fixture
def sample_image_bytes() -> bytes:
"""
创建测试用的图片字节数据
生成一个简单的 100x100 白色 JPEG 图片
"""
import cv2
# 创建白色图片
image = np.ones((100, 100, 3), dtype=np.uint8) * 255
# 编码为 JPEG
success, encoded = cv2.imencode(".jpg", image)
assert success, "图片编码失败"
return encoded.tobytes()
@pytest.fixture
def sample_image_base64(sample_image_bytes) -> str:
"""
创建测试用的 Base64 编码图片
"""
return base64.b64encode(sample_image_bytes).decode("utf-8")
@pytest.fixture
def sample_png_bytes() -> bytes:
"""
创建测试用的 PNG 图片字节数据
"""
import cv2
image = np.ones((100, 100, 3), dtype=np.uint8) * 255
success, encoded = cv2.imencode(".png", image)
assert success, "PNG 编码失败"
return encoded.tobytes()
@pytest.fixture
def invalid_file_bytes() -> bytes:
"""
创建无效的文件字节数据 (非图片)
"""
return b"This is not an image file content"

@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
"""
健康检查 API 测试
"""
import pytest
from fastapi.testclient import TestClient
class TestHealthEndpoints:
"""健康检查端点测试"""
def test_health_check_success(self, test_client: TestClient):
"""测试健康检查端点 - 正常情况"""
response = test_client.get("/api/v1/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert data["model_loaded"] is True
assert "version" in data
def test_readiness_check_success(self, test_client: TestClient):
"""测试就绪检查端点 - 正常情况"""
response = test_client.get("/api/v1/health/ready")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert data["model_loaded"] is True
def test_root_endpoint(self, test_client: TestClient):
"""测试根路径端点"""
response = test_client.get("/")
assert response.status_code == 200
data = response.json()
assert data["name"] == "Vision-OCR API"
assert "version" in data
assert "docs" in data
class TestHealthEndpointsUnhealthy:
"""健康检查端点测试 - 模型未加载情况"""
def test_health_check_model_not_loaded(self, test_client: TestClient):
"""测试健康检查 - 模型未加载"""
# 临时设置模型未加载状态
original_state = test_client.app.state.model_loaded
test_client.app.state.model_loaded = False
try:
response = test_client.get("/api/v1/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "unhealthy"
assert data["model_loaded"] is False
finally:
# 恢复原始状态
test_client.app.state.model_loaded = original_state
def test_readiness_check_model_not_loaded(self, test_client: TestClient):
"""测试就绪检查 - 模型未加载"""
original_state = test_client.app.state.model_loaded
test_client.app.state.model_loaded = False
try:
response = test_client.get("/api/v1/health/ready")
assert response.status_code == 200
data = response.json()
assert data["status"] == "unhealthy"
assert data["model_loaded"] is False
finally:
test_client.app.state.model_loaded = original_state

@ -0,0 +1,315 @@
# -*- coding: utf-8 -*-
"""
OCR API 测试
"""
import io
import pytest
from fastapi.testclient import TestClient
class TestOCRRecognizeMultipart:
"""OCR 识别端点测试 (multipart/form-data)"""
def test_recognize_success(
self,
test_client: TestClient,
sample_image_bytes: bytes,
):
"""测试 OCR 识别 - 正常情况"""
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
data={"lang": "ch", "drop_score": "0.5"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"] is not None
assert "text_count" in data["data"]
assert "text_blocks" in data["data"]
assert "processing_time_ms" in data["data"]
def test_recognize_with_roi(
self,
test_client: TestClient,
sample_image_bytes: bytes,
):
"""测试 OCR 识别 - 带 ROI 参数"""
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
data={
"lang": "ch",
"roi_x": "0.1",
"roi_y": "0.1",
"roi_width": "0.8",
"roi_height": "0.8",
},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_recognize_with_annotated_image(
self,
test_client: TestClient,
sample_image_bytes: bytes,
):
"""测试 OCR 识别 - 返回标注图片"""
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
data={"return_annotated_image": "true"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
# 注意: 标注图片只有在有识别结果时才返回
if data["data"]["text_count"] > 0:
assert data["data"]["annotated_image_base64"] is not None
def test_recognize_png_image(
self,
test_client: TestClient,
sample_png_bytes: bytes,
):
"""测试 OCR 识别 - PNG 格式"""
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("test.png", io.BytesIO(sample_png_bytes), "image/png")},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_recognize_invalid_file(
self,
test_client: TestClient,
invalid_file_bytes: bytes,
):
"""测试 OCR 识别 - 无效文件"""
response = test_client.post(
"/api/v1/ocr/recognize",
files={
"file": (
"test.txt",
io.BytesIO(invalid_file_bytes),
"text/plain",
)
},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["error"] is not None
def test_recognize_no_file(self, test_client: TestClient):
"""测试 OCR 识别 - 未提供文件"""
response = test_client.post("/api/v1/ocr/recognize")
assert response.status_code == 422 # Validation Error
class TestOCRRecognizeBase64:
"""OCR 识别端点测试 (Base64 JSON)"""
def test_recognize_base64_success(
self,
test_client: TestClient,
sample_image_base64: str,
):
"""测试 OCR 识别 (Base64) - 正常情况"""
response = test_client.post(
"/api/v1/ocr/recognize/base64",
json={
"image_base64": sample_image_base64,
"lang": "ch",
"drop_score": 0.5,
},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"] is not None
def test_recognize_base64_with_data_url(
self,
test_client: TestClient,
sample_image_base64: str,
):
"""测试 OCR 识别 (Base64) - Data URL 格式"""
data_url = f"data:image/jpeg;base64,{sample_image_base64}"
response = test_client.post(
"/api/v1/ocr/recognize/base64",
json={"image_base64": data_url},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_recognize_base64_with_roi(
self,
test_client: TestClient,
sample_image_base64: str,
):
"""测试 OCR 识别 (Base64) - 带 ROI 参数"""
response = test_client.post(
"/api/v1/ocr/recognize/base64",
json={
"image_base64": sample_image_base64,
"roi": {"x": 0.1, "y": 0.1, "width": 0.8, "height": 0.8},
},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_recognize_base64_invalid(self, test_client: TestClient):
"""测试 OCR 识别 (Base64) - 无效 Base64"""
response = test_client.post(
"/api/v1/ocr/recognize/base64",
json={"image_base64": "not-valid-base64!!!"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["error"] is not None
def test_recognize_base64_missing_field(self, test_client: TestClient):
"""测试 OCR 识别 (Base64) - 缺少必填字段"""
response = test_client.post(
"/api/v1/ocr/recognize/base64",
json={"lang": "ch"}, # 缺少 image_base64
)
assert response.status_code == 422 # Validation Error
class TestExpressMultipart:
"""快递单解析端点测试 (multipart/form-data)"""
def test_express_success(
self,
test_client: TestClient,
sample_image_bytes: bytes,
):
"""测试快递单解析 - 正常情况"""
response = test_client.post(
"/api/v1/ocr/express",
files={"file": ("express.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"] is not None
assert "express_info" in data["data"]
assert "merged_text" in data["data"]
assert "processing_time_ms" in data["data"]
def test_express_info_structure(
self,
test_client: TestClient,
sample_image_bytes: bytes,
):
"""测试快递单解析 - 响应结构"""
response = test_client.post(
"/api/v1/ocr/express",
files={"file": ("express.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")},
)
assert response.status_code == 200
data = response.json()
express_info = data["data"]["express_info"]
# 验证结构完整性
assert "tracking_number" in express_info
assert "sender" in express_info
assert "receiver" in express_info
assert "courier_company" in express_info
assert "confidence" in express_info
# 验证 sender/receiver 结构
assert "name" in express_info["sender"]
assert "phone" in express_info["sender"]
assert "address" in express_info["sender"]
class TestExpressBase64:
"""快递单解析端点测试 (Base64 JSON)"""
def test_express_base64_success(
self,
test_client: TestClient,
sample_image_base64: str,
):
"""测试快递单解析 (Base64) - 正常情况"""
response = test_client.post(
"/api/v1/ocr/express/base64",
json={"image_base64": sample_image_base64},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"] is not None
assert "express_info" in data["data"]
class TestSecurityValidation:
"""安全验证测试"""
def test_file_size_limit(self, test_client: TestClient):
"""测试文件大小限制"""
# 创建一个超大的假文件 (11MB)
large_content = b"x" * (11 * 1024 * 1024)
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("large.jpg", io.BytesIO(large_content), "image/jpeg")},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert "大小" in data["error"]["message"] or "size" in data["error"]["message"].lower()
def test_invalid_extension(
self,
test_client: TestClient,
sample_image_bytes: bytes,
):
"""测试无效文件扩展名"""
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("test.exe", io.BytesIO(sample_image_bytes), "application/octet-stream")},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
def test_magic_bytes_validation(self, test_client: TestClient):
"""测试文件魔数验证"""
# 创建一个假的 jpg 文件 (扩展名正确但内容不是图片)
fake_jpg = b"This is not a real JPEG file"
response = test_client.post(
"/api/v1/ocr/recognize",
files={"file": ("fake.jpg", io.BytesIO(fake_jpg), "image/jpeg")},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
Loading…
Cancel
Save