# -*- coding: utf-8 -*- """ OCR 路由模块 提供 OCR 识别和快递单解析端点 """ from typing import Optional from fastapi import APIRouter, Depends, File, Request, UploadFile from api.dependencies import ( build_pipeline_config, decode_image_bytes, encode_image_base64, get_ocr_pipeline, parse_multipart_image, parse_multipart_params, ) from api.exceptions import OCRProcessingError from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams from api.schemas.response import ( ErrorDetail, ExpressInfoData, ExpressPersonData, ExpressResponse, ExpressResultData, OCRResponse, OCRResultData, TextBlockData, ) from api.security import decode_base64_image from ocr.pipeline import OCRPipeline, OCRResult from utils.config import OCRConfig, PipelineConfig, ROIConfig from visualize.draw import OCRVisualizer from utils.config import VisualizeConfig router = APIRouter(prefix="/ocr", tags=["OCR 识别"]) def _convert_ocr_result_to_response( result: OCRResult, annotated_image_base64: Optional[str] = None, ) -> OCRResultData: """ 将 OCRResult 转换为响应数据模型 Args: result: OCR 处理结果 annotated_image_base64: 标注图片的 Base64 编码 Returns: 响应数据模型 """ text_blocks = [ TextBlockData( text=block.text, confidence=block.confidence, bbox=block.bbox, bbox_with_offset=block.bbox_with_offset, center=list(block.center), width=block.width, height=block.height, ) for block in result.text_blocks ] return OCRResultData( processing_time_ms=result.processing_time_ms, text_count=result.text_count, average_confidence=result.average_confidence, roi_applied=result.roi_applied, roi_rect=list(result.roi_rect) if result.roi_rect else None, text_blocks=text_blocks, annotated_image_base64=annotated_image_base64, ) def _process_ocr( image_bytes: bytes, pipeline: OCRPipeline, roi: Optional[ROIParams] = None, return_annotated_image: bool = False, ) -> tuple[OCRResult, Optional[str]]: """ 执行 OCR 处理 Args: image_bytes: 图片字节数据 pipeline: OCR 管道 roi: ROI 参数 return_annotated_image: 是否返回标注图片 Returns: (OCR 结果, 标注图片 Base64) """ # 解码图片 image = decode_image_bytes(image_bytes) # 构建管道配置 pipeline_config = build_pipeline_config(roi) # 临时更新管道配置 original_config = pipeline._pipeline_config pipeline._pipeline_config = pipeline_config try: # 执行 OCR result = pipeline.process(image) except Exception as e: raise OCRProcessingError(f"OCR 处理失败: {str(e)}") finally: # 恢复原始配置 pipeline._pipeline_config = original_config # 生成标注图片 annotated_image_base64 = None if return_annotated_image and result.text_count > 0: visualizer = OCRVisualizer(VisualizeConfig()) annotated = visualizer.draw_result(image, result) annotated_image_base64 = encode_image_base64(annotated) return result, annotated_image_base64 # ============================================================ # multipart/form-data 端点 # ============================================================ @router.post( "/recognize", response_model=OCRResponse, summary="OCR 识别 (文件上传)", description="上传图片文件进行 OCR 识别,支持 jpg/png/bmp/webp 格式", ) async def recognize_multipart( request: Request, file: UploadFile = File(..., description="图片文件"), params: OCRRequestParams = Depends(parse_multipart_params), pipeline: OCRPipeline = Depends(get_ocr_pipeline), ) -> OCRResponse: """ OCR 识别端点 (multipart/form-data) 上传图片文件进行文字识别 """ try: # 验证并读取文件 image_bytes = await parse_multipart_image(file) # 执行 OCR result, annotated_base64 = _process_ocr( image_bytes=image_bytes, pipeline=pipeline, roi=params.get_roi(), return_annotated_image=params.return_annotated_image, ) # 构建响应 return OCRResponse( success=True, data=_convert_ocr_result_to_response(result, annotated_base64), ) except Exception as e: return OCRResponse( success=False, error=ErrorDetail( code=type(e).__name__, message=str(e), ), ) @router.post( "/express", response_model=ExpressResponse, summary="快递单解析 (文件上传)", description="上传快递单图片进行 OCR 识别并解析结构化信息", ) async def express_multipart( request: Request, file: UploadFile = File(..., description="快递单图片"), params: OCRRequestParams = Depends(parse_multipart_params), pipeline: OCRPipeline = Depends(get_ocr_pipeline), ) -> ExpressResponse: """ 快递单解析端点 (multipart/form-data) 上传快递单图片,自动识别并提取结构化信息 """ try: # 验证并读取文件 image_bytes = await parse_multipart_image(file) # 执行 OCR result, annotated_base64 = _process_ocr( image_bytes=image_bytes, pipeline=pipeline, roi=params.get_roi(), return_annotated_image=params.return_annotated_image, ) # 解析快递单信息 express_info = result.parse_express() merged_text = result.merge_text() # 构建响应 return ExpressResponse( success=True, data=ExpressResultData( processing_time_ms=result.processing_time_ms, express_info=ExpressInfoData( tracking_number=express_info.tracking_number, sender=ExpressPersonData( name=express_info.sender_name, phone=express_info.sender_phone, address=express_info.sender_address, ), receiver=ExpressPersonData( name=express_info.receiver_name, phone=express_info.receiver_phone, address=express_info.receiver_address, ), courier_company=express_info.courier_company, confidence=express_info.confidence, extra_fields=express_info.extra_fields, raw_text=express_info.raw_text, ), merged_text=merged_text, annotated_image_base64=annotated_base64, ), ) except Exception as e: return ExpressResponse( success=False, error=ErrorDetail( code=type(e).__name__, message=str(e), ), ) # ============================================================ # JSON (Base64) 端点 # ============================================================ @router.post( "/recognize/base64", response_model=OCRResponse, summary="OCR 识别 (Base64)", description="提交 Base64 编码的图片进行 OCR 识别", ) async def recognize_base64( request: Request, body: OCRRequestBase64, pipeline: OCRPipeline = Depends(get_ocr_pipeline), ) -> OCRResponse: """ OCR 识别端点 (JSON Base64) 提交 Base64 编码的图片进行文字识别 """ try: # 解码并验证 Base64 图片 image_bytes = decode_base64_image(body.image_base64) # 执行 OCR result, annotated_base64 = _process_ocr( image_bytes=image_bytes, pipeline=pipeline, roi=body.roi, return_annotated_image=body.return_annotated_image, ) # 构建响应 return OCRResponse( success=True, data=_convert_ocr_result_to_response(result, annotated_base64), ) except Exception as e: return OCRResponse( success=False, error=ErrorDetail( code=type(e).__name__, message=str(e), ), ) @router.post( "/express/base64", response_model=ExpressResponse, summary="快递单解析 (Base64)", description="提交 Base64 编码的快递单图片进行解析", ) async def express_base64( request: Request, body: OCRRequestBase64, pipeline: OCRPipeline = Depends(get_ocr_pipeline), ) -> ExpressResponse: """ 快递单解析端点 (JSON Base64) 提交 Base64 编码的快递单图片,自动识别并提取结构化信息 """ try: # 解码并验证 Base64 图片 image_bytes = decode_base64_image(body.image_base64) # 执行 OCR result, annotated_base64 = _process_ocr( image_bytes=image_bytes, pipeline=pipeline, roi=body.roi, return_annotated_image=body.return_annotated_image, ) # 解析快递单信息 express_info = result.parse_express() merged_text = result.merge_text() # 构建响应 return ExpressResponse( success=True, data=ExpressResultData( processing_time_ms=result.processing_time_ms, express_info=ExpressInfoData( tracking_number=express_info.tracking_number, sender=ExpressPersonData( name=express_info.sender_name, phone=express_info.sender_phone, address=express_info.sender_address, ), receiver=ExpressPersonData( name=express_info.receiver_name, phone=express_info.receiver_phone, address=express_info.receiver_address, ), courier_company=express_info.courier_company, confidence=express_info.confidence, extra_fields=express_info.extra_fields, raw_text=express_info.raw_text, ), merged_text=merged_text, annotated_image_base64=annotated_base64, ), ) except Exception as e: return ExpressResponse( success=False, error=ErrorDetail( code=type(e).__name__, message=str(e), ), )