# -*- coding: utf-8 -*- """ OCR 路由模块 提供 OCR 识别和快递单解析端点 """ from typing import Optional from fastapi import APIRouter, Depends, File, Request, UploadFile from api.dependencies import ( build_pipeline_config, decode_image_bytes, encode_image_base64, get_ocr_pipeline, parse_multipart_image, parse_multipart_params, ) from api.exceptions import OCRProcessingError from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams from api.schemas.response import ( ErrorDetail, ExpressInfoData, ExpressPersonData, ExpressResponse, ExpressResultData, OCRResponse, OCRResultData, TextBlockData, ) from api.security import decode_base64_image from ocr.pipeline import OCRPipeline, OCRResult from utils.config import OCRConfig, PipelineConfig, ROIConfig from visualize.draw import OCRVisualizer from utils.config import VisualizeConfig router = APIRouter(prefix="/ocr", tags=["OCR 识别"]) # 模块级别的 Visualizer 实例(避免重复创建) _visualizer: Optional[OCRVisualizer] = None def _get_visualizer() -> OCRVisualizer: """ 获取 Visualizer 实例(单例模式,避免重复创建) Returns: OCRVisualizer 实例 """ global _visualizer if _visualizer is None: _visualizer = OCRVisualizer(VisualizeConfig()) return _visualizer def _convert_ocr_result_to_response( result: OCRResult, annotated_image_base64: Optional[str] = None, ) -> OCRResultData: """ 将 OCRResult 转换为响应数据模型 Args: result: OCR 处理结果 annotated_image_base64: 标注图片的 Base64 编码 Returns: 响应数据模型 """ text_blocks = [ TextBlockData( text=block.text, confidence=block.confidence, bbox=block.bbox, bbox_with_offset=block.bbox_with_offset, center=list(block.center), width=block.width, height=block.height, ) for block in result.text_blocks ] return OCRResultData( processing_time_ms=result.processing_time_ms, text_count=result.text_count, average_confidence=result.average_confidence, roi_applied=result.roi_applied, roi_rect=list(result.roi_rect) if result.roi_rect else None, text_blocks=text_blocks, annotated_image_base64=annotated_image_base64, ) def _convert_express_result_to_response( result: OCRResult, annotated_image_base64: Optional[str] = None, ) -> ExpressResultData: """ 将 OCRResult 转换为快递单响应数据模型 Args: result: OCR 处理结果 annotated_image_base64: 标注图片的 Base64 编码 Returns: 快递单响应数据模型 """ # 解析快递单信息 express_info = result.parse_express() merged_text = result.merge_text() return ExpressResultData( processing_time_ms=result.processing_time_ms, express_info=ExpressInfoData( tracking_number=express_info.tracking_number, sender=ExpressPersonData( name=express_info.sender_name, phone=express_info.sender_phone, address=express_info.sender_address, ), receiver=ExpressPersonData( name=express_info.receiver_name, phone=express_info.receiver_phone, address=express_info.receiver_address, ), courier_company=express_info.courier_company, confidence=express_info.confidence, extra_fields=express_info.extra_fields, raw_text=express_info.raw_text, ), merged_text=merged_text, annotated_image_base64=annotated_image_base64, ) def _process_ocr( image_bytes: bytes, pipeline: OCRPipeline, roi: Optional[ROIParams] = None, drop_score: Optional[float] = None, return_annotated_image: bool = False, ) -> tuple[OCRResult, Optional[str]]: """ 执行 OCR 处理(线程安全) Args: image_bytes: 图片字节数据 pipeline: OCR 管道 roi: ROI 参数 drop_score: 置信度阈值,低于此值的结果将被过滤 return_annotated_image: 是否返回标注图片 Returns: (OCR 结果, 标注图片 Base64) """ # 解码图片 image = decode_image_bytes(image_bytes) # 构建管道配置(每次请求独立的配置,线程安全) pipeline_config = build_pipeline_config(roi) try: # 执行 OCR(传递临时配置,不修改共享状态) result = pipeline.process( image=image, pipeline_config=pipeline_config, drop_score=drop_score, ) except Exception as e: raise OCRProcessingError(f"OCR 处理失败: {str(e)}") # 生成标注图片 annotated_image_base64 = None if return_annotated_image and result.text_count > 0: annotated = _get_visualizer().draw_result(image, result) annotated_image_base64 = encode_image_base64(annotated) return result, annotated_image_base64 # ============================================================ # multipart/form-data 端点 # ============================================================ @router.post( "/recognize", response_model=OCRResponse, summary="OCR 识别 (文件上传)", description="上传图片文件进行 OCR 识别,支持 jpg/png/bmp/webp 格式", ) async def recognize_multipart( request: Request, file: UploadFile = File(..., description="图片文件"), params: OCRRequestParams = Depends(parse_multipart_params), pipeline: OCRPipeline = Depends(get_ocr_pipeline), ) -> OCRResponse: """ OCR 识别端点 (multipart/form-data) 上传图片文件进行文字识别 """ try: # 验证并读取文件 image_bytes = await parse_multipart_image(file) # 执行 OCR result, annotated_base64 = _process_ocr( image_bytes=image_bytes, pipeline=pipeline, roi=params.get_roi(), drop_score=params.drop_score, return_annotated_image=params.return_annotated_image, ) # 构建响应 return OCRResponse( success=True, data=_convert_ocr_result_to_response(result, annotated_base64), ) except Exception as e: return OCRResponse( success=False, error=ErrorDetail( code=type(e).__name__, message=str(e), ), ) @router.post( "/express", response_model=ExpressResponse, summary="快递单解析 (文件上传)", description="上传快递单图片进行 OCR 识别并解析结构化信息", ) async def express_multipart( request: Request, file: UploadFile = File(..., description="快递单图片"), params: OCRRequestParams = Depends(parse_multipart_params), pipeline: OCRPipeline = Depends(get_ocr_pipeline), ) -> ExpressResponse: """ 快递单解析端点 (multipart/form-data) 上传快递单图片,自动识别并提取结构化信息 """ try: # 验证并读取文件 image_bytes = await parse_multipart_image(file) # 执行 OCR result, annotated_base64 = _process_ocr( image_bytes=image_bytes, pipeline=pipeline, roi=params.get_roi(), drop_score=params.drop_score, return_annotated_image=params.return_annotated_image, ) # 构建响应 return ExpressResponse( success=True, data=_convert_express_result_to_response(result, annotated_base64), ) except Exception as e: return ExpressResponse( success=False, error=ErrorDetail( code=type(e).__name__, message=str(e), ), ) # ============================================================ # JSON (Base64) 端点 # ============================================================ @router.post( "/recognize/base64", response_model=OCRResponse, summary="OCR 识别 (Base64)", description="提交 Base64 编码的图片进行 OCR 识别", ) async def recognize_base64( request: Request, body: OCRRequestBase64, pipeline: OCRPipeline = Depends(get_ocr_pipeline), ) -> OCRResponse: """ OCR 识别端点 (JSON Base64) 提交 Base64 编码的图片进行文字识别 """ try: # 解码并验证 Base64 图片 image_bytes = decode_base64_image(body.image_base64) # 执行 OCR result, annotated_base64 = _process_ocr( image_bytes=image_bytes, pipeline=pipeline, roi=body.roi, drop_score=body.drop_score, return_annotated_image=body.return_annotated_image, ) # 构建响应 return OCRResponse( success=True, data=_convert_ocr_result_to_response(result, annotated_base64), ) except Exception as e: return OCRResponse( success=False, error=ErrorDetail( code=type(e).__name__, message=str(e), ), ) @router.post( "/express/base64", response_model=ExpressResponse, summary="快递单解析 (Base64)", description="提交 Base64 编码的快递单图片进行解析", ) async def express_base64( request: Request, body: OCRRequestBase64, pipeline: OCRPipeline = Depends(get_ocr_pipeline), ) -> ExpressResponse: """ 快递单解析端点 (JSON Base64) 提交 Base64 编码的快递单图片,自动识别并提取结构化信息 """ try: # 解码并验证 Base64 图片 image_bytes = decode_base64_image(body.image_base64) # 执行 OCR result, annotated_base64 = _process_ocr( image_bytes=image_bytes, pipeline=pipeline, roi=body.roi, drop_score=body.drop_score, return_annotated_image=body.return_annotated_image, ) # 构建响应 return ExpressResponse( success=True, data=_convert_express_result_to_response(result, annotated_base64), ) except Exception as e: return ExpressResponse( success=False, error=ErrorDetail( code=type(e).__name__, message=str(e), ), )