You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

371 lines
10 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""
OCR 路由模块
提供 OCR 识别和快递单解析端点
"""
from typing import Optional
from fastapi import APIRouter, Depends, File, Request, UploadFile
from api.dependencies import (
build_pipeline_config,
decode_image_bytes,
encode_image_base64,
get_ocr_pipeline,
parse_multipart_image,
parse_multipart_params,
)
from api.exceptions import OCRProcessingError
from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams
from api.schemas.response import (
ErrorDetail,
ExpressInfoData,
ExpressPersonData,
ExpressResponse,
ExpressResultData,
OCRResponse,
OCRResultData,
TextBlockData,
)
from api.security import decode_base64_image
from ocr.pipeline import OCRPipeline, OCRResult
from utils.config import OCRConfig, PipelineConfig, ROIConfig
from visualize.draw import OCRVisualizer
from utils.config import VisualizeConfig
router = APIRouter(prefix="/ocr", tags=["OCR 识别"])
# 模块级别的 Visualizer 实例(避免重复创建)
_visualizer: Optional[OCRVisualizer] = None
def _get_visualizer() -> OCRVisualizer:
"""
获取 Visualizer 实例(单例模式,避免重复创建)
Returns:
OCRVisualizer 实例
"""
global _visualizer
if _visualizer is None:
_visualizer = OCRVisualizer(VisualizeConfig())
return _visualizer
def _convert_ocr_result_to_response(
result: OCRResult,
annotated_image_base64: Optional[str] = None,
) -> OCRResultData:
"""
将 OCRResult 转换为响应数据模型
Args:
result: OCR 处理结果
annotated_image_base64: 标注图片的 Base64 编码
Returns:
响应数据模型
"""
text_blocks = [
TextBlockData(
text=block.text,
confidence=block.confidence,
bbox=block.bbox,
bbox_with_offset=block.bbox_with_offset,
center=list(block.center),
width=block.width,
height=block.height,
)
for block in result.text_blocks
]
return OCRResultData(
processing_time_ms=result.processing_time_ms,
text_count=result.text_count,
average_confidence=result.average_confidence,
roi_applied=result.roi_applied,
roi_rect=list(result.roi_rect) if result.roi_rect else None,
text_blocks=text_blocks,
annotated_image_base64=annotated_image_base64,
)
def _convert_express_result_to_response(
result: OCRResult,
annotated_image_base64: Optional[str] = None,
) -> ExpressResultData:
"""
将 OCRResult 转换为快递单响应数据模型
Args:
result: OCR 处理结果
annotated_image_base64: 标注图片的 Base64 编码
Returns:
快递单响应数据模型
"""
# 解析快递单信息
express_info = result.parse_express()
merged_text = result.merge_text()
return ExpressResultData(
processing_time_ms=result.processing_time_ms,
express_info=ExpressInfoData(
tracking_number=express_info.tracking_number,
sender=ExpressPersonData(
name=express_info.sender_name,
phone=express_info.sender_phone,
address=express_info.sender_address,
),
receiver=ExpressPersonData(
name=express_info.receiver_name,
phone=express_info.receiver_phone,
address=express_info.receiver_address,
),
courier_company=express_info.courier_company,
confidence=express_info.confidence,
extra_fields=express_info.extra_fields,
raw_text=express_info.raw_text,
),
merged_text=merged_text,
annotated_image_base64=annotated_image_base64,
)
def _process_ocr(
image_bytes: bytes,
pipeline: OCRPipeline,
roi: Optional[ROIParams] = None,
drop_score: Optional[float] = None,
return_annotated_image: bool = False,
) -> tuple[OCRResult, Optional[str]]:
"""
执行 OCR 处理(线程安全)
Args:
image_bytes: 图片字节数据
pipeline: OCR 管道
roi: ROI 参数
drop_score: 置信度阈值,低于此值的结果将被过滤
return_annotated_image: 是否返回标注图片
Returns:
(OCR 结果, 标注图片 Base64)
"""
# 解码图片
image = decode_image_bytes(image_bytes)
# 构建管道配置(每次请求独立的配置,线程安全)
pipeline_config = build_pipeline_config(roi)
try:
# 执行 OCR传递临时配置不修改共享状态
result = pipeline.process(
image=image,
pipeline_config=pipeline_config,
drop_score=drop_score,
)
except Exception as e:
raise OCRProcessingError(f"OCR 处理失败: {str(e)}")
# 生成标注图片
annotated_image_base64 = None
if return_annotated_image and result.text_count > 0:
annotated = _get_visualizer().draw_result(image, result)
annotated_image_base64 = encode_image_base64(annotated)
return result, annotated_image_base64
# ============================================================
# multipart/form-data 端点
# ============================================================
@router.post(
"/recognize",
response_model=OCRResponse,
summary="OCR 识别 (文件上传)",
description="上传图片文件进行 OCR 识别,支持 jpg/png/bmp/webp 格式",
)
async def recognize_multipart(
request: Request,
file: UploadFile = File(..., description="图片文件"),
params: OCRRequestParams = Depends(parse_multipart_params),
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
) -> OCRResponse:
"""
OCR 识别端点 (multipart/form-data)
上传图片文件进行文字识别
"""
try:
# 验证并读取文件
image_bytes = await parse_multipart_image(file)
# 执行 OCR
result, annotated_base64 = _process_ocr(
image_bytes=image_bytes,
pipeline=pipeline,
roi=params.get_roi(),
drop_score=params.drop_score,
return_annotated_image=params.return_annotated_image,
)
# 构建响应
return OCRResponse(
success=True,
data=_convert_ocr_result_to_response(result, annotated_base64),
)
except Exception as e:
return OCRResponse(
success=False,
error=ErrorDetail(
code=type(e).__name__,
message=str(e),
),
)
@router.post(
"/express",
response_model=ExpressResponse,
summary="快递单解析 (文件上传)",
description="上传快递单图片进行 OCR 识别并解析结构化信息",
)
async def express_multipart(
request: Request,
file: UploadFile = File(..., description="快递单图片"),
params: OCRRequestParams = Depends(parse_multipart_params),
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
) -> ExpressResponse:
"""
快递单解析端点 (multipart/form-data)
上传快递单图片,自动识别并提取结构化信息
"""
try:
# 验证并读取文件
image_bytes = await parse_multipart_image(file)
# 执行 OCR
result, annotated_base64 = _process_ocr(
image_bytes=image_bytes,
pipeline=pipeline,
roi=params.get_roi(),
drop_score=params.drop_score,
return_annotated_image=params.return_annotated_image,
)
# 构建响应
return ExpressResponse(
success=True,
data=_convert_express_result_to_response(result, annotated_base64),
)
except Exception as e:
return ExpressResponse(
success=False,
error=ErrorDetail(
code=type(e).__name__,
message=str(e),
),
)
# ============================================================
# JSON (Base64) 端点
# ============================================================
@router.post(
"/recognize/base64",
response_model=OCRResponse,
summary="OCR 识别 (Base64)",
description="提交 Base64 编码的图片进行 OCR 识别",
)
async def recognize_base64(
request: Request,
body: OCRRequestBase64,
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
) -> OCRResponse:
"""
OCR 识别端点 (JSON Base64)
提交 Base64 编码的图片进行文字识别
"""
try:
# 解码并验证 Base64 图片
image_bytes = decode_base64_image(body.image_base64)
# 执行 OCR
result, annotated_base64 = _process_ocr(
image_bytes=image_bytes,
pipeline=pipeline,
roi=body.roi,
drop_score=body.drop_score,
return_annotated_image=body.return_annotated_image,
)
# 构建响应
return OCRResponse(
success=True,
data=_convert_ocr_result_to_response(result, annotated_base64),
)
except Exception as e:
return OCRResponse(
success=False,
error=ErrorDetail(
code=type(e).__name__,
message=str(e),
),
)
@router.post(
"/express/base64",
response_model=ExpressResponse,
summary="快递单解析 (Base64)",
description="提交 Base64 编码的快递单图片进行解析",
)
async def express_base64(
request: Request,
body: OCRRequestBase64,
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
) -> ExpressResponse:
"""
快递单解析端点 (JSON Base64)
提交 Base64 编码的快递单图片,自动识别并提取结构化信息
"""
try:
# 解码并验证 Base64 图片
image_bytes = decode_base64_image(body.image_base64)
# 执行 OCR
result, annotated_base64 = _process_ocr(
image_bytes=image_bytes,
pipeline=pipeline,
roi=body.roi,
drop_score=body.drop_score,
return_annotated_image=body.return_annotated_image,
)
# 构建响应
return ExpressResponse(
success=True,
data=_convert_express_result_to_response(result, annotated_base64),
)
except Exception as e:
return ExpressResponse(
success=False,
error=ErrorDetail(
code=type(e).__name__,
message=str(e),
),
)