fix: 修复线程安全和 API 参数失效问题

- 修复 OCRPipeline 并发请求时的线程安全问题
- 修复 drop_score 参数未生效的问题
- 集中管理版本号,消除代码重复
- 添加图片尺寸验证和日志系统
master
蒋尚宏 4 weeks ago
parent 9ee92136d8
commit b3d52a0684

@ -4,6 +4,11 @@ Vision-OCR REST API 模块
提供 HTTP 接口访问 OCR 功能
"""
from api.version import API_VERSION
# 导出版本号
__version__ = API_VERSION
from api.main import app
__all__ = ["app"]
__all__ = ["app", "__version__"]

@ -102,13 +102,31 @@ def decode_image_bytes(content: bytes) -> np.ndarray:
OpenCV 图像 (BGR 格式)
Raises:
InvalidImageError: 图片解码失败
InvalidImageError: 图片解码失败或尺寸不符合要求
"""
# 图片尺寸限制
MIN_IMAGE_SIZE = 10 # 最小 10x10 像素
MAX_IMAGE_SIZE = 10000 # 最大 10000x10000 像素
try:
nparr = np.frombuffer(content, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None:
raise InvalidImageError("图片解码失败")
# 验证图片尺寸
height, width = image.shape[:2]
if width < MIN_IMAGE_SIZE or height < MIN_IMAGE_SIZE:
raise InvalidImageError(
f"图片尺寸过小,最小要求 {MIN_IMAGE_SIZE}x{MIN_IMAGE_SIZE} 像素,"
f"当前尺寸 {width}x{height}"
)
if width > MAX_IMAGE_SIZE or height > MAX_IMAGE_SIZE:
raise InvalidImageError(
f"图片尺寸过大,最大允许 {MAX_IMAGE_SIZE}x{MAX_IMAGE_SIZE} 像素,"
f"当前尺寸 {width}x{height}"
)
return image
except Exception as e:
if isinstance(e, InvalidImageError):

@ -4,6 +4,7 @@ Vision-OCR REST API 主入口
基于 FastAPI OCR 服务
"""
import logging
import os
import sys
from contextlib import asynccontextmanager
@ -24,9 +25,18 @@ from fastapi.responses import JSONResponse
from api.exceptions import OCRAPIException
from api.routes import health_router, ocr_router
from api.version import API_VERSION
from ocr.pipeline import OCRPipeline
from utils.config import OCRConfig, PipelineConfig
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("vision-ocr")
def _get_ocr_config() -> OCRConfig:
"""
@ -64,38 +74,37 @@ async def lifespan(app: FastAPI):
启动时预加载 OCR 模型关闭时清理资源
"""
# 启动: 加载 OCR 模型
print("[INFO] 正在初始化 OCR API 服务...")
logger.info("正在初始化 OCR API 服务...")
ocr_config = _get_ocr_config()
if ocr_config.det_model_dir is None:
print("[WARN] 模型未在项目目录中找到")
print("[WARN] 对于 Windows 中文用户名用户,请先运行:")
print("[WARN] python download_models.py")
print("[INFO] 回退到默认 PaddleOCR 模型路径...")
logger.warning("模型未在项目目录中找到")
logger.warning("对于 Windows 中文用户名用户,请先运行: python download_models.py")
logger.info("回退到默认 PaddleOCR 模型路径...")
pipeline_config = PipelineConfig()
print("[INFO] 正在加载 OCR 模型...")
logger.info("正在加载 OCR 模型...")
pipeline = OCRPipeline(ocr_config, pipeline_config)
pipeline.initialize()
app.state.ocr_pipeline = pipeline
app.state.model_loaded = True
print("[INFO] OCR 模型加载完成,服务已就绪")
logger.info("OCR 模型加载完成,服务已就绪")
yield
# 关闭: 清理资源
print("[INFO] 正在关闭 OCR API 服务...")
logger.info("正在关闭 OCR API 服务...")
app.state.model_loaded = False
print("[INFO] 服务已关闭")
logger.info("服务已关闭")
# 创建 FastAPI 应用
app = FastAPI(
title="Vision-OCR API",
description="基于 PaddleOCR 的图片 OCR 识别服务,支持通用文字识别和快递单解析",
version="1.0.0",
version=API_VERSION,
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
@ -116,6 +125,12 @@ app.add_middleware(
@app.exception_handler(OCRAPIException)
async def ocr_exception_handler(request: Request, exc: OCRAPIException):
"""处理 OCR API 自定义异常"""
logger.warning(
"OCR API 异常: %s - %s (路径: %s)",
type(exc).__name__,
exc.message,
request.url.path,
)
return JSONResponse(
status_code=exc.status_code,
content={
@ -132,6 +147,13 @@ async def ocr_exception_handler(request: Request, exc: OCRAPIException):
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""处理未捕获的异常"""
logger.error(
"未捕获的异常: %s - %s (路径: %s)",
type(exc).__name__,
str(exc),
request.url.path,
exc_info=True,
)
return JSONResponse(
status_code=500,
content={
@ -155,7 +177,7 @@ async def root():
"""API 根路径,返回基本信息"""
return {
"name": "Vision-OCR API",
"version": "1.0.0",
"version": API_VERSION,
"docs": "/docs",
"health": "/api/v1/health",
}

@ -7,12 +7,10 @@
from fastapi import APIRouter, Request
from api.schemas.response import HealthResponse
from api.version import API_VERSION
router = APIRouter(prefix="/health", tags=["健康检查"])
# API 版本
API_VERSION = "1.0.0"
@router.get(
"",

@ -36,6 +36,22 @@ from utils.config import VisualizeConfig
router = APIRouter(prefix="/ocr", tags=["OCR 识别"])
# 模块级别的 Visualizer 实例(避免重复创建)
_visualizer: Optional[OCRVisualizer] = None
def _get_visualizer() -> OCRVisualizer:
"""
获取 Visualizer 实例单例模式避免重复创建
Returns:
OCRVisualizer 实例
"""
global _visualizer
if _visualizer is None:
_visualizer = OCRVisualizer(VisualizeConfig())
return _visualizer
def _convert_ocr_result_to_response(
result: OCRResult,
@ -75,19 +91,63 @@ def _convert_ocr_result_to_response(
)
def _convert_express_result_to_response(
result: OCRResult,
annotated_image_base64: Optional[str] = None,
) -> ExpressResultData:
"""
OCRResult 转换为快递单响应数据模型
Args:
result: OCR 处理结果
annotated_image_base64: 标注图片的 Base64 编码
Returns:
快递单响应数据模型
"""
# 解析快递单信息
express_info = result.parse_express()
merged_text = result.merge_text()
return ExpressResultData(
processing_time_ms=result.processing_time_ms,
express_info=ExpressInfoData(
tracking_number=express_info.tracking_number,
sender=ExpressPersonData(
name=express_info.sender_name,
phone=express_info.sender_phone,
address=express_info.sender_address,
),
receiver=ExpressPersonData(
name=express_info.receiver_name,
phone=express_info.receiver_phone,
address=express_info.receiver_address,
),
courier_company=express_info.courier_company,
confidence=express_info.confidence,
extra_fields=express_info.extra_fields,
raw_text=express_info.raw_text,
),
merged_text=merged_text,
annotated_image_base64=annotated_image_base64,
)
def _process_ocr(
image_bytes: bytes,
pipeline: OCRPipeline,
roi: Optional[ROIParams] = None,
drop_score: Optional[float] = None,
return_annotated_image: bool = False,
) -> tuple[OCRResult, Optional[str]]:
"""
执行 OCR 处理
执行 OCR 处理线程安全
Args:
image_bytes: 图片字节数据
pipeline: OCR 管道
roi: ROI 参数
drop_score: 置信度阈值低于此值的结果将被过滤
return_annotated_image: 是否返回标注图片
Returns:
@ -96,27 +156,23 @@ def _process_ocr(
# 解码图片
image = decode_image_bytes(image_bytes)
# 构建管道配置
# 构建管道配置(每次请求独立的配置,线程安全)
pipeline_config = build_pipeline_config(roi)
# 临时更新管道配置
original_config = pipeline._pipeline_config
pipeline._pipeline_config = pipeline_config
try:
# 执行 OCR
result = pipeline.process(image)
# 执行 OCR传递临时配置不修改共享状态
result = pipeline.process(
image=image,
pipeline_config=pipeline_config,
drop_score=drop_score,
)
except Exception as e:
raise OCRProcessingError(f"OCR 处理失败: {str(e)}")
finally:
# 恢复原始配置
pipeline._pipeline_config = original_config
# 生成标注图片
annotated_image_base64 = None
if return_annotated_image and result.text_count > 0:
visualizer = OCRVisualizer(VisualizeConfig())
annotated = visualizer.draw_result(image, result)
annotated = _get_visualizer().draw_result(image, result)
annotated_image_base64 = encode_image_base64(annotated)
return result, annotated_image_base64
@ -153,6 +209,7 @@ async def recognize_multipart(
image_bytes=image_bytes,
pipeline=pipeline,
roi=params.get_roi(),
drop_score=params.drop_score,
return_annotated_image=params.return_annotated_image,
)
@ -198,38 +255,14 @@ async def express_multipart(
image_bytes=image_bytes,
pipeline=pipeline,
roi=params.get_roi(),
drop_score=params.drop_score,
return_annotated_image=params.return_annotated_image,
)
# 解析快递单信息
express_info = result.parse_express()
merged_text = result.merge_text()
# 构建响应
return ExpressResponse(
success=True,
data=ExpressResultData(
processing_time_ms=result.processing_time_ms,
express_info=ExpressInfoData(
tracking_number=express_info.tracking_number,
sender=ExpressPersonData(
name=express_info.sender_name,
phone=express_info.sender_phone,
address=express_info.sender_address,
),
receiver=ExpressPersonData(
name=express_info.receiver_name,
phone=express_info.receiver_phone,
address=express_info.receiver_address,
),
courier_company=express_info.courier_company,
confidence=express_info.confidence,
extra_fields=express_info.extra_fields,
raw_text=express_info.raw_text,
),
merged_text=merged_text,
annotated_image_base64=annotated_base64,
),
data=_convert_express_result_to_response(result, annotated_base64),
)
except Exception as e:
@ -272,6 +305,7 @@ async def recognize_base64(
image_bytes=image_bytes,
pipeline=pipeline,
roi=body.roi,
drop_score=body.drop_score,
return_annotated_image=body.return_annotated_image,
)
@ -316,38 +350,14 @@ async def express_base64(
image_bytes=image_bytes,
pipeline=pipeline,
roi=body.roi,
drop_score=body.drop_score,
return_annotated_image=body.return_annotated_image,
)
# 解析快递单信息
express_info = result.parse_express()
merged_text = result.merge_text()
# 构建响应
return ExpressResponse(
success=True,
data=ExpressResultData(
processing_time_ms=result.processing_time_ms,
express_info=ExpressInfoData(
tracking_number=express_info.tracking_number,
sender=ExpressPersonData(
name=express_info.sender_name,
phone=express_info.sender_phone,
address=express_info.sender_address,
),
receiver=ExpressPersonData(
name=express_info.receiver_name,
phone=express_info.receiver_phone,
address=express_info.receiver_address,
),
courier_company=express_info.courier_company,
confidence=express_info.confidence,
extra_fields=express_info.extra_fields,
raw_text=express_info.raw_text,
),
merged_text=merged_text,
annotated_image_base64=annotated_base64,
),
data=_convert_express_result_to_response(result, annotated_base64),
)
except Exception as e:

@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
"""
版本号管理模块
统一管理 API 版本号避免分散定义
"""
# API 版本号 - 修改此处即可更新所有引用
API_VERSION = "1.0.0"

@ -187,7 +187,7 @@ class OCRPipeline:
image: np.ndarray
) -> tuple:
"""
应用 ROI 裁剪
应用 ROI 裁剪使用默认配置
Args:
image: 原始图片
@ -195,7 +195,24 @@ class OCRPipeline:
Returns:
(裁剪后的图像, ROI 偏移量, ROI 矩形)
"""
roi_config = self._pipeline_config.roi
return self._apply_roi_with_config(image, self._pipeline_config)
def _apply_roi_with_config(
self,
image: np.ndarray,
config: PipelineConfig
) -> tuple:
"""
应用 ROI 裁剪使用指定配置线程安全
Args:
image: 原始图片
config: 管道配置
Returns:
(裁剪后的图像, ROI 偏移量, ROI 矩形)
"""
roi_config = config.roi
if not roi_config.enabled:
return image, (0, 0), None
@ -245,7 +262,9 @@ class OCRPipeline:
def process(
self,
image: np.ndarray,
image_path: Optional[str] = None
image_path: Optional[str] = None,
pipeline_config: Optional[PipelineConfig] = None,
drop_score: Optional[float] = None,
) -> OCRResult:
"""
处理单张图片
@ -253,6 +272,8 @@ class OCRPipeline:
Args:
image: 输入图片 (numpy array, BGR 格式)
image_path: 图片路径可选用于结果记录
pipeline_config: 临时管道配置可选用于单次请求的配置覆盖线程安全
drop_score: 置信度阈值可选用于过滤低置信度结果
Returns:
OCR 结果
@ -260,8 +281,11 @@ class OCRPipeline:
self._image_counter += 1
start_time = time.time()
# 使用临时配置或默认配置(线程安全:不修改共享状态)
config = pipeline_config if pipeline_config is not None else self._pipeline_config
# 应用 ROI 裁剪
cropped_image, roi_offset, roi_rect = self._apply_roi(image)
cropped_image, roi_offset, roi_rect = self._apply_roi_with_config(image, config)
# 图片预处理
processed_image = self._preprocess_image(cropped_image)
@ -269,6 +293,13 @@ class OCRPipeline:
# 执行 OCR
text_blocks = self._engine.recognize(processed_image, roi_offset)
# 应用置信度过滤(如果指定了 drop_score
if drop_score is not None:
text_blocks = [
block for block in text_blocks
if block.confidence >= drop_score
]
# 计算处理耗时
processing_time_ms = (time.time() - start_time) * 1000
@ -279,7 +310,7 @@ class OCRPipeline:
timestamp=time.time(),
processing_time_ms=processing_time_ms,
text_blocks=text_blocks,
roi_applied=self._pipeline_config.roi.enabled,
roi_applied=config.roi.enabled,
roi_rect=roi_rect
)

Loading…
Cancel
Save