From c05e3e58eddb08cc9aed913d529c2a198e66c76e Mon Sep 17 00:00:00 2001 From: Harden <1915702192@qq.com> Date: Thu, 8 Jan 2026 11:26:17 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=9F=BA=E4=BA=8E=20?= =?UTF-8?q?FastAPI=20=E7=9A=84=20REST=20API=20=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 253 ++++++++++++++++++++++++++- api/__init__.py | 9 + api/dependencies.py | 157 +++++++++++++++++ api/exceptions.py | 49 ++++++ api/main.py | 172 +++++++++++++++++++ api/routes/__init__.py | 10 ++ api/routes/health.py | 54 ++++++ api/routes/ocr.py | 360 +++++++++++++++++++++++++++++++++++++++ api/schemas/__init__.py | 31 ++++ api/schemas/request.py | 141 +++++++++++++++ api/schemas/response.py | 142 +++++++++++++++ api/security.py | 180 ++++++++++++++++++++ requirements.txt | 11 ++ tests/__init__.py | 4 + tests/conftest.py | 134 +++++++++++++++ tests/test_api_health.py | 76 +++++++++ tests/test_api_ocr.py | 315 ++++++++++++++++++++++++++++++++++ 17 files changed, 2095 insertions(+), 3 deletions(-) create mode 100644 api/__init__.py create mode 100644 api/dependencies.py create mode 100644 api/exceptions.py create mode 100644 api/main.py create mode 100644 api/routes/__init__.py create mode 100644 api/routes/health.py create mode 100644 api/routes/ocr.py create mode 100644 api/schemas/__init__.py create mode 100644 api/schemas/request.py create mode 100644 api/schemas/response.py create mode 100644 api/security.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_api_health.py create mode 100644 tests/test_api_ocr.py diff --git a/README.md b/README.md index e11e492..49fca96 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Vision-OCR: 图片 OCR 识别系统 -基于 PaddleOCR 的图片 OCR 识别系统,支持单张图片、批量图片和目录扫描,提供文本检测、识别、方向分类,输出结构化识别结果。 +基于 PaddleOCR 的图片 OCR 识别系统,支持单张图片、批量图片和目录扫描,提供文本检测、识别、方向分类,输出结构化识别结果。同时提供 REST API 服务,支持 HTTP 接口调用。 ## 功能特性 @@ -11,6 +11,7 @@ - **可视化展示**: 在图片上绘制文本框和识别结果 - **结果导出**: 支持 JSON 结果导出和标注图片保存 - **ROI 裁剪**: 支持只识别图片指定区域 +- **REST API**: 提供 HTTP 接口,支持文件上传和 Base64 两种方式 - **模块化设计**: 图片加载与 OCR 逻辑完全解耦,便于扩展 - **全本地运行**: 不依赖任何云服务 @@ -18,6 +19,18 @@ ``` vision-ocr/ +├── api/ # REST API 模块 +│ ├── __init__.py +│ ├── main.py # FastAPI 应用入口 +│ ├── dependencies.py # 依赖注入 +│ ├── exceptions.py # 自定义异常 +│ ├── security.py # 安全验证 +│ ├── routes/ # 路由模块 +│ │ ├── health.py # 健康检查端点 +│ │ └── ocr.py # OCR 识别端点 +│ └── schemas/ # 数据模型 +│ ├── request.py # 请求模型 +│ └── response.py # 响应模型 ├── input/ # 图片输入模块 │ ├── __init__.py │ └── loader.py # 图片加载器 @@ -32,8 +45,12 @@ vision-ocr/ ├── utils/ # 工具模块 │ ├── __init__.py │ └── config.py # 配置管理 -├── models/ # 模型文件目录(运行 download_models.py 后生成) -├── main.py # 主入口 +├── tests/ # 测试模块 +│ ├── conftest.py # pytest 配置 +│ ├── test_api_health.py # 健康检查测试 +│ └── test_api_ocr.py # OCR API 测试 +├── models/ # 模型文件目录 +├── main.py # CLI 主入口 ├── download_models.py # 模型下载脚本 ├── requirements.txt # 依赖清单 └── README.md @@ -540,6 +557,236 @@ pipeline.add_postprocessor(filter_short_text) 4. **提高置信度阈值**: 使用 `--drop-score` 过滤低质量结果 5. **批量处理**: 使用目录模式批量处理多张图片 +## REST API 服务 + +除了命令行工具,本项目还提供基于 FastAPI 的 REST API 服务。 + +### 启动服务 + +```bash +# 开发模式 (支持热重载) +uvicorn api.main:app --reload --host 0.0.0.0 --port 8000 + +# 生产模式 (多进程) +uvicorn api.main:app --host 0.0.0.0 --port 8000 --workers 4 + +# 或直接运行 +python -m api.main +``` + +启动后访问: +- API 文档 (Swagger UI): http://localhost:8000/docs +- API 文档 (ReDoc): http://localhost:8000/redoc +- 健康检查: http://localhost:8000/api/v1/health + +### API 端点 + +| 方法 | 端点 | 功能 | 输入格式 | +|------|------|------|---------| +| `GET` | `/api/v1/health` | 健康检查 | - | +| `GET` | `/api/v1/health/ready` | 就绪检查 | - | +| `POST` | `/api/v1/ocr/recognize` | OCR 识别 | multipart/form-data | +| `POST` | `/api/v1/ocr/recognize/base64` | OCR 识别 | JSON (Base64) | +| `POST` | `/api/v1/ocr/express` | 快递单解析 | multipart/form-data | +| `POST` | `/api/v1/ocr/express/base64` | 快递单解析 | JSON (Base64) | + +### 使用示例 + +#### 方式一: 文件上传 (multipart/form-data) + +```bash +# OCR 识别 +curl -X POST "http://localhost:8000/api/v1/ocr/recognize" \ + -F "file=@image.jpg" \ + -F "lang=ch" \ + -F "drop_score=0.5" + +# 快递单解析 +curl -X POST "http://localhost:8000/api/v1/ocr/express" \ + -F "file=@express.jpg" + +# 带 ROI 参数 +curl -X POST "http://localhost:8000/api/v1/ocr/recognize" \ + -F "file=@image.jpg" \ + -F "roi_x=0.1" \ + -F "roi_y=0.1" \ + -F "roi_width=0.8" \ + -F "roi_height=0.8" + +# 返回标注图片 +curl -X POST "http://localhost:8000/api/v1/ocr/recognize" \ + -F "file=@image.jpg" \ + -F "return_annotated_image=true" +``` + +#### 方式二: Base64 JSON + +```bash +# OCR 识别 +curl -X POST "http://localhost:8000/api/v1/ocr/recognize/base64" \ + -H "Content-Type: application/json" \ + -d '{ + "image_base64": "'"$(base64 -w 0 image.jpg)"'", + "lang": "ch", + "drop_score": 0.5 + }' + +# 快递单解析 +curl -X POST "http://localhost:8000/api/v1/ocr/express/base64" \ + -H "Content-Type: application/json" \ + -d '{ + "image_base64": "'"$(base64 -w 0 express.jpg)"'" + }' + +# 带 ROI 参数 +curl -X POST "http://localhost:8000/api/v1/ocr/recognize/base64" \ + -H "Content-Type: application/json" \ + -d '{ + "image_base64": "...", + "roi": {"x": 0.1, "y": 0.1, "width": 0.8, "height": 0.8} + }' +``` + +#### Python 调用示例 + +```python +import requests +import base64 + +# 方式一: 文件上传 +with open("image.jpg", "rb") as f: + response = requests.post( + "http://localhost:8000/api/v1/ocr/recognize", + files={"file": ("image.jpg", f, "image/jpeg")}, + data={"lang": "ch", "drop_score": 0.5} + ) +result = response.json() + +# 方式二: Base64 +with open("image.jpg", "rb") as f: + image_base64 = base64.b64encode(f.read()).decode() + +response = requests.post( + "http://localhost:8000/api/v1/ocr/recognize/base64", + json={ + "image_base64": image_base64, + "lang": "ch" + } +) +result = response.json() + +# 处理结果 +if result["success"]: + for block in result["data"]["text_blocks"]: + print(f"文本: {block['text']}, 置信度: {block['confidence']}") +``` + +### API 响应格式 + +#### OCR 识别响应 + +```json +{ + "success": true, + "data": { + "processing_time_ms": 45.2, + "text_count": 3, + "average_confidence": 0.92, + "roi_applied": false, + "roi_rect": null, + "text_blocks": [ + { + "text": "识别的文本", + "confidence": 0.95, + "bbox": [[100, 50], [200, 50], [200, 80], [100, 80]], + "bbox_with_offset": [[100, 50], [200, 50], [200, 80], [100, 80]], + "center": [150, 65], + "width": 100, + "height": 30 + } + ], + "annotated_image_base64": null + }, + "error": null +} +``` + +#### 快递单解析响应 + +```json +{ + "success": true, + "data": { + "processing_time_ms": 52.3, + "express_info": { + "tracking_number": "SF1234567890", + "sender": { + "name": "张三", + "phone": "13800138000", + "address": "北京市朝阳区xxx路" + }, + "receiver": { + "name": "李四", + "phone": "13900139000", + "address": "上海市浦东新区xxx路" + }, + "courier_company": "顺丰速运", + "confidence": 0.95, + "extra_fields": {}, + "raw_text": "..." + }, + "merged_text": "顺丰速运\n运单号:SF1234567890\n...", + "annotated_image_base64": null + }, + "error": null +} +``` + +#### 错误响应 + +```json +{ + "success": false, + "data": null, + "error": { + "code": "InvalidImageError", + "message": "无效的图片文件" + } +} +``` + +### API 请求参数 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `file` | File | - | 图片文件 (multipart 模式) | +| `image_base64` | string | - | Base64 编码图片 (JSON 模式) | +| `lang` | string | "ch" | 识别语言 | +| `use_gpu` | bool | false | 是否使用 GPU | +| `drop_score` | float | 0.5 | 置信度阈值 (0.0~1.0) | +| `roi_x` | float | - | ROI 左上角 X (0.0~1.0) | +| `roi_y` | float | - | ROI 左上角 Y (0.0~1.0) | +| `roi_width` | float | - | ROI 宽度 (0.0~1.0) | +| `roi_height` | float | - | ROI 高度 (0.0~1.0) | +| `return_annotated_image` | bool | false | 是否返回标注图片 | + +### 安全限制 + +- 最大文件大小: 10MB +- 支持的格式: jpg, jpeg, png, bmp, webp, tiff +- 文件验证: 扩展名 + MIME 类型 + 文件魔数 + +### 运行测试 + +```bash +# 运行所有 API 测试 +pytest tests/ -v + +# 运行特定测试 +pytest tests/test_api_health.py -v +pytest tests/test_api_ocr.py -v +``` + ## 常见问题 ### Q: Windows 中文用户名导致模型加载失败? diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000..96f020c --- /dev/null +++ b/api/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +""" +Vision-OCR REST API 模块 +提供 HTTP 接口访问 OCR 功能 +""" + +from api.main import app + +__all__ = ["app"] diff --git a/api/dependencies.py b/api/dependencies.py new file mode 100644 index 0000000..88614b4 --- /dev/null +++ b/api/dependencies.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- +""" +依赖注入模块 +提供 FastAPI 依赖项 +""" + +import base64 +from typing import Optional + +import cv2 +import numpy as np +from fastapi import Depends, File, Form, Request, UploadFile + +from api.exceptions import InvalidImageError, ModelNotLoadedError +from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams +from api.security import decode_base64_image, validate_uploaded_file +from ocr.pipeline import OCRPipeline, OCRResult +from utils.config import OCRConfig, PipelineConfig, ROIConfig + + +def get_ocr_pipeline(request: Request) -> OCRPipeline: + """ + 获取 OCR Pipeline 实例 + + Args: + request: FastAPI 请求对象 + + Returns: + OCRPipeline 实例 + + Raises: + ModelNotLoadedError: 模型未加载 + """ + pipeline = getattr(request.app.state, "ocr_pipeline", None) + if pipeline is None: + raise ModelNotLoadedError() + return pipeline + + +async def parse_multipart_image( + file: UploadFile = File(..., description="图片文件"), +) -> bytes: + """ + 解析 multipart 上传的图片文件 + + Args: + file: 上传的文件 + + Returns: + 验证后的图片字节数据 + """ + content = await file.read() + return validate_uploaded_file( + content=content, + filename=file.filename, + content_type=file.content_type, + ) + + +async def parse_multipart_params( + lang: str = Form(default="ch", description="识别语言"), + use_gpu: bool = Form(default=False, description="是否使用 GPU"), + drop_score: float = Form(default=0.5, ge=0.0, le=1.0, description="置信度阈值"), + roi_x: Optional[float] = Form(default=None, ge=0.0, le=1.0, description="ROI X"), + roi_y: Optional[float] = Form(default=None, ge=0.0, le=1.0, description="ROI Y"), + roi_width: Optional[float] = Form( + default=None, gt=0.0, le=1.0, description="ROI 宽度" + ), + roi_height: Optional[float] = Form( + default=None, gt=0.0, le=1.0, description="ROI 高度" + ), + return_annotated_image: bool = Form( + default=False, description="是否返回标注图片" + ), +) -> OCRRequestParams: + """ + 解析 multipart 表单参数 + + Returns: + OCR 请求参数对象 + """ + return OCRRequestParams( + lang=lang, + use_gpu=use_gpu, + drop_score=drop_score, + roi_x=roi_x, + roi_y=roi_y, + roi_width=roi_width, + roi_height=roi_height, + return_annotated_image=return_annotated_image, + ) + + +def decode_image_bytes(content: bytes) -> np.ndarray: + """ + 将图片字节解码为 numpy 数组 + + Args: + content: 图片字节数据 + + Returns: + OpenCV 图像 (BGR 格式) + + Raises: + InvalidImageError: 图片解码失败 + """ + try: + nparr = np.frombuffer(content, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + if image is None: + raise InvalidImageError("图片解码失败") + return image + except Exception as e: + if isinstance(e, InvalidImageError): + raise + raise InvalidImageError(f"图片解码失败: {str(e)}") + + +def encode_image_base64(image: np.ndarray, format: str = ".jpg") -> str: + """ + 将图像编码为 Base64 字符串 + + Args: + image: OpenCV 图像 + format: 图片格式 (如 '.jpg', '.png') + + Returns: + Base64 编码的图片字符串 + """ + success, encoded = cv2.imencode(format, image) + if not success: + return "" + return base64.b64encode(encoded.tobytes()).decode("utf-8") + + +def build_pipeline_config(roi: Optional[ROIParams]) -> PipelineConfig: + """ + 根据 ROI 参数构建管道配置 + + Args: + roi: ROI 参数 + + Returns: + 管道配置对象 + """ + if roi is None: + return PipelineConfig(roi=ROIConfig(enabled=False)) + + return PipelineConfig( + roi=ROIConfig( + enabled=True, + x_ratio=roi.x, + y_ratio=roi.y, + width_ratio=roi.width, + height_ratio=roi.height, + ) + ) diff --git a/api/exceptions.py b/api/exceptions.py new file mode 100644 index 0000000..6996c42 --- /dev/null +++ b/api/exceptions.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +""" +自定义异常模块 +定义 API 层的异常类型 +""" + + +class OCRAPIException(Exception): + """OCR API 基础异常""" + + def __init__(self, message: str, status_code: int = 500): + self.message = message + self.status_code = status_code + super().__init__(self.message) + + +class InvalidImageError(OCRAPIException): + """无效的图片文件""" + + def __init__(self, message: str = "无效的图片文件"): + super().__init__(message, status_code=400) + + +class FileTooLargeError(OCRAPIException): + """文件过大""" + + def __init__(self, message: str = "文件大小超过限制"): + super().__init__(message, status_code=413) + + +class UnsupportedFormatError(OCRAPIException): + """不支持的文件格式""" + + def __init__(self, message: str = "不支持的文件格式"): + super().__init__(message, status_code=415) + + +class OCRProcessingError(OCRAPIException): + """OCR 处理错误""" + + def __init__(self, message: str = "OCR 处理失败"): + super().__init__(message, status_code=500) + + +class ModelNotLoadedError(OCRAPIException): + """模型未加载""" + + def __init__(self, message: str = "OCR 模型尚未加载完成"): + super().__init__(message, status_code=503) diff --git a/api/main.py b/api/main.py new file mode 100644 index 0000000..6ba1f02 --- /dev/null +++ b/api/main.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +""" +Vision-OCR REST API 主入口 +基于 FastAPI 的 OCR 服务 +""" + +import os +import sys +from contextlib import asynccontextmanager +from pathlib import Path + +# 设置项目根目录和模型路径 +_PROJECT_ROOT = Path(__file__).parent.parent +_MODELS_DIR = _PROJECT_ROOT / "models" +_MODELS_DIR.mkdir(exist_ok=True) +os.environ["PADDLEOCR_HOME"] = str(_MODELS_DIR) + +# 将项目根目录添加到 Python 路径 +sys.path.insert(0, str(_PROJECT_ROOT)) + +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from api.exceptions import OCRAPIException +from api.routes import health_router, ocr_router +from ocr.pipeline import OCRPipeline +from utils.config import OCRConfig, PipelineConfig + + +def _get_ocr_config() -> OCRConfig: + """ + 获取 OCR 配置 + + 检查模型是否存在于项目目录,如果存在则使用本地模型 + """ + det_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_det_infer") + rec_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_rec_infer") + cls_model_dir = str(_MODELS_DIR / "ch_ppocr_mobile_v2.0_cls_infer") + + # 检查模型是否已下载 + models_exist = ( + Path(det_model_dir).exists() + and Path(rec_model_dir).exists() + and Path(cls_model_dir).exists() + ) + + return OCRConfig( + lang="ch", + use_angle_cls=True, + use_gpu=False, + drop_score=0.5, + det_model_dir=det_model_dir if models_exist else None, + rec_model_dir=rec_model_dir if models_exist else None, + cls_model_dir=cls_model_dir if models_exist else None, + ) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + 应用生命周期管理 + + 启动时预加载 OCR 模型,关闭时清理资源 + """ + # 启动: 加载 OCR 模型 + print("[INFO] 正在初始化 OCR API 服务...") + + ocr_config = _get_ocr_config() + if ocr_config.det_model_dir is None: + print("[WARN] 模型未在项目目录中找到") + print("[WARN] 对于 Windows 中文用户名用户,请先运行:") + print("[WARN] python download_models.py") + print("[INFO] 回退到默认 PaddleOCR 模型路径...") + + pipeline_config = PipelineConfig() + + print("[INFO] 正在加载 OCR 模型...") + pipeline = OCRPipeline(ocr_config, pipeline_config) + pipeline.initialize() + + app.state.ocr_pipeline = pipeline + app.state.model_loaded = True + print("[INFO] OCR 模型加载完成,服务已就绪") + + yield + + # 关闭: 清理资源 + print("[INFO] 正在关闭 OCR API 服务...") + app.state.model_loaded = False + print("[INFO] 服务已关闭") + + +# 创建 FastAPI 应用 +app = FastAPI( + title="Vision-OCR API", + description="基于 PaddleOCR 的图片 OCR 识别服务,支持通用文字识别和快递单解析", + version="1.0.0", + lifespan=lifespan, + docs_url="/docs", + redoc_url="/redoc", + openapi_url="/openapi.json", +) + +# 配置 CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 生产环境应限制具体域名 + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# 全局异常处理 +@app.exception_handler(OCRAPIException) +async def ocr_exception_handler(request: Request, exc: OCRAPIException): + """处理 OCR API 自定义异常""" + return JSONResponse( + status_code=exc.status_code, + content={ + "success": False, + "data": None, + "error": { + "code": type(exc).__name__, + "message": exc.message, + }, + }, + ) + + +@app.exception_handler(Exception) +async def general_exception_handler(request: Request, exc: Exception): + """处理未捕获的异常""" + return JSONResponse( + status_code=500, + content={ + "success": False, + "data": None, + "error": { + "code": "InternalServerError", + "message": f"服务器内部错误: {str(exc)}", + }, + }, + ) + + +# 注册路由 +app.include_router(health_router, prefix="/api/v1") +app.include_router(ocr_router, prefix="/api/v1") + + +@app.get("/", tags=["根路径"]) +async def root(): + """API 根路径,返回基本信息""" + return { + "name": "Vision-OCR API", + "version": "1.0.0", + "docs": "/docs", + "health": "/api/v1/health", + } + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "api.main:app", + host="0.0.0.0", + port=8000, + reload=True, + ) diff --git a/api/routes/__init__.py b/api/routes/__init__.py new file mode 100644 index 0000000..138d881 --- /dev/null +++ b/api/routes/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +""" +路由模块 +定义 API 端点 +""" + +from api.routes.health import router as health_router +from api.routes.ocr import router as ocr_router + +__all__ = ["health_router", "ocr_router"] diff --git a/api/routes/health.py b/api/routes/health.py new file mode 100644 index 0000000..ebb9119 --- /dev/null +++ b/api/routes/health.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +""" +健康检查路由 +提供服务状态检查端点 +""" + +from fastapi import APIRouter, Request + +from api.schemas.response import HealthResponse + +router = APIRouter(prefix="/health", tags=["健康检查"]) + +# API 版本 +API_VERSION = "1.0.0" + + +@router.get( + "", + response_model=HealthResponse, + summary="健康检查", + description="检查服务是否正常运行", +) +async def health_check(request: Request) -> HealthResponse: + """ + 基础健康检查 + 返回服务状态和模型加载状态 + """ + model_loaded = getattr(request.app.state, "model_loaded", False) + + return HealthResponse( + status="healthy" if model_loaded else "unhealthy", + model_loaded=model_loaded, + version=API_VERSION, + ) + + +@router.get( + "/ready", + response_model=HealthResponse, + summary="就绪检查", + description="检查服务是否已准备好处理请求 (模型已加载)", +) +async def readiness_check(request: Request) -> HealthResponse: + """ + 就绪检查 + 只有当模型加载完成后才返回 healthy + """ + model_loaded = getattr(request.app.state, "model_loaded", False) + + return HealthResponse( + status="healthy" if model_loaded else "unhealthy", + model_loaded=model_loaded, + version=API_VERSION, + ) diff --git a/api/routes/ocr.py b/api/routes/ocr.py new file mode 100644 index 0000000..199d0da --- /dev/null +++ b/api/routes/ocr.py @@ -0,0 +1,360 @@ +# -*- coding: utf-8 -*- +""" +OCR 路由模块 +提供 OCR 识别和快递单解析端点 +""" + +from typing import Optional + +from fastapi import APIRouter, Depends, File, Request, UploadFile + +from api.dependencies import ( + build_pipeline_config, + decode_image_bytes, + encode_image_base64, + get_ocr_pipeline, + parse_multipart_image, + parse_multipart_params, +) +from api.exceptions import OCRProcessingError +from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams +from api.schemas.response import ( + ErrorDetail, + ExpressInfoData, + ExpressPersonData, + ExpressResponse, + ExpressResultData, + OCRResponse, + OCRResultData, + TextBlockData, +) +from api.security import decode_base64_image +from ocr.pipeline import OCRPipeline, OCRResult +from utils.config import OCRConfig, PipelineConfig, ROIConfig +from visualize.draw import OCRVisualizer +from utils.config import VisualizeConfig + +router = APIRouter(prefix="/ocr", tags=["OCR 识别"]) + + +def _convert_ocr_result_to_response( + result: OCRResult, + annotated_image_base64: Optional[str] = None, +) -> OCRResultData: + """ + 将 OCRResult 转换为响应数据模型 + + Args: + result: OCR 处理结果 + annotated_image_base64: 标注图片的 Base64 编码 + + Returns: + 响应数据模型 + """ + text_blocks = [ + TextBlockData( + text=block.text, + confidence=block.confidence, + bbox=block.bbox, + bbox_with_offset=block.bbox_with_offset, + center=list(block.center), + width=block.width, + height=block.height, + ) + for block in result.text_blocks + ] + + return OCRResultData( + processing_time_ms=result.processing_time_ms, + text_count=result.text_count, + average_confidence=result.average_confidence, + roi_applied=result.roi_applied, + roi_rect=list(result.roi_rect) if result.roi_rect else None, + text_blocks=text_blocks, + annotated_image_base64=annotated_image_base64, + ) + + +def _process_ocr( + image_bytes: bytes, + pipeline: OCRPipeline, + roi: Optional[ROIParams] = None, + return_annotated_image: bool = False, +) -> tuple[OCRResult, Optional[str]]: + """ + 执行 OCR 处理 + + Args: + image_bytes: 图片字节数据 + pipeline: OCR 管道 + roi: ROI 参数 + return_annotated_image: 是否返回标注图片 + + Returns: + (OCR 结果, 标注图片 Base64) + """ + # 解码图片 + image = decode_image_bytes(image_bytes) + + # 构建管道配置 + pipeline_config = build_pipeline_config(roi) + + # 临时更新管道配置 + original_config = pipeline._pipeline_config + pipeline._pipeline_config = pipeline_config + + try: + # 执行 OCR + result = pipeline.process(image) + except Exception as e: + raise OCRProcessingError(f"OCR 处理失败: {str(e)}") + finally: + # 恢复原始配置 + pipeline._pipeline_config = original_config + + # 生成标注图片 + annotated_image_base64 = None + if return_annotated_image and result.text_count > 0: + visualizer = OCRVisualizer(VisualizeConfig()) + annotated = visualizer.draw_result(image, result) + annotated_image_base64 = encode_image_base64(annotated) + + return result, annotated_image_base64 + + +# ============================================================ +# multipart/form-data 端点 +# ============================================================ + + +@router.post( + "/recognize", + response_model=OCRResponse, + summary="OCR 识别 (文件上传)", + description="上传图片文件进行 OCR 识别,支持 jpg/png/bmp/webp 格式", +) +async def recognize_multipart( + request: Request, + file: UploadFile = File(..., description="图片文件"), + params: OCRRequestParams = Depends(parse_multipart_params), + pipeline: OCRPipeline = Depends(get_ocr_pipeline), +) -> OCRResponse: + """ + OCR 识别端点 (multipart/form-data) + + 上传图片文件进行文字识别 + """ + try: + # 验证并读取文件 + image_bytes = await parse_multipart_image(file) + + # 执行 OCR + result, annotated_base64 = _process_ocr( + image_bytes=image_bytes, + pipeline=pipeline, + roi=params.get_roi(), + return_annotated_image=params.return_annotated_image, + ) + + # 构建响应 + return OCRResponse( + success=True, + data=_convert_ocr_result_to_response(result, annotated_base64), + ) + + except Exception as e: + return OCRResponse( + success=False, + error=ErrorDetail( + code=type(e).__name__, + message=str(e), + ), + ) + + +@router.post( + "/express", + response_model=ExpressResponse, + summary="快递单解析 (文件上传)", + description="上传快递单图片进行 OCR 识别并解析结构化信息", +) +async def express_multipart( + request: Request, + file: UploadFile = File(..., description="快递单图片"), + params: OCRRequestParams = Depends(parse_multipart_params), + pipeline: OCRPipeline = Depends(get_ocr_pipeline), +) -> ExpressResponse: + """ + 快递单解析端点 (multipart/form-data) + + 上传快递单图片,自动识别并提取结构化信息 + """ + try: + # 验证并读取文件 + image_bytes = await parse_multipart_image(file) + + # 执行 OCR + result, annotated_base64 = _process_ocr( + image_bytes=image_bytes, + pipeline=pipeline, + roi=params.get_roi(), + return_annotated_image=params.return_annotated_image, + ) + + # 解析快递单信息 + express_info = result.parse_express() + merged_text = result.merge_text() + + # 构建响应 + return ExpressResponse( + success=True, + data=ExpressResultData( + processing_time_ms=result.processing_time_ms, + express_info=ExpressInfoData( + tracking_number=express_info.tracking_number, + sender=ExpressPersonData( + name=express_info.sender_name, + phone=express_info.sender_phone, + address=express_info.sender_address, + ), + receiver=ExpressPersonData( + name=express_info.receiver_name, + phone=express_info.receiver_phone, + address=express_info.receiver_address, + ), + courier_company=express_info.courier_company, + confidence=express_info.confidence, + extra_fields=express_info.extra_fields, + raw_text=express_info.raw_text, + ), + merged_text=merged_text, + annotated_image_base64=annotated_base64, + ), + ) + + except Exception as e: + return ExpressResponse( + success=False, + error=ErrorDetail( + code=type(e).__name__, + message=str(e), + ), + ) + + +# ============================================================ +# JSON (Base64) 端点 +# ============================================================ + + +@router.post( + "/recognize/base64", + response_model=OCRResponse, + summary="OCR 识别 (Base64)", + description="提交 Base64 编码的图片进行 OCR 识别", +) +async def recognize_base64( + request: Request, + body: OCRRequestBase64, + pipeline: OCRPipeline = Depends(get_ocr_pipeline), +) -> OCRResponse: + """ + OCR 识别端点 (JSON Base64) + + 提交 Base64 编码的图片进行文字识别 + """ + try: + # 解码并验证 Base64 图片 + image_bytes = decode_base64_image(body.image_base64) + + # 执行 OCR + result, annotated_base64 = _process_ocr( + image_bytes=image_bytes, + pipeline=pipeline, + roi=body.roi, + return_annotated_image=body.return_annotated_image, + ) + + # 构建响应 + return OCRResponse( + success=True, + data=_convert_ocr_result_to_response(result, annotated_base64), + ) + + except Exception as e: + return OCRResponse( + success=False, + error=ErrorDetail( + code=type(e).__name__, + message=str(e), + ), + ) + + +@router.post( + "/express/base64", + response_model=ExpressResponse, + summary="快递单解析 (Base64)", + description="提交 Base64 编码的快递单图片进行解析", +) +async def express_base64( + request: Request, + body: OCRRequestBase64, + pipeline: OCRPipeline = Depends(get_ocr_pipeline), +) -> ExpressResponse: + """ + 快递单解析端点 (JSON Base64) + + 提交 Base64 编码的快递单图片,自动识别并提取结构化信息 + """ + try: + # 解码并验证 Base64 图片 + image_bytes = decode_base64_image(body.image_base64) + + # 执行 OCR + result, annotated_base64 = _process_ocr( + image_bytes=image_bytes, + pipeline=pipeline, + roi=body.roi, + return_annotated_image=body.return_annotated_image, + ) + + # 解析快递单信息 + express_info = result.parse_express() + merged_text = result.merge_text() + + # 构建响应 + return ExpressResponse( + success=True, + data=ExpressResultData( + processing_time_ms=result.processing_time_ms, + express_info=ExpressInfoData( + tracking_number=express_info.tracking_number, + sender=ExpressPersonData( + name=express_info.sender_name, + phone=express_info.sender_phone, + address=express_info.sender_address, + ), + receiver=ExpressPersonData( + name=express_info.receiver_name, + phone=express_info.receiver_phone, + address=express_info.receiver_address, + ), + courier_company=express_info.courier_company, + confidence=express_info.confidence, + extra_fields=express_info.extra_fields, + raw_text=express_info.raw_text, + ), + merged_text=merged_text, + annotated_image_base64=annotated_base64, + ), + ) + + except Exception as e: + return ExpressResponse( + success=False, + error=ErrorDetail( + code=type(e).__name__, + message=str(e), + ), + ) diff --git a/api/schemas/__init__.py b/api/schemas/__init__.py new file mode 100644 index 0000000..8e4bfd5 --- /dev/null +++ b/api/schemas/__init__.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +""" +Pydantic 数据模型模块 +定义 API 请求和响应的数据结构 +""" + +from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams +from api.schemas.response import ( + ErrorDetail, + ExpressInfoData, + ExpressResponse, + HealthResponse, + OCRResponse, + OCRResultData, + TextBlockData, +) + +__all__ = [ + # Request + "OCRRequestBase64", + "OCRRequestParams", + "ROIParams", + # Response + "OCRResponse", + "OCRResultData", + "TextBlockData", + "ExpressResponse", + "ExpressInfoData", + "HealthResponse", + "ErrorDetail", +] diff --git a/api/schemas/request.py b/api/schemas/request.py new file mode 100644 index 0000000..2508b4c --- /dev/null +++ b/api/schemas/request.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +""" +请求数据模型 +定义 API 请求的数据结构 +""" + +from typing import Optional + +from pydantic import BaseModel, Field + + +class ROIParams(BaseModel): + """ + ROI (感兴趣区域) 参数 + 使用归一化坐标 (0.0 ~ 1.0) + """ + + x: float = Field( + default=0.0, + ge=0.0, + le=1.0, + description="ROI 左上角 X 坐标 (归一化)", + ) + y: float = Field( + default=0.0, + ge=0.0, + le=1.0, + description="ROI 左上角 Y 坐标 (归一化)", + ) + width: float = Field( + default=1.0, + gt=0.0, + le=1.0, + description="ROI 宽度 (归一化)", + ) + height: float = Field( + default=1.0, + gt=0.0, + le=1.0, + description="ROI 高度 (归一化)", + ) + + +class OCRRequestParams(BaseModel): + """ + OCR 请求参数 (用于 multipart/form-data) + """ + + lang: str = Field( + default="ch", + description="识别语言,支持 'ch'(中文), 'en'(英文) 等", + ) + use_gpu: bool = Field( + default=False, + description="是否使用 GPU 加速", + ) + drop_score: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="置信度阈值,低于此值的结果将被过滤", + ) + roi_x: Optional[float] = Field( + default=None, + ge=0.0, + le=1.0, + description="ROI 左上角 X 坐标 (归一化)", + ) + roi_y: Optional[float] = Field( + default=None, + ge=0.0, + le=1.0, + description="ROI 左上角 Y 坐标 (归一化)", + ) + roi_width: Optional[float] = Field( + default=None, + gt=0.0, + le=1.0, + description="ROI 宽度 (归一化)", + ) + roi_height: Optional[float] = Field( + default=None, + gt=0.0, + le=1.0, + description="ROI 高度 (归一化)", + ) + return_annotated_image: bool = Field( + default=False, + description="是否返回标注后的图片 (Base64)", + ) + + def has_roi(self) -> bool: + """检查是否设置了 ROI 参数""" + return all( + v is not None + for v in [self.roi_x, self.roi_y, self.roi_width, self.roi_height] + ) + + def get_roi(self) -> Optional[ROIParams]: + """获取 ROI 参数对象""" + if not self.has_roi(): + return None + return ROIParams( + x=self.roi_x, + y=self.roi_y, + width=self.roi_width, + height=self.roi_height, + ) + + +class OCRRequestBase64(BaseModel): + """ + OCR 请求 (Base64 JSON 格式) + """ + + image_base64: str = Field( + ..., + description="Base64 编码的图片数据,支持 Data URL 格式", + ) + lang: str = Field( + default="ch", + description="识别语言,支持 'ch'(中文), 'en'(英文) 等", + ) + use_gpu: bool = Field( + default=False, + description="是否使用 GPU 加速", + ) + drop_score: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="置信度阈值,低于此值的结果将被过滤", + ) + roi: Optional[ROIParams] = Field( + default=None, + description="ROI 区域参数", + ) + return_annotated_image: bool = Field( + default=False, + description="是否返回标注后的图片 (Base64)", + ) diff --git a/api/schemas/response.py b/api/schemas/response.py new file mode 100644 index 0000000..a505476 --- /dev/null +++ b/api/schemas/response.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +""" +响应数据模型 +定义 API 响应的数据结构 +""" + +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class ErrorDetail(BaseModel): + """错误详情""" + + code: str = Field(..., description="错误代码") + message: str = Field(..., description="错误信息") + details: Optional[Dict[str, Any]] = Field( + default=None, + description="额外的错误详情", + ) + + +class TextBlockData(BaseModel): + """文本块数据""" + + text: str = Field(..., description="识别出的文本内容") + confidence: float = Field(..., description="置信度 (0.0 ~ 1.0)") + bbox: List[List[float]] = Field( + ..., + description="边界框 4 个顶点坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]", + ) + bbox_with_offset: List[List[float]] = Field( + ..., + description="带偏移的边界框坐标 (已还原到原图坐标系)", + ) + center: List[float] = Field(..., description="文本块中心点坐标 [x, y]") + width: float = Field(..., description="文本块宽度 (像素)") + height: float = Field(..., description="文本块高度 (像素)") + + +class OCRResultData(BaseModel): + """OCR 识别结果数据""" + + processing_time_ms: float = Field(..., description="处理耗时 (毫秒)") + text_count: int = Field(..., description="识别出的文本块数量") + average_confidence: float = Field( + ..., + description="所有文本块的平均置信度", + ) + roi_applied: bool = Field(..., description="是否应用了 ROI 裁剪") + roi_rect: Optional[List[int]] = Field( + default=None, + description="ROI 矩形区域 [x, y, width, height]", + ) + text_blocks: List[TextBlockData] = Field( + default_factory=list, + description="识别出的文本块列表", + ) + annotated_image_base64: Optional[str] = Field( + default=None, + description="标注后的图片 (Base64 编码)", + ) + + +class OCRResponse(BaseModel): + """OCR 识别响应""" + + success: bool = Field(..., description="请求是否成功") + data: Optional[OCRResultData] = Field( + default=None, + description="识别结果数据", + ) + error: Optional[ErrorDetail] = Field( + default=None, + description="错误信息 (仅在失败时存在)", + ) + + +class ExpressPersonData(BaseModel): + """快递单人员信息""" + + name: Optional[str] = Field(default=None, description="姓名") + phone: Optional[str] = Field(default=None, description="电话") + address: Optional[str] = Field(default=None, description="地址") + + +class ExpressInfoData(BaseModel): + """快递单结构化信息""" + + tracking_number: Optional[str] = Field(default=None, description="运单号") + sender: ExpressPersonData = Field( + default_factory=ExpressPersonData, + description="寄件人信息", + ) + receiver: ExpressPersonData = Field( + default_factory=ExpressPersonData, + description="收件人信息", + ) + courier_company: Optional[str] = Field( + default=None, + description="快递公司名称", + ) + confidence: float = Field(default=0.0, description="平均置信度") + extra_fields: Dict[str, str] = Field( + default_factory=dict, + description="其他识别到的额外字段", + ) + raw_text: str = Field(default="", description="原始合并文本") + + +class ExpressResultData(BaseModel): + """快递单解析结果数据""" + + processing_time_ms: float = Field(..., description="处理耗时 (毫秒)") + express_info: ExpressInfoData = Field(..., description="快递单结构化信息") + merged_text: str = Field(..., description="智能合并后的完整文本") + annotated_image_base64: Optional[str] = Field( + default=None, + description="标注后的图片 (Base64 编码)", + ) + + +class ExpressResponse(BaseModel): + """快递单解析响应""" + + success: bool = Field(..., description="请求是否成功") + data: Optional[ExpressResultData] = Field( + default=None, + description="解析结果数据", + ) + error: Optional[ErrorDetail] = Field( + default=None, + description="错误信息 (仅在失败时存在)", + ) + + +class HealthResponse(BaseModel): + """健康检查响应""" + + status: str = Field(..., description="服务状态: 'healthy' 或 'unhealthy'") + model_loaded: bool = Field(..., description="OCR 模型是否已加载") + version: str = Field(..., description="API 版本") diff --git a/api/security.py b/api/security.py new file mode 100644 index 0000000..d1e7127 --- /dev/null +++ b/api/security.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +""" +安全模块 +提供文件验证、大小限制等安全功能 +""" + +import base64 +from pathlib import Path +from typing import Optional + +from api.exceptions import ( + FileTooLargeError, + InvalidImageError, + UnsupportedFormatError, +) + + +# 安全配置常量 +MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB +ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tiff", ".tif"} +ALLOWED_MIME_TYPES = { + "image/jpeg", + "image/png", + "image/bmp", + "image/webp", + "image/tiff", +} + +# 图片文件魔数 (Magic Bytes) +IMAGE_MAGIC_BYTES = { + b"\xff\xd8\xff": "jpeg", # JPEG + b"\x89PNG\r\n\x1a\n": "png", # PNG + b"BM": "bmp", # BMP + b"RIFF": "webp", # WebP (需要进一步检查) + b"II*\x00": "tiff", # TIFF (Little Endian) + b"MM\x00*": "tiff", # TIFF (Big Endian) +} + + +def validate_file_extension(filename: Optional[str]) -> bool: + """ + 验证文件扩展名 + + Args: + filename: 文件名 + + Returns: + 是否为允许的扩展名 + """ + if not filename: + return False + ext = Path(filename).suffix.lower() + return ext in ALLOWED_EXTENSIONS + + +def validate_file_size(content: bytes) -> bool: + """ + 验证文件大小 + + Args: + content: 文件内容 + + Returns: + 是否在允许的大小范围内 + """ + return len(content) <= MAX_FILE_SIZE + + +def validate_image_magic_bytes(content: bytes) -> bool: + """ + 验证图片文件魔数 + + Args: + content: 文件内容 + + Returns: + 是否为有效的图片文件 + """ + if len(content) < 8: + return False + + for magic, _ in IMAGE_MAGIC_BYTES.items(): + if content.startswith(magic): + # WebP 需要额外检查 + if magic == b"RIFF" and len(content) >= 12: + if content[8:12] != b"WEBP": + continue + return True + + return False + + +def validate_mime_type(content_type: Optional[str]) -> bool: + """ + 验证 MIME 类型 + + Args: + content_type: MIME 类型 + + Returns: + 是否为允许的 MIME 类型 + """ + if not content_type: + return False + # 处理带参数的 MIME 类型,如 "image/jpeg; charset=utf-8" + mime = content_type.split(";")[0].strip().lower() + return mime in ALLOWED_MIME_TYPES + + +def validate_uploaded_file( + content: bytes, + filename: Optional[str] = None, + content_type: Optional[str] = None, +) -> bytes: + """ + 综合验证上传的文件 + + Args: + content: 文件内容 + filename: 文件名 + content_type: MIME 类型 + + Returns: + 验证通过的文件内容 + + Raises: + FileTooLargeError: 文件过大 + UnsupportedFormatError: 不支持的格式 + InvalidImageError: 无效的图片 + """ + # 验证文件大小 + if not validate_file_size(content): + raise FileTooLargeError( + f"文件大小超过限制,最大允许 {MAX_FILE_SIZE // 1024 // 1024}MB" + ) + + # 验证扩展名 (如果提供) + if filename and not validate_file_extension(filename): + raise UnsupportedFormatError( + f"不支持的文件格式,允许的格式: {', '.join(ALLOWED_EXTENSIONS)}" + ) + + # 验证 MIME 类型 (如果提供) + if content_type and not validate_mime_type(content_type): + raise UnsupportedFormatError(f"不支持的 MIME 类型: {content_type}") + + # 验证文件魔数 + if not validate_image_magic_bytes(content): + raise InvalidImageError("文件内容不是有效的图片格式") + + return content + + +def decode_base64_image(base64_string: str) -> bytes: + """ + 解码 Base64 图片 + + Args: + base64_string: Base64 编码的图片字符串 + + Returns: + 解码后的图片字节 + + Raises: + InvalidImageError: Base64 解码失败或不是有效图片 + """ + # 移除可能的 Data URL 前缀 + if "," in base64_string: + base64_string = base64_string.split(",", 1)[1] + + # 移除空白字符 + base64_string = base64_string.strip() + + try: + content = base64.b64decode(base64_string) + except Exception as e: + raise InvalidImageError(f"Base64 解码失败: {str(e)}") + + # 验证解码后的内容 + return validate_uploaded_file(content) diff --git a/requirements.txt b/requirements.txt index cf485d4..0cf7318 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,14 @@ numpy>=1.24.0,<2.0 # Optional dependencies (for Chinese font rendering) Pillow>=10.0.0 + +# REST API dependencies +fastapi>=0.109.0 +uvicorn[standard]>=0.27.0 +python-multipart>=0.0.6 +pydantic>=2.5.0 + +# Testing dependencies +pytest>=7.4.0 +pytest-asyncio>=0.23.0 +httpx>=0.26.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e0319ea --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +""" +Vision-OCR API 测试模块 +""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..062f5e0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +""" +pytest 配置和共享 fixtures +""" + +import base64 +import os +import sys +from pathlib import Path +from typing import Generator +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +from fastapi.testclient import TestClient + +# 将项目根目录添加到 Python 路径 +PROJECT_ROOT = Path(__file__).parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +# 设置模型路径环境变量 +os.environ["PADDLEOCR_HOME"] = str(PROJECT_ROOT / "models") + + +@pytest.fixture(scope="session") +def mock_ocr_pipeline(): + """ + 创建模拟的 OCR Pipeline + + 避免在测试中加载真实的 OCR 模型 + """ + from ocr.engine import TextBlock + from ocr.pipeline import OCRResult + + mock_pipeline = MagicMock() + + # 模拟 OCR 结果 + mock_text_blocks = [ + TextBlock( + text="测试文本1", + confidence=0.95, + bbox=[[10, 10], [100, 10], [100, 30], [10, 30]], + bbox_offset=(0, 0), + ), + TextBlock( + text="测试文本2", + confidence=0.88, + bbox=[[10, 40], [150, 40], [150, 60], [10, 60]], + bbox_offset=(0, 0), + ), + ] + + mock_result = OCRResult( + image_index=1, + image_path=None, + timestamp=1704672000.0, + processing_time_ms=45.6, + text_blocks=mock_text_blocks, + roi_applied=False, + roi_rect=None, + ) + + mock_pipeline.process.return_value = mock_result + mock_pipeline.initialize.return_value = None + mock_pipeline._pipeline_config = MagicMock() + + return mock_pipeline + + +@pytest.fixture(scope="session") +def test_client(mock_ocr_pipeline) -> Generator[TestClient, None, None]: + """ + 创建测试客户端 + + 使用模拟的 OCR Pipeline 避免加载真实模型 + """ + # 延迟导入以确保环境变量已设置 + from api.main import app + + # 设置模拟的 pipeline + app.state.ocr_pipeline = mock_ocr_pipeline + app.state.model_loaded = True + + with TestClient(app) as client: + yield client + + +@pytest.fixture +def sample_image_bytes() -> bytes: + """ + 创建测试用的图片字节数据 + + 生成一个简单的 100x100 白色 JPEG 图片 + """ + import cv2 + + # 创建白色图片 + image = np.ones((100, 100, 3), dtype=np.uint8) * 255 + + # 编码为 JPEG + success, encoded = cv2.imencode(".jpg", image) + assert success, "图片编码失败" + + return encoded.tobytes() + + +@pytest.fixture +def sample_image_base64(sample_image_bytes) -> str: + """ + 创建测试用的 Base64 编码图片 + """ + return base64.b64encode(sample_image_bytes).decode("utf-8") + + +@pytest.fixture +def sample_png_bytes() -> bytes: + """ + 创建测试用的 PNG 图片字节数据 + """ + import cv2 + + image = np.ones((100, 100, 3), dtype=np.uint8) * 255 + success, encoded = cv2.imencode(".png", image) + assert success, "PNG 编码失败" + + return encoded.tobytes() + + +@pytest.fixture +def invalid_file_bytes() -> bytes: + """ + 创建无效的文件字节数据 (非图片) + """ + return b"This is not an image file content" diff --git a/tests/test_api_health.py b/tests/test_api_health.py new file mode 100644 index 0000000..1d4d5ce --- /dev/null +++ b/tests/test_api_health.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +""" +健康检查 API 测试 +""" + +import pytest +from fastapi.testclient import TestClient + + +class TestHealthEndpoints: + """健康检查端点测试""" + + def test_health_check_success(self, test_client: TestClient): + """测试健康检查端点 - 正常情况""" + response = test_client.get("/api/v1/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert data["model_loaded"] is True + assert "version" in data + + def test_readiness_check_success(self, test_client: TestClient): + """测试就绪检查端点 - 正常情况""" + response = test_client.get("/api/v1/health/ready") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert data["model_loaded"] is True + + def test_root_endpoint(self, test_client: TestClient): + """测试根路径端点""" + response = test_client.get("/") + + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Vision-OCR API" + assert "version" in data + assert "docs" in data + + +class TestHealthEndpointsUnhealthy: + """健康检查端点测试 - 模型未加载情况""" + + def test_health_check_model_not_loaded(self, test_client: TestClient): + """测试健康检查 - 模型未加载""" + # 临时设置模型未加载状态 + original_state = test_client.app.state.model_loaded + test_client.app.state.model_loaded = False + + try: + response = test_client.get("/api/v1/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "unhealthy" + assert data["model_loaded"] is False + finally: + # 恢复原始状态 + test_client.app.state.model_loaded = original_state + + def test_readiness_check_model_not_loaded(self, test_client: TestClient): + """测试就绪检查 - 模型未加载""" + original_state = test_client.app.state.model_loaded + test_client.app.state.model_loaded = False + + try: + response = test_client.get("/api/v1/health/ready") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "unhealthy" + assert data["model_loaded"] is False + finally: + test_client.app.state.model_loaded = original_state diff --git a/tests/test_api_ocr.py b/tests/test_api_ocr.py new file mode 100644 index 0000000..2a709cf --- /dev/null +++ b/tests/test_api_ocr.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +""" +OCR API 测试 +""" + +import io + +import pytest +from fastapi.testclient import TestClient + + +class TestOCRRecognizeMultipart: + """OCR 识别端点测试 (multipart/form-data)""" + + def test_recognize_success( + self, + test_client: TestClient, + sample_image_bytes: bytes, + ): + """测试 OCR 识别 - 正常情况""" + response = test_client.post( + "/api/v1/ocr/recognize", + files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")}, + data={"lang": "ch", "drop_score": "0.5"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["data"] is not None + assert "text_count" in data["data"] + assert "text_blocks" in data["data"] + assert "processing_time_ms" in data["data"] + + def test_recognize_with_roi( + self, + test_client: TestClient, + sample_image_bytes: bytes, + ): + """测试 OCR 识别 - 带 ROI 参数""" + response = test_client.post( + "/api/v1/ocr/recognize", + files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")}, + data={ + "lang": "ch", + "roi_x": "0.1", + "roi_y": "0.1", + "roi_width": "0.8", + "roi_height": "0.8", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + + def test_recognize_with_annotated_image( + self, + test_client: TestClient, + sample_image_bytes: bytes, + ): + """测试 OCR 识别 - 返回标注图片""" + response = test_client.post( + "/api/v1/ocr/recognize", + files={"file": ("test.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")}, + data={"return_annotated_image": "true"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + # 注意: 标注图片只有在有识别结果时才返回 + if data["data"]["text_count"] > 0: + assert data["data"]["annotated_image_base64"] is not None + + def test_recognize_png_image( + self, + test_client: TestClient, + sample_png_bytes: bytes, + ): + """测试 OCR 识别 - PNG 格式""" + response = test_client.post( + "/api/v1/ocr/recognize", + files={"file": ("test.png", io.BytesIO(sample_png_bytes), "image/png")}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + + def test_recognize_invalid_file( + self, + test_client: TestClient, + invalid_file_bytes: bytes, + ): + """测试 OCR 识别 - 无效文件""" + response = test_client.post( + "/api/v1/ocr/recognize", + files={ + "file": ( + "test.txt", + io.BytesIO(invalid_file_bytes), + "text/plain", + ) + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + assert data["error"] is not None + + def test_recognize_no_file(self, test_client: TestClient): + """测试 OCR 识别 - 未提供文件""" + response = test_client.post("/api/v1/ocr/recognize") + + assert response.status_code == 422 # Validation Error + + +class TestOCRRecognizeBase64: + """OCR 识别端点测试 (Base64 JSON)""" + + def test_recognize_base64_success( + self, + test_client: TestClient, + sample_image_base64: str, + ): + """测试 OCR 识别 (Base64) - 正常情况""" + response = test_client.post( + "/api/v1/ocr/recognize/base64", + json={ + "image_base64": sample_image_base64, + "lang": "ch", + "drop_score": 0.5, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["data"] is not None + + def test_recognize_base64_with_data_url( + self, + test_client: TestClient, + sample_image_base64: str, + ): + """测试 OCR 识别 (Base64) - Data URL 格式""" + data_url = f"data:image/jpeg;base64,{sample_image_base64}" + response = test_client.post( + "/api/v1/ocr/recognize/base64", + json={"image_base64": data_url}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + + def test_recognize_base64_with_roi( + self, + test_client: TestClient, + sample_image_base64: str, + ): + """测试 OCR 识别 (Base64) - 带 ROI 参数""" + response = test_client.post( + "/api/v1/ocr/recognize/base64", + json={ + "image_base64": sample_image_base64, + "roi": {"x": 0.1, "y": 0.1, "width": 0.8, "height": 0.8}, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + + def test_recognize_base64_invalid(self, test_client: TestClient): + """测试 OCR 识别 (Base64) - 无效 Base64""" + response = test_client.post( + "/api/v1/ocr/recognize/base64", + json={"image_base64": "not-valid-base64!!!"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + assert data["error"] is not None + + def test_recognize_base64_missing_field(self, test_client: TestClient): + """测试 OCR 识别 (Base64) - 缺少必填字段""" + response = test_client.post( + "/api/v1/ocr/recognize/base64", + json={"lang": "ch"}, # 缺少 image_base64 + ) + + assert response.status_code == 422 # Validation Error + + +class TestExpressMultipart: + """快递单解析端点测试 (multipart/form-data)""" + + def test_express_success( + self, + test_client: TestClient, + sample_image_bytes: bytes, + ): + """测试快递单解析 - 正常情况""" + response = test_client.post( + "/api/v1/ocr/express", + files={"file": ("express.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["data"] is not None + assert "express_info" in data["data"] + assert "merged_text" in data["data"] + assert "processing_time_ms" in data["data"] + + def test_express_info_structure( + self, + test_client: TestClient, + sample_image_bytes: bytes, + ): + """测试快递单解析 - 响应结构""" + response = test_client.post( + "/api/v1/ocr/express", + files={"file": ("express.jpg", io.BytesIO(sample_image_bytes), "image/jpeg")}, + ) + + assert response.status_code == 200 + data = response.json() + express_info = data["data"]["express_info"] + + # 验证结构完整性 + assert "tracking_number" in express_info + assert "sender" in express_info + assert "receiver" in express_info + assert "courier_company" in express_info + assert "confidence" in express_info + + # 验证 sender/receiver 结构 + assert "name" in express_info["sender"] + assert "phone" in express_info["sender"] + assert "address" in express_info["sender"] + + +class TestExpressBase64: + """快递单解析端点测试 (Base64 JSON)""" + + def test_express_base64_success( + self, + test_client: TestClient, + sample_image_base64: str, + ): + """测试快递单解析 (Base64) - 正常情况""" + response = test_client.post( + "/api/v1/ocr/express/base64", + json={"image_base64": sample_image_base64}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["data"] is not None + assert "express_info" in data["data"] + + +class TestSecurityValidation: + """安全验证测试""" + + def test_file_size_limit(self, test_client: TestClient): + """测试文件大小限制""" + # 创建一个超大的假文件 (11MB) + large_content = b"x" * (11 * 1024 * 1024) + + response = test_client.post( + "/api/v1/ocr/recognize", + files={"file": ("large.jpg", io.BytesIO(large_content), "image/jpeg")}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + assert "大小" in data["error"]["message"] or "size" in data["error"]["message"].lower() + + def test_invalid_extension( + self, + test_client: TestClient, + sample_image_bytes: bytes, + ): + """测试无效文件扩展名""" + response = test_client.post( + "/api/v1/ocr/recognize", + files={"file": ("test.exe", io.BytesIO(sample_image_bytes), "application/octet-stream")}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + + def test_magic_bytes_validation(self, test_client: TestClient): + """测试文件魔数验证""" + # 创建一个假的 jpg 文件 (扩展名正确但内容不是图片) + fake_jpg = b"This is not a real JPEG file" + + response = test_client.post( + "/api/v1/ocr/recognize", + files={"file": ("fake.jpg", io.BytesIO(fake_jpg), "image/jpeg")}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False