You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

361 lines
11 KiB
Python

# -*- coding: utf-8 -*-
"""
OCR 路由模块
提供 OCR 识别和快递单解析端点
"""
from typing import Optional
from fastapi import APIRouter, Depends, File, Request, UploadFile
from api.dependencies import (
build_pipeline_config,
decode_image_bytes,
encode_image_base64,
get_ocr_pipeline,
parse_multipart_image,
parse_multipart_params,
)
from api.exceptions import OCRProcessingError
from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams
from api.schemas.response import (
ErrorDetail,
ExpressInfoData,
ExpressPersonData,
ExpressResponse,
ExpressResultData,
OCRResponse,
OCRResultData,
TextBlockData,
)
from api.security import decode_base64_image
from ocr.pipeline import OCRPipeline, OCRResult
from utils.config import OCRConfig, PipelineConfig, ROIConfig
from visualize.draw import OCRVisualizer
from utils.config import VisualizeConfig
router = APIRouter(prefix="/ocr", tags=["OCR 识别"])
def _convert_ocr_result_to_response(
result: OCRResult,
annotated_image_base64: Optional[str] = None,
) -> OCRResultData:
"""
将 OCRResult 转换为响应数据模型
Args:
result: OCR 处理结果
annotated_image_base64: 标注图片的 Base64 编码
Returns:
响应数据模型
"""
text_blocks = [
TextBlockData(
text=block.text,
confidence=block.confidence,
bbox=block.bbox,
bbox_with_offset=block.bbox_with_offset,
center=list(block.center),
width=block.width,
height=block.height,
)
for block in result.text_blocks
]
return OCRResultData(
processing_time_ms=result.processing_time_ms,
text_count=result.text_count,
average_confidence=result.average_confidence,
roi_applied=result.roi_applied,
roi_rect=list(result.roi_rect) if result.roi_rect else None,
text_blocks=text_blocks,
annotated_image_base64=annotated_image_base64,
)
def _process_ocr(
image_bytes: bytes,
pipeline: OCRPipeline,
roi: Optional[ROIParams] = None,
return_annotated_image: bool = False,
) -> tuple[OCRResult, Optional[str]]:
"""
执行 OCR 处理
Args:
image_bytes: 图片字节数据
pipeline: OCR 管道
roi: ROI 参数
return_annotated_image: 是否返回标注图片
Returns:
(OCR 结果, 标注图片 Base64)
"""
# 解码图片
image = decode_image_bytes(image_bytes)
# 构建管道配置
pipeline_config = build_pipeline_config(roi)
# 临时更新管道配置
original_config = pipeline._pipeline_config
pipeline._pipeline_config = pipeline_config
try:
# 执行 OCR
result = pipeline.process(image)
except Exception as e:
raise OCRProcessingError(f"OCR 处理失败: {str(e)}")
finally:
# 恢复原始配置
pipeline._pipeline_config = original_config
# 生成标注图片
annotated_image_base64 = None
if return_annotated_image and result.text_count > 0:
visualizer = OCRVisualizer(VisualizeConfig())
annotated = visualizer.draw_result(image, result)
annotated_image_base64 = encode_image_base64(annotated)
return result, annotated_image_base64
# ============================================================
# multipart/form-data 端点
# ============================================================
@router.post(
"/recognize",
response_model=OCRResponse,
summary="OCR 识别 (文件上传)",
description="上传图片文件进行 OCR 识别,支持 jpg/png/bmp/webp 格式",
)
async def recognize_multipart(
request: Request,
file: UploadFile = File(..., description="图片文件"),
params: OCRRequestParams = Depends(parse_multipart_params),
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
) -> OCRResponse:
"""
OCR 识别端点 (multipart/form-data)
上传图片文件进行文字识别
"""
try:
# 验证并读取文件
image_bytes = await parse_multipart_image(file)
# 执行 OCR
result, annotated_base64 = _process_ocr(
image_bytes=image_bytes,
pipeline=pipeline,
roi=params.get_roi(),
return_annotated_image=params.return_annotated_image,
)
# 构建响应
return OCRResponse(
success=True,
data=_convert_ocr_result_to_response(result, annotated_base64),
)
except Exception as e:
return OCRResponse(
success=False,
error=ErrorDetail(
code=type(e).__name__,
message=str(e),
),
)
@router.post(
"/express",
response_model=ExpressResponse,
summary="快递单解析 (文件上传)",
description="上传快递单图片进行 OCR 识别并解析结构化信息",
)
async def express_multipart(
request: Request,
file: UploadFile = File(..., description="快递单图片"),
params: OCRRequestParams = Depends(parse_multipart_params),
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
) -> ExpressResponse:
"""
快递单解析端点 (multipart/form-data)
上传快递单图片,自动识别并提取结构化信息
"""
try:
# 验证并读取文件
image_bytes = await parse_multipart_image(file)
# 执行 OCR
result, annotated_base64 = _process_ocr(
image_bytes=image_bytes,
pipeline=pipeline,
roi=params.get_roi(),
return_annotated_image=params.return_annotated_image,
)
# 解析快递单信息
express_info = result.parse_express()
merged_text = result.merge_text()
# 构建响应
return ExpressResponse(
success=True,
data=ExpressResultData(
processing_time_ms=result.processing_time_ms,
express_info=ExpressInfoData(
tracking_number=express_info.tracking_number,
sender=ExpressPersonData(
name=express_info.sender_name,
phone=express_info.sender_phone,
address=express_info.sender_address,
),
receiver=ExpressPersonData(
name=express_info.receiver_name,
phone=express_info.receiver_phone,
address=express_info.receiver_address,
),
courier_company=express_info.courier_company,
confidence=express_info.confidence,
extra_fields=express_info.extra_fields,
raw_text=express_info.raw_text,
),
merged_text=merged_text,
annotated_image_base64=annotated_base64,
),
)
except Exception as e:
return ExpressResponse(
success=False,
error=ErrorDetail(
code=type(e).__name__,
message=str(e),
),
)
# ============================================================
# JSON (Base64) 端点
# ============================================================
@router.post(
"/recognize/base64",
response_model=OCRResponse,
summary="OCR 识别 (Base64)",
description="提交 Base64 编码的图片进行 OCR 识别",
)
async def recognize_base64(
request: Request,
body: OCRRequestBase64,
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
) -> OCRResponse:
"""
OCR 识别端点 (JSON Base64)
提交 Base64 编码的图片进行文字识别
"""
try:
# 解码并验证 Base64 图片
image_bytes = decode_base64_image(body.image_base64)
# 执行 OCR
result, annotated_base64 = _process_ocr(
image_bytes=image_bytes,
pipeline=pipeline,
roi=body.roi,
return_annotated_image=body.return_annotated_image,
)
# 构建响应
return OCRResponse(
success=True,
data=_convert_ocr_result_to_response(result, annotated_base64),
)
except Exception as e:
return OCRResponse(
success=False,
error=ErrorDetail(
code=type(e).__name__,
message=str(e),
),
)
@router.post(
"/express/base64",
response_model=ExpressResponse,
summary="快递单解析 (Base64)",
description="提交 Base64 编码的快递单图片进行解析",
)
async def express_base64(
request: Request,
body: OCRRequestBase64,
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
) -> ExpressResponse:
"""
快递单解析端点 (JSON Base64)
提交 Base64 编码的快递单图片,自动识别并提取结构化信息
"""
try:
# 解码并验证 Base64 图片
image_bytes = decode_base64_image(body.image_base64)
# 执行 OCR
result, annotated_base64 = _process_ocr(
image_bytes=image_bytes,
pipeline=pipeline,
roi=body.roi,
return_annotated_image=body.return_annotated_image,
)
# 解析快递单信息
express_info = result.parse_express()
merged_text = result.merge_text()
# 构建响应
return ExpressResponse(
success=True,
data=ExpressResultData(
processing_time_ms=result.processing_time_ms,
express_info=ExpressInfoData(
tracking_number=express_info.tracking_number,
sender=ExpressPersonData(
name=express_info.sender_name,
phone=express_info.sender_phone,
address=express_info.sender_address,
),
receiver=ExpressPersonData(
name=express_info.receiver_name,
phone=express_info.receiver_phone,
address=express_info.receiver_address,
),
courier_company=express_info.courier_company,
confidence=express_info.confidence,
extra_fields=express_info.extra_fields,
raw_text=express_info.raw_text,
),
merged_text=merged_text,
annotated_image_base64=annotated_base64,
),
)
except Exception as e:
return ExpressResponse(
success=False,
error=ErrorDetail(
code=type(e).__name__,
message=str(e),
),
)