fix: 修复线程安全和 API 参数失效问题

- 修复 OCRPipeline 并发请求时的线程安全问题
- 修复 drop_score 参数未生效的问题
- 集中管理版本号,消除代码重复
- 添加图片尺寸验证和日志系统
master
蒋尚宏 4 weeks ago
parent 9ee92136d8
commit b3d52a0684

@ -4,6 +4,11 @@ Vision-OCR REST API 模块
提供 HTTP 接口访问 OCR 功能 提供 HTTP 接口访问 OCR 功能
""" """
from api.version import API_VERSION
# 导出版本号
__version__ = API_VERSION
from api.main import app from api.main import app
__all__ = ["app"] __all__ = ["app", "__version__"]

@ -102,13 +102,31 @@ def decode_image_bytes(content: bytes) -> np.ndarray:
OpenCV 图像 (BGR 格式) OpenCV 图像 (BGR 格式)
Raises: Raises:
InvalidImageError: 图片解码失败 InvalidImageError: 图片解码失败或尺寸不符合要求
""" """
# 图片尺寸限制
MIN_IMAGE_SIZE = 10 # 最小 10x10 像素
MAX_IMAGE_SIZE = 10000 # 最大 10000x10000 像素
try: try:
nparr = np.frombuffer(content, np.uint8) nparr = np.frombuffer(content, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None: if image is None:
raise InvalidImageError("图片解码失败") raise InvalidImageError("图片解码失败")
# 验证图片尺寸
height, width = image.shape[:2]
if width < MIN_IMAGE_SIZE or height < MIN_IMAGE_SIZE:
raise InvalidImageError(
f"图片尺寸过小,最小要求 {MIN_IMAGE_SIZE}x{MIN_IMAGE_SIZE} 像素,"
f"当前尺寸 {width}x{height}"
)
if width > MAX_IMAGE_SIZE or height > MAX_IMAGE_SIZE:
raise InvalidImageError(
f"图片尺寸过大,最大允许 {MAX_IMAGE_SIZE}x{MAX_IMAGE_SIZE} 像素,"
f"当前尺寸 {width}x{height}"
)
return image return image
except Exception as e: except Exception as e:
if isinstance(e, InvalidImageError): if isinstance(e, InvalidImageError):

@ -4,6 +4,7 @@ Vision-OCR REST API 主入口
基于 FastAPI OCR 服务 基于 FastAPI OCR 服务
""" """
import logging
import os import os
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -24,9 +25,18 @@ from fastapi.responses import JSONResponse
from api.exceptions import OCRAPIException from api.exceptions import OCRAPIException
from api.routes import health_router, ocr_router from api.routes import health_router, ocr_router
from api.version import API_VERSION
from ocr.pipeline import OCRPipeline from ocr.pipeline import OCRPipeline
from utils.config import OCRConfig, PipelineConfig from utils.config import OCRConfig, PipelineConfig
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("vision-ocr")
def _get_ocr_config() -> OCRConfig: def _get_ocr_config() -> OCRConfig:
""" """
@ -64,38 +74,37 @@ async def lifespan(app: FastAPI):
启动时预加载 OCR 模型关闭时清理资源 启动时预加载 OCR 模型关闭时清理资源
""" """
# 启动: 加载 OCR 模型 # 启动: 加载 OCR 模型
print("[INFO] 正在初始化 OCR API 服务...") logger.info("正在初始化 OCR API 服务...")
ocr_config = _get_ocr_config() ocr_config = _get_ocr_config()
if ocr_config.det_model_dir is None: if ocr_config.det_model_dir is None:
print("[WARN] 模型未在项目目录中找到") logger.warning("模型未在项目目录中找到")
print("[WARN] 对于 Windows 中文用户名用户,请先运行:") logger.warning("对于 Windows 中文用户名用户,请先运行: python download_models.py")
print("[WARN] python download_models.py") logger.info("回退到默认 PaddleOCR 模型路径...")
print("[INFO] 回退到默认 PaddleOCR 模型路径...")
pipeline_config = PipelineConfig() pipeline_config = PipelineConfig()
print("[INFO] 正在加载 OCR 模型...") logger.info("正在加载 OCR 模型...")
pipeline = OCRPipeline(ocr_config, pipeline_config) pipeline = OCRPipeline(ocr_config, pipeline_config)
pipeline.initialize() pipeline.initialize()
app.state.ocr_pipeline = pipeline app.state.ocr_pipeline = pipeline
app.state.model_loaded = True app.state.model_loaded = True
print("[INFO] OCR 模型加载完成,服务已就绪") logger.info("OCR 模型加载完成,服务已就绪")
yield yield
# 关闭: 清理资源 # 关闭: 清理资源
print("[INFO] 正在关闭 OCR API 服务...") logger.info("正在关闭 OCR API 服务...")
app.state.model_loaded = False app.state.model_loaded = False
print("[INFO] 服务已关闭") logger.info("服务已关闭")
# 创建 FastAPI 应用 # 创建 FastAPI 应用
app = FastAPI( app = FastAPI(
title="Vision-OCR API", title="Vision-OCR API",
description="基于 PaddleOCR 的图片 OCR 识别服务,支持通用文字识别和快递单解析", description="基于 PaddleOCR 的图片 OCR 识别服务,支持通用文字识别和快递单解析",
version="1.0.0", version=API_VERSION,
lifespan=lifespan, lifespan=lifespan,
docs_url="/docs", docs_url="/docs",
redoc_url="/redoc", redoc_url="/redoc",
@ -116,6 +125,12 @@ app.add_middleware(
@app.exception_handler(OCRAPIException) @app.exception_handler(OCRAPIException)
async def ocr_exception_handler(request: Request, exc: OCRAPIException): async def ocr_exception_handler(request: Request, exc: OCRAPIException):
"""处理 OCR API 自定义异常""" """处理 OCR API 自定义异常"""
logger.warning(
"OCR API 异常: %s - %s (路径: %s)",
type(exc).__name__,
exc.message,
request.url.path,
)
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=exc.status_code,
content={ content={
@ -132,6 +147,13 @@ async def ocr_exception_handler(request: Request, exc: OCRAPIException):
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception): async def general_exception_handler(request: Request, exc: Exception):
"""处理未捕获的异常""" """处理未捕获的异常"""
logger.error(
"未捕获的异常: %s - %s (路径: %s)",
type(exc).__name__,
str(exc),
request.url.path,
exc_info=True,
)
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content={ content={
@ -155,7 +177,7 @@ async def root():
"""API 根路径,返回基本信息""" """API 根路径,返回基本信息"""
return { return {
"name": "Vision-OCR API", "name": "Vision-OCR API",
"version": "1.0.0", "version": API_VERSION,
"docs": "/docs", "docs": "/docs",
"health": "/api/v1/health", "health": "/api/v1/health",
} }

@ -7,12 +7,10 @@
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from api.schemas.response import HealthResponse from api.schemas.response import HealthResponse
from api.version import API_VERSION
router = APIRouter(prefix="/health", tags=["健康检查"]) router = APIRouter(prefix="/health", tags=["健康检查"])
# API 版本
API_VERSION = "1.0.0"
@router.get( @router.get(
"", "",

@ -36,6 +36,22 @@ from utils.config import VisualizeConfig
router = APIRouter(prefix="/ocr", tags=["OCR 识别"]) router = APIRouter(prefix="/ocr", tags=["OCR 识别"])
# 模块级别的 Visualizer 实例(避免重复创建)
_visualizer: Optional[OCRVisualizer] = None
def _get_visualizer() -> OCRVisualizer:
"""
获取 Visualizer 实例单例模式避免重复创建
Returns:
OCRVisualizer 实例
"""
global _visualizer
if _visualizer is None:
_visualizer = OCRVisualizer(VisualizeConfig())
return _visualizer
def _convert_ocr_result_to_response( def _convert_ocr_result_to_response(
result: OCRResult, result: OCRResult,
@ -75,19 +91,63 @@ def _convert_ocr_result_to_response(
) )
def _convert_express_result_to_response(
result: OCRResult,
annotated_image_base64: Optional[str] = None,
) -> ExpressResultData:
"""
OCRResult 转换为快递单响应数据模型
Args:
result: OCR 处理结果
annotated_image_base64: 标注图片的 Base64 编码
Returns:
快递单响应数据模型
"""
# 解析快递单信息
express_info = result.parse_express()
merged_text = result.merge_text()
return ExpressResultData(
processing_time_ms=result.processing_time_ms,
express_info=ExpressInfoData(
tracking_number=express_info.tracking_number,
sender=ExpressPersonData(
name=express_info.sender_name,
phone=express_info.sender_phone,
address=express_info.sender_address,
),
receiver=ExpressPersonData(
name=express_info.receiver_name,
phone=express_info.receiver_phone,
address=express_info.receiver_address,
),
courier_company=express_info.courier_company,
confidence=express_info.confidence,
extra_fields=express_info.extra_fields,
raw_text=express_info.raw_text,
),
merged_text=merged_text,
annotated_image_base64=annotated_image_base64,
)
def _process_ocr( def _process_ocr(
image_bytes: bytes, image_bytes: bytes,
pipeline: OCRPipeline, pipeline: OCRPipeline,
roi: Optional[ROIParams] = None, roi: Optional[ROIParams] = None,
drop_score: Optional[float] = None,
return_annotated_image: bool = False, return_annotated_image: bool = False,
) -> tuple[OCRResult, Optional[str]]: ) -> tuple[OCRResult, Optional[str]]:
""" """
执行 OCR 处理 执行 OCR 处理线程安全
Args: Args:
image_bytes: 图片字节数据 image_bytes: 图片字节数据
pipeline: OCR 管道 pipeline: OCR 管道
roi: ROI 参数 roi: ROI 参数
drop_score: 置信度阈值低于此值的结果将被过滤
return_annotated_image: 是否返回标注图片 return_annotated_image: 是否返回标注图片
Returns: Returns:
@ -96,27 +156,23 @@ def _process_ocr(
# 解码图片 # 解码图片
image = decode_image_bytes(image_bytes) image = decode_image_bytes(image_bytes)
# 构建管道配置 # 构建管道配置(每次请求独立的配置,线程安全)
pipeline_config = build_pipeline_config(roi) pipeline_config = build_pipeline_config(roi)
# 临时更新管道配置
original_config = pipeline._pipeline_config
pipeline._pipeline_config = pipeline_config
try: try:
# 执行 OCR # 执行 OCR传递临时配置不修改共享状态
result = pipeline.process(image) result = pipeline.process(
image=image,
pipeline_config=pipeline_config,
drop_score=drop_score,
)
except Exception as e: except Exception as e:
raise OCRProcessingError(f"OCR 处理失败: {str(e)}") raise OCRProcessingError(f"OCR 处理失败: {str(e)}")
finally:
# 恢复原始配置
pipeline._pipeline_config = original_config
# 生成标注图片 # 生成标注图片
annotated_image_base64 = None annotated_image_base64 = None
if return_annotated_image and result.text_count > 0: if return_annotated_image and result.text_count > 0:
visualizer = OCRVisualizer(VisualizeConfig()) annotated = _get_visualizer().draw_result(image, result)
annotated = visualizer.draw_result(image, result)
annotated_image_base64 = encode_image_base64(annotated) annotated_image_base64 = encode_image_base64(annotated)
return result, annotated_image_base64 return result, annotated_image_base64
@ -153,6 +209,7 @@ async def recognize_multipart(
image_bytes=image_bytes, image_bytes=image_bytes,
pipeline=pipeline, pipeline=pipeline,
roi=params.get_roi(), roi=params.get_roi(),
drop_score=params.drop_score,
return_annotated_image=params.return_annotated_image, return_annotated_image=params.return_annotated_image,
) )
@ -198,38 +255,14 @@ async def express_multipart(
image_bytes=image_bytes, image_bytes=image_bytes,
pipeline=pipeline, pipeline=pipeline,
roi=params.get_roi(), roi=params.get_roi(),
drop_score=params.drop_score,
return_annotated_image=params.return_annotated_image, return_annotated_image=params.return_annotated_image,
) )
# 解析快递单信息
express_info = result.parse_express()
merged_text = result.merge_text()
# 构建响应 # 构建响应
return ExpressResponse( return ExpressResponse(
success=True, success=True,
data=ExpressResultData( data=_convert_express_result_to_response(result, annotated_base64),
processing_time_ms=result.processing_time_ms,
express_info=ExpressInfoData(
tracking_number=express_info.tracking_number,
sender=ExpressPersonData(
name=express_info.sender_name,
phone=express_info.sender_phone,
address=express_info.sender_address,
),
receiver=ExpressPersonData(
name=express_info.receiver_name,
phone=express_info.receiver_phone,
address=express_info.receiver_address,
),
courier_company=express_info.courier_company,
confidence=express_info.confidence,
extra_fields=express_info.extra_fields,
raw_text=express_info.raw_text,
),
merged_text=merged_text,
annotated_image_base64=annotated_base64,
),
) )
except Exception as e: except Exception as e:
@ -272,6 +305,7 @@ async def recognize_base64(
image_bytes=image_bytes, image_bytes=image_bytes,
pipeline=pipeline, pipeline=pipeline,
roi=body.roi, roi=body.roi,
drop_score=body.drop_score,
return_annotated_image=body.return_annotated_image, return_annotated_image=body.return_annotated_image,
) )
@ -316,38 +350,14 @@ async def express_base64(
image_bytes=image_bytes, image_bytes=image_bytes,
pipeline=pipeline, pipeline=pipeline,
roi=body.roi, roi=body.roi,
drop_score=body.drop_score,
return_annotated_image=body.return_annotated_image, return_annotated_image=body.return_annotated_image,
) )
# 解析快递单信息
express_info = result.parse_express()
merged_text = result.merge_text()
# 构建响应 # 构建响应
return ExpressResponse( return ExpressResponse(
success=True, success=True,
data=ExpressResultData( data=_convert_express_result_to_response(result, annotated_base64),
processing_time_ms=result.processing_time_ms,
express_info=ExpressInfoData(
tracking_number=express_info.tracking_number,
sender=ExpressPersonData(
name=express_info.sender_name,
phone=express_info.sender_phone,
address=express_info.sender_address,
),
receiver=ExpressPersonData(
name=express_info.receiver_name,
phone=express_info.receiver_phone,
address=express_info.receiver_address,
),
courier_company=express_info.courier_company,
confidence=express_info.confidence,
extra_fields=express_info.extra_fields,
raw_text=express_info.raw_text,
),
merged_text=merged_text,
annotated_image_base64=annotated_base64,
),
) )
except Exception as e: except Exception as e:

@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
"""
版本号管理模块
统一管理 API 版本号避免分散定义
"""
# API 版本号 - 修改此处即可更新所有引用
API_VERSION = "1.0.0"

@ -187,7 +187,7 @@ class OCRPipeline:
image: np.ndarray image: np.ndarray
) -> tuple: ) -> tuple:
""" """
应用 ROI 裁剪 应用 ROI 裁剪使用默认配置
Args: Args:
image: 原始图片 image: 原始图片
@ -195,7 +195,24 @@ class OCRPipeline:
Returns: Returns:
(裁剪后的图像, ROI 偏移量, ROI 矩形) (裁剪后的图像, ROI 偏移量, ROI 矩形)
""" """
roi_config = self._pipeline_config.roi return self._apply_roi_with_config(image, self._pipeline_config)
def _apply_roi_with_config(
self,
image: np.ndarray,
config: PipelineConfig
) -> tuple:
"""
应用 ROI 裁剪使用指定配置线程安全
Args:
image: 原始图片
config: 管道配置
Returns:
(裁剪后的图像, ROI 偏移量, ROI 矩形)
"""
roi_config = config.roi
if not roi_config.enabled: if not roi_config.enabled:
return image, (0, 0), None return image, (0, 0), None
@ -245,7 +262,9 @@ class OCRPipeline:
def process( def process(
self, self,
image: np.ndarray, image: np.ndarray,
image_path: Optional[str] = None image_path: Optional[str] = None,
pipeline_config: Optional[PipelineConfig] = None,
drop_score: Optional[float] = None,
) -> OCRResult: ) -> OCRResult:
""" """
处理单张图片 处理单张图片
@ -253,6 +272,8 @@ class OCRPipeline:
Args: Args:
image: 输入图片 (numpy array, BGR 格式) image: 输入图片 (numpy array, BGR 格式)
image_path: 图片路径可选用于结果记录 image_path: 图片路径可选用于结果记录
pipeline_config: 临时管道配置可选用于单次请求的配置覆盖线程安全
drop_score: 置信度阈值可选用于过滤低置信度结果
Returns: Returns:
OCR 结果 OCR 结果
@ -260,8 +281,11 @@ class OCRPipeline:
self._image_counter += 1 self._image_counter += 1
start_time = time.time() start_time = time.time()
# 使用临时配置或默认配置(线程安全:不修改共享状态)
config = pipeline_config if pipeline_config is not None else self._pipeline_config
# 应用 ROI 裁剪 # 应用 ROI 裁剪
cropped_image, roi_offset, roi_rect = self._apply_roi(image) cropped_image, roi_offset, roi_rect = self._apply_roi_with_config(image, config)
# 图片预处理 # 图片预处理
processed_image = self._preprocess_image(cropped_image) processed_image = self._preprocess_image(cropped_image)
@ -269,6 +293,13 @@ class OCRPipeline:
# 执行 OCR # 执行 OCR
text_blocks = self._engine.recognize(processed_image, roi_offset) text_blocks = self._engine.recognize(processed_image, roi_offset)
# 应用置信度过滤(如果指定了 drop_score
if drop_score is not None:
text_blocks = [
block for block in text_blocks
if block.confidence >= drop_score
]
# 计算处理耗时 # 计算处理耗时
processing_time_ms = (time.time() - start_time) * 1000 processing_time_ms = (time.time() - start_time) * 1000
@ -279,7 +310,7 @@ class OCRPipeline:
timestamp=time.time(), timestamp=time.time(),
processing_time_ms=processing_time_ms, processing_time_ms=processing_time_ms,
text_blocks=text_blocks, text_blocks=text_blocks,
roi_applied=self._pipeline_config.roi.enabled, roi_applied=config.roi.enabled,
roi_rect=roi_rect roi_rect=roi_rect
) )

Loading…
Cancel
Save