diff --git a/api/__init__.py b/api/__init__.py index 96f020c..390d669 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -4,6 +4,11 @@ Vision-OCR REST API 模块 提供 HTTP 接口访问 OCR 功能 """ +from api.version import API_VERSION + +# 导出版本号 +__version__ = API_VERSION + from api.main import app -__all__ = ["app"] +__all__ = ["app", "__version__"] diff --git a/api/dependencies.py b/api/dependencies.py index 88614b4..801f303 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -102,13 +102,31 @@ def decode_image_bytes(content: bytes) -> np.ndarray: OpenCV 图像 (BGR 格式) Raises: - InvalidImageError: 图片解码失败 + InvalidImageError: 图片解码失败或尺寸不符合要求 """ + # 图片尺寸限制 + MIN_IMAGE_SIZE = 10 # 最小 10x10 像素 + MAX_IMAGE_SIZE = 10000 # 最大 10000x10000 像素 + try: nparr = np.frombuffer(content, np.uint8) image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if image is None: raise InvalidImageError("图片解码失败") + + # 验证图片尺寸 + height, width = image.shape[:2] + if width < MIN_IMAGE_SIZE or height < MIN_IMAGE_SIZE: + raise InvalidImageError( + f"图片尺寸过小,最小要求 {MIN_IMAGE_SIZE}x{MIN_IMAGE_SIZE} 像素," + f"当前尺寸 {width}x{height}" + ) + if width > MAX_IMAGE_SIZE or height > MAX_IMAGE_SIZE: + raise InvalidImageError( + f"图片尺寸过大,最大允许 {MAX_IMAGE_SIZE}x{MAX_IMAGE_SIZE} 像素," + f"当前尺寸 {width}x{height}" + ) + return image except Exception as e: if isinstance(e, InvalidImageError): diff --git a/api/main.py b/api/main.py index 6ba1f02..5250c58 100644 --- a/api/main.py +++ b/api/main.py @@ -4,6 +4,7 @@ Vision-OCR REST API 主入口 基于 FastAPI 的 OCR 服务 """ +import logging import os import sys from contextlib import asynccontextmanager @@ -24,9 +25,18 @@ from fastapi.responses import JSONResponse from api.exceptions import OCRAPIException from api.routes import health_router, ocr_router +from api.version import API_VERSION from ocr.pipeline import OCRPipeline from utils.config import OCRConfig, PipelineConfig +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger("vision-ocr") + def _get_ocr_config() -> OCRConfig: """ @@ -64,38 +74,37 @@ async def lifespan(app: FastAPI): 启动时预加载 OCR 模型,关闭时清理资源 """ # 启动: 加载 OCR 模型 - print("[INFO] 正在初始化 OCR API 服务...") + logger.info("正在初始化 OCR API 服务...") ocr_config = _get_ocr_config() if ocr_config.det_model_dir is None: - print("[WARN] 模型未在项目目录中找到") - print("[WARN] 对于 Windows 中文用户名用户,请先运行:") - print("[WARN] python download_models.py") - print("[INFO] 回退到默认 PaddleOCR 模型路径...") + logger.warning("模型未在项目目录中找到") + logger.warning("对于 Windows 中文用户名用户,请先运行: python download_models.py") + logger.info("回退到默认 PaddleOCR 模型路径...") pipeline_config = PipelineConfig() - print("[INFO] 正在加载 OCR 模型...") + logger.info("正在加载 OCR 模型...") pipeline = OCRPipeline(ocr_config, pipeline_config) pipeline.initialize() app.state.ocr_pipeline = pipeline app.state.model_loaded = True - print("[INFO] OCR 模型加载完成,服务已就绪") + logger.info("OCR 模型加载完成,服务已就绪") yield # 关闭: 清理资源 - print("[INFO] 正在关闭 OCR API 服务...") + logger.info("正在关闭 OCR API 服务...") app.state.model_loaded = False - print("[INFO] 服务已关闭") + logger.info("服务已关闭") # 创建 FastAPI 应用 app = FastAPI( title="Vision-OCR API", description="基于 PaddleOCR 的图片 OCR 识别服务,支持通用文字识别和快递单解析", - version="1.0.0", + version=API_VERSION, lifespan=lifespan, docs_url="/docs", redoc_url="/redoc", @@ -116,6 +125,12 @@ app.add_middleware( @app.exception_handler(OCRAPIException) async def ocr_exception_handler(request: Request, exc: OCRAPIException): """处理 OCR API 自定义异常""" + logger.warning( + "OCR API 异常: %s - %s (路径: %s)", + type(exc).__name__, + exc.message, + request.url.path, + ) return JSONResponse( status_code=exc.status_code, content={ @@ -132,6 +147,13 @@ async def ocr_exception_handler(request: Request, exc: OCRAPIException): @app.exception_handler(Exception) async def general_exception_handler(request: Request, exc: Exception): """处理未捕获的异常""" + logger.error( + "未捕获的异常: %s - %s (路径: %s)", + type(exc).__name__, + str(exc), + request.url.path, + exc_info=True, + ) return JSONResponse( status_code=500, content={ @@ -155,7 +177,7 @@ async def root(): """API 根路径,返回基本信息""" return { "name": "Vision-OCR API", - "version": "1.0.0", + "version": API_VERSION, "docs": "/docs", "health": "/api/v1/health", } diff --git a/api/routes/health.py b/api/routes/health.py index ebb9119..c0dd12c 100644 --- a/api/routes/health.py +++ b/api/routes/health.py @@ -7,12 +7,10 @@ from fastapi import APIRouter, Request from api.schemas.response import HealthResponse +from api.version import API_VERSION router = APIRouter(prefix="/health", tags=["健康检查"]) -# API 版本 -API_VERSION = "1.0.0" - @router.get( "", diff --git a/api/routes/ocr.py b/api/routes/ocr.py index 199d0da..b4b3c8e 100644 --- a/api/routes/ocr.py +++ b/api/routes/ocr.py @@ -36,6 +36,22 @@ from utils.config import VisualizeConfig router = APIRouter(prefix="/ocr", tags=["OCR 识别"]) +# 模块级别的 Visualizer 实例(避免重复创建) +_visualizer: Optional[OCRVisualizer] = None + + +def _get_visualizer() -> OCRVisualizer: + """ + 获取 Visualizer 实例(单例模式,避免重复创建) + + Returns: + OCRVisualizer 实例 + """ + global _visualizer + if _visualizer is None: + _visualizer = OCRVisualizer(VisualizeConfig()) + return _visualizer + def _convert_ocr_result_to_response( result: OCRResult, @@ -75,19 +91,63 @@ def _convert_ocr_result_to_response( ) +def _convert_express_result_to_response( + result: OCRResult, + annotated_image_base64: Optional[str] = None, +) -> ExpressResultData: + """ + 将 OCRResult 转换为快递单响应数据模型 + + Args: + result: OCR 处理结果 + annotated_image_base64: 标注图片的 Base64 编码 + + Returns: + 快递单响应数据模型 + """ + # 解析快递单信息 + express_info = result.parse_express() + merged_text = result.merge_text() + + return ExpressResultData( + processing_time_ms=result.processing_time_ms, + express_info=ExpressInfoData( + tracking_number=express_info.tracking_number, + sender=ExpressPersonData( + name=express_info.sender_name, + phone=express_info.sender_phone, + address=express_info.sender_address, + ), + receiver=ExpressPersonData( + name=express_info.receiver_name, + phone=express_info.receiver_phone, + address=express_info.receiver_address, + ), + courier_company=express_info.courier_company, + confidence=express_info.confidence, + extra_fields=express_info.extra_fields, + raw_text=express_info.raw_text, + ), + merged_text=merged_text, + annotated_image_base64=annotated_image_base64, + ) + + def _process_ocr( image_bytes: bytes, pipeline: OCRPipeline, roi: Optional[ROIParams] = None, + drop_score: Optional[float] = None, return_annotated_image: bool = False, ) -> tuple[OCRResult, Optional[str]]: """ - 执行 OCR 处理 + 执行 OCR 处理(线程安全) Args: image_bytes: 图片字节数据 pipeline: OCR 管道 roi: ROI 参数 + drop_score: 置信度阈值,低于此值的结果将被过滤 return_annotated_image: 是否返回标注图片 Returns: @@ -96,27 +156,23 @@ def _process_ocr( # 解码图片 image = decode_image_bytes(image_bytes) - # 构建管道配置 + # 构建管道配置(每次请求独立的配置,线程安全) pipeline_config = build_pipeline_config(roi) - # 临时更新管道配置 - original_config = pipeline._pipeline_config - pipeline._pipeline_config = pipeline_config - try: - # 执行 OCR - result = pipeline.process(image) + # 执行 OCR(传递临时配置,不修改共享状态) + result = pipeline.process( + image=image, + pipeline_config=pipeline_config, + drop_score=drop_score, + ) except Exception as e: raise OCRProcessingError(f"OCR 处理失败: {str(e)}") - finally: - # 恢复原始配置 - pipeline._pipeline_config = original_config # 生成标注图片 annotated_image_base64 = None if return_annotated_image and result.text_count > 0: - visualizer = OCRVisualizer(VisualizeConfig()) - annotated = visualizer.draw_result(image, result) + annotated = _get_visualizer().draw_result(image, result) annotated_image_base64 = encode_image_base64(annotated) return result, annotated_image_base64 @@ -153,6 +209,7 @@ async def recognize_multipart( image_bytes=image_bytes, pipeline=pipeline, roi=params.get_roi(), + drop_score=params.drop_score, return_annotated_image=params.return_annotated_image, ) @@ -198,38 +255,14 @@ async def express_multipart( image_bytes=image_bytes, pipeline=pipeline, roi=params.get_roi(), + drop_score=params.drop_score, return_annotated_image=params.return_annotated_image, ) - # 解析快递单信息 - express_info = result.parse_express() - merged_text = result.merge_text() - # 构建响应 return ExpressResponse( success=True, - data=ExpressResultData( - processing_time_ms=result.processing_time_ms, - express_info=ExpressInfoData( - tracking_number=express_info.tracking_number, - sender=ExpressPersonData( - name=express_info.sender_name, - phone=express_info.sender_phone, - address=express_info.sender_address, - ), - receiver=ExpressPersonData( - name=express_info.receiver_name, - phone=express_info.receiver_phone, - address=express_info.receiver_address, - ), - courier_company=express_info.courier_company, - confidence=express_info.confidence, - extra_fields=express_info.extra_fields, - raw_text=express_info.raw_text, - ), - merged_text=merged_text, - annotated_image_base64=annotated_base64, - ), + data=_convert_express_result_to_response(result, annotated_base64), ) except Exception as e: @@ -272,6 +305,7 @@ async def recognize_base64( image_bytes=image_bytes, pipeline=pipeline, roi=body.roi, + drop_score=body.drop_score, return_annotated_image=body.return_annotated_image, ) @@ -316,38 +350,14 @@ async def express_base64( image_bytes=image_bytes, pipeline=pipeline, roi=body.roi, + drop_score=body.drop_score, return_annotated_image=body.return_annotated_image, ) - # 解析快递单信息 - express_info = result.parse_express() - merged_text = result.merge_text() - # 构建响应 return ExpressResponse( success=True, - data=ExpressResultData( - processing_time_ms=result.processing_time_ms, - express_info=ExpressInfoData( - tracking_number=express_info.tracking_number, - sender=ExpressPersonData( - name=express_info.sender_name, - phone=express_info.sender_phone, - address=express_info.sender_address, - ), - receiver=ExpressPersonData( - name=express_info.receiver_name, - phone=express_info.receiver_phone, - address=express_info.receiver_address, - ), - courier_company=express_info.courier_company, - confidence=express_info.confidence, - extra_fields=express_info.extra_fields, - raw_text=express_info.raw_text, - ), - merged_text=merged_text, - annotated_image_base64=annotated_base64, - ), + data=_convert_express_result_to_response(result, annotated_base64), ) except Exception as e: diff --git a/api/version.py b/api/version.py new file mode 100644 index 0000000..3b358e2 --- /dev/null +++ b/api/version.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +""" +版本号管理模块 +统一管理 API 版本号,避免分散定义 +""" + +# API 版本号 - 修改此处即可更新所有引用 +API_VERSION = "1.0.0" diff --git a/ocr/pipeline.py b/ocr/pipeline.py index bca7031..a1e3472 100644 --- a/ocr/pipeline.py +++ b/ocr/pipeline.py @@ -187,7 +187,7 @@ class OCRPipeline: image: np.ndarray ) -> tuple: """ - 应用 ROI 裁剪 + 应用 ROI 裁剪(使用默认配置) Args: image: 原始图片 @@ -195,7 +195,24 @@ class OCRPipeline: Returns: (裁剪后的图像, ROI 偏移量, ROI 矩形) """ - roi_config = self._pipeline_config.roi + return self._apply_roi_with_config(image, self._pipeline_config) + + def _apply_roi_with_config( + self, + image: np.ndarray, + config: PipelineConfig + ) -> tuple: + """ + 应用 ROI 裁剪(使用指定配置,线程安全) + + Args: + image: 原始图片 + config: 管道配置 + + Returns: + (裁剪后的图像, ROI 偏移量, ROI 矩形) + """ + roi_config = config.roi if not roi_config.enabled: return image, (0, 0), None @@ -245,7 +262,9 @@ class OCRPipeline: def process( self, image: np.ndarray, - image_path: Optional[str] = None + image_path: Optional[str] = None, + pipeline_config: Optional[PipelineConfig] = None, + drop_score: Optional[float] = None, ) -> OCRResult: """ 处理单张图片 @@ -253,6 +272,8 @@ class OCRPipeline: Args: image: 输入图片 (numpy array, BGR 格式) image_path: 图片路径(可选,用于结果记录) + pipeline_config: 临时管道配置(可选,用于单次请求的配置覆盖,线程安全) + drop_score: 置信度阈值(可选,用于过滤低置信度结果) Returns: OCR 结果 @@ -260,8 +281,11 @@ class OCRPipeline: self._image_counter += 1 start_time = time.time() + # 使用临时配置或默认配置(线程安全:不修改共享状态) + config = pipeline_config if pipeline_config is not None else self._pipeline_config + # 应用 ROI 裁剪 - cropped_image, roi_offset, roi_rect = self._apply_roi(image) + cropped_image, roi_offset, roi_rect = self._apply_roi_with_config(image, config) # 图片预处理 processed_image = self._preprocess_image(cropped_image) @@ -269,6 +293,13 @@ class OCRPipeline: # 执行 OCR text_blocks = self._engine.recognize(processed_image, roi_offset) + # 应用置信度过滤(如果指定了 drop_score) + if drop_score is not None: + text_blocks = [ + block for block in text_blocks + if block.confidence >= drop_score + ] + # 计算处理耗时 processing_time_ms = (time.time() - start_time) * 1000 @@ -279,7 +310,7 @@ class OCRPipeline: timestamp=time.time(), processing_time_ms=processing_time_ms, text_blocks=text_blocks, - roi_applied=self._pipeline_config.roi.enabled, + roi_applied=config.roi.enabled, roi_rect=roi_rect )