|
|
# -*- coding: utf-8 -*-
|
|
|
"""
|
|
|
OCR 路由模块
|
|
|
提供 OCR 识别和快递单解析端点
|
|
|
"""
|
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
from fastapi import APIRouter, Depends, File, Request, UploadFile
|
|
|
|
|
|
from api.dependencies import (
|
|
|
build_pipeline_config,
|
|
|
decode_image_bytes,
|
|
|
encode_image_base64,
|
|
|
get_ocr_pipeline,
|
|
|
parse_multipart_image,
|
|
|
parse_multipart_params,
|
|
|
)
|
|
|
from api.exceptions import OCRProcessingError
|
|
|
from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams
|
|
|
from api.schemas.response import (
|
|
|
ErrorDetail,
|
|
|
ExpressInfoData,
|
|
|
ExpressPersonData,
|
|
|
ExpressResponse,
|
|
|
ExpressResultData,
|
|
|
OCRResponse,
|
|
|
OCRResultData,
|
|
|
TextBlockData,
|
|
|
)
|
|
|
from api.security import decode_base64_image
|
|
|
from ocr.pipeline import OCRPipeline, OCRResult
|
|
|
from utils.config import OCRConfig, PipelineConfig, ROIConfig
|
|
|
from visualize.draw import OCRVisualizer
|
|
|
from utils.config import VisualizeConfig
|
|
|
|
|
|
router = APIRouter(prefix="/ocr", tags=["OCR 识别"])
|
|
|
|
|
|
# 模块级别的 Visualizer 实例(避免重复创建)
|
|
|
_visualizer: Optional[OCRVisualizer] = None
|
|
|
|
|
|
|
|
|
def _get_visualizer() -> OCRVisualizer:
|
|
|
"""
|
|
|
获取 Visualizer 实例(单例模式,避免重复创建)
|
|
|
|
|
|
Returns:
|
|
|
OCRVisualizer 实例
|
|
|
"""
|
|
|
global _visualizer
|
|
|
if _visualizer is None:
|
|
|
_visualizer = OCRVisualizer(VisualizeConfig())
|
|
|
return _visualizer
|
|
|
|
|
|
|
|
|
def _convert_ocr_result_to_response(
|
|
|
result: OCRResult,
|
|
|
annotated_image_base64: Optional[str] = None,
|
|
|
) -> OCRResultData:
|
|
|
"""
|
|
|
将 OCRResult 转换为响应数据模型
|
|
|
|
|
|
Args:
|
|
|
result: OCR 处理结果
|
|
|
annotated_image_base64: 标注图片的 Base64 编码
|
|
|
|
|
|
Returns:
|
|
|
响应数据模型
|
|
|
"""
|
|
|
text_blocks = [
|
|
|
TextBlockData(
|
|
|
text=block.text,
|
|
|
confidence=block.confidence,
|
|
|
bbox=block.bbox,
|
|
|
bbox_with_offset=block.bbox_with_offset,
|
|
|
center=list(block.center),
|
|
|
width=block.width,
|
|
|
height=block.height,
|
|
|
)
|
|
|
for block in result.text_blocks
|
|
|
]
|
|
|
|
|
|
return OCRResultData(
|
|
|
processing_time_ms=result.processing_time_ms,
|
|
|
text_count=result.text_count,
|
|
|
average_confidence=result.average_confidence,
|
|
|
roi_applied=result.roi_applied,
|
|
|
roi_rect=list(result.roi_rect) if result.roi_rect else None,
|
|
|
text_blocks=text_blocks,
|
|
|
annotated_image_base64=annotated_image_base64,
|
|
|
)
|
|
|
|
|
|
|
|
|
def _convert_express_result_to_response(
|
|
|
result: OCRResult,
|
|
|
annotated_image_base64: Optional[str] = None,
|
|
|
) -> ExpressResultData:
|
|
|
"""
|
|
|
将 OCRResult 转换为快递单响应数据模型
|
|
|
|
|
|
Args:
|
|
|
result: OCR 处理结果
|
|
|
annotated_image_base64: 标注图片的 Base64 编码
|
|
|
|
|
|
Returns:
|
|
|
快递单响应数据模型
|
|
|
"""
|
|
|
# 解析快递单信息
|
|
|
express_info = result.parse_express()
|
|
|
merged_text = result.merge_text()
|
|
|
|
|
|
return ExpressResultData(
|
|
|
processing_time_ms=result.processing_time_ms,
|
|
|
express_info=ExpressInfoData(
|
|
|
tracking_number=express_info.tracking_number,
|
|
|
sender=ExpressPersonData(
|
|
|
name=express_info.sender_name,
|
|
|
phone=express_info.sender_phone,
|
|
|
address=express_info.sender_address,
|
|
|
),
|
|
|
receiver=ExpressPersonData(
|
|
|
name=express_info.receiver_name,
|
|
|
phone=express_info.receiver_phone,
|
|
|
address=express_info.receiver_address,
|
|
|
),
|
|
|
courier_company=express_info.courier_company,
|
|
|
confidence=express_info.confidence,
|
|
|
extra_fields=express_info.extra_fields,
|
|
|
raw_text=express_info.raw_text,
|
|
|
),
|
|
|
merged_text=merged_text,
|
|
|
annotated_image_base64=annotated_image_base64,
|
|
|
)
|
|
|
|
|
|
|
|
|
def _process_ocr(
|
|
|
image_bytes: bytes,
|
|
|
pipeline: OCRPipeline,
|
|
|
roi: Optional[ROIParams] = None,
|
|
|
drop_score: Optional[float] = None,
|
|
|
return_annotated_image: bool = False,
|
|
|
) -> tuple[OCRResult, Optional[str]]:
|
|
|
"""
|
|
|
执行 OCR 处理(线程安全)
|
|
|
|
|
|
Args:
|
|
|
image_bytes: 图片字节数据
|
|
|
pipeline: OCR 管道
|
|
|
roi: ROI 参数
|
|
|
drop_score: 置信度阈值,低于此值的结果将被过滤
|
|
|
return_annotated_image: 是否返回标注图片
|
|
|
|
|
|
Returns:
|
|
|
(OCR 结果, 标注图片 Base64)
|
|
|
"""
|
|
|
# 解码图片
|
|
|
image = decode_image_bytes(image_bytes)
|
|
|
|
|
|
# 构建管道配置(每次请求独立的配置,线程安全)
|
|
|
pipeline_config = build_pipeline_config(roi)
|
|
|
|
|
|
try:
|
|
|
# 执行 OCR(传递临时配置,不修改共享状态)
|
|
|
result = pipeline.process(
|
|
|
image=image,
|
|
|
pipeline_config=pipeline_config,
|
|
|
drop_score=drop_score,
|
|
|
)
|
|
|
except Exception as e:
|
|
|
raise OCRProcessingError(f"OCR 处理失败: {str(e)}")
|
|
|
|
|
|
# 生成标注图片
|
|
|
annotated_image_base64 = None
|
|
|
if return_annotated_image and result.text_count > 0:
|
|
|
annotated = _get_visualizer().draw_result(image, result)
|
|
|
annotated_image_base64 = encode_image_base64(annotated)
|
|
|
|
|
|
return result, annotated_image_base64
|
|
|
|
|
|
|
|
|
# ============================================================
|
|
|
# multipart/form-data 端点
|
|
|
# ============================================================
|
|
|
|
|
|
|
|
|
@router.post(
|
|
|
"/recognize",
|
|
|
response_model=OCRResponse,
|
|
|
summary="OCR 识别 (文件上传)",
|
|
|
description="上传图片文件进行 OCR 识别,支持 jpg/png/bmp/webp 格式",
|
|
|
)
|
|
|
async def recognize_multipart(
|
|
|
request: Request,
|
|
|
file: UploadFile = File(..., description="图片文件"),
|
|
|
params: OCRRequestParams = Depends(parse_multipart_params),
|
|
|
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
|
|
|
) -> OCRResponse:
|
|
|
"""
|
|
|
OCR 识别端点 (multipart/form-data)
|
|
|
|
|
|
上传图片文件进行文字识别
|
|
|
"""
|
|
|
try:
|
|
|
# 验证并读取文件
|
|
|
image_bytes = await parse_multipart_image(file)
|
|
|
|
|
|
# 执行 OCR
|
|
|
result, annotated_base64 = _process_ocr(
|
|
|
image_bytes=image_bytes,
|
|
|
pipeline=pipeline,
|
|
|
roi=params.get_roi(),
|
|
|
drop_score=params.drop_score,
|
|
|
return_annotated_image=params.return_annotated_image,
|
|
|
)
|
|
|
|
|
|
# 构建响应
|
|
|
return OCRResponse(
|
|
|
success=True,
|
|
|
data=_convert_ocr_result_to_response(result, annotated_base64),
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
return OCRResponse(
|
|
|
success=False,
|
|
|
error=ErrorDetail(
|
|
|
code=type(e).__name__,
|
|
|
message=str(e),
|
|
|
),
|
|
|
)
|
|
|
|
|
|
|
|
|
@router.post(
|
|
|
"/express",
|
|
|
response_model=ExpressResponse,
|
|
|
summary="快递单解析 (文件上传)",
|
|
|
description="上传快递单图片进行 OCR 识别并解析结构化信息",
|
|
|
)
|
|
|
async def express_multipart(
|
|
|
request: Request,
|
|
|
file: UploadFile = File(..., description="快递单图片"),
|
|
|
params: OCRRequestParams = Depends(parse_multipart_params),
|
|
|
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
|
|
|
) -> ExpressResponse:
|
|
|
"""
|
|
|
快递单解析端点 (multipart/form-data)
|
|
|
|
|
|
上传快递单图片,自动识别并提取结构化信息
|
|
|
"""
|
|
|
try:
|
|
|
# 验证并读取文件
|
|
|
image_bytes = await parse_multipart_image(file)
|
|
|
|
|
|
# 执行 OCR
|
|
|
result, annotated_base64 = _process_ocr(
|
|
|
image_bytes=image_bytes,
|
|
|
pipeline=pipeline,
|
|
|
roi=params.get_roi(),
|
|
|
drop_score=params.drop_score,
|
|
|
return_annotated_image=params.return_annotated_image,
|
|
|
)
|
|
|
|
|
|
# 构建响应
|
|
|
return ExpressResponse(
|
|
|
success=True,
|
|
|
data=_convert_express_result_to_response(result, annotated_base64),
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
return ExpressResponse(
|
|
|
success=False,
|
|
|
error=ErrorDetail(
|
|
|
code=type(e).__name__,
|
|
|
message=str(e),
|
|
|
),
|
|
|
)
|
|
|
|
|
|
|
|
|
# ============================================================
|
|
|
# JSON (Base64) 端点
|
|
|
# ============================================================
|
|
|
|
|
|
|
|
|
@router.post(
|
|
|
"/recognize/base64",
|
|
|
response_model=OCRResponse,
|
|
|
summary="OCR 识别 (Base64)",
|
|
|
description="提交 Base64 编码的图片进行 OCR 识别",
|
|
|
)
|
|
|
async def recognize_base64(
|
|
|
request: Request,
|
|
|
body: OCRRequestBase64,
|
|
|
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
|
|
|
) -> OCRResponse:
|
|
|
"""
|
|
|
OCR 识别端点 (JSON Base64)
|
|
|
|
|
|
提交 Base64 编码的图片进行文字识别
|
|
|
"""
|
|
|
try:
|
|
|
# 解码并验证 Base64 图片
|
|
|
image_bytes = decode_base64_image(body.image_base64)
|
|
|
|
|
|
# 执行 OCR
|
|
|
result, annotated_base64 = _process_ocr(
|
|
|
image_bytes=image_bytes,
|
|
|
pipeline=pipeline,
|
|
|
roi=body.roi,
|
|
|
drop_score=body.drop_score,
|
|
|
return_annotated_image=body.return_annotated_image,
|
|
|
)
|
|
|
|
|
|
# 构建响应
|
|
|
return OCRResponse(
|
|
|
success=True,
|
|
|
data=_convert_ocr_result_to_response(result, annotated_base64),
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
return OCRResponse(
|
|
|
success=False,
|
|
|
error=ErrorDetail(
|
|
|
code=type(e).__name__,
|
|
|
message=str(e),
|
|
|
),
|
|
|
)
|
|
|
|
|
|
|
|
|
@router.post(
|
|
|
"/express/base64",
|
|
|
response_model=ExpressResponse,
|
|
|
summary="快递单解析 (Base64)",
|
|
|
description="提交 Base64 编码的快递单图片进行解析",
|
|
|
)
|
|
|
async def express_base64(
|
|
|
request: Request,
|
|
|
body: OCRRequestBase64,
|
|
|
pipeline: OCRPipeline = Depends(get_ocr_pipeline),
|
|
|
) -> ExpressResponse:
|
|
|
"""
|
|
|
快递单解析端点 (JSON Base64)
|
|
|
|
|
|
提交 Base64 编码的快递单图片,自动识别并提取结构化信息
|
|
|
"""
|
|
|
try:
|
|
|
# 解码并验证 Base64 图片
|
|
|
image_bytes = decode_base64_image(body.image_base64)
|
|
|
|
|
|
# 执行 OCR
|
|
|
result, annotated_base64 = _process_ocr(
|
|
|
image_bytes=image_bytes,
|
|
|
pipeline=pipeline,
|
|
|
roi=body.roi,
|
|
|
drop_score=body.drop_score,
|
|
|
return_annotated_image=body.return_annotated_image,
|
|
|
)
|
|
|
|
|
|
# 构建响应
|
|
|
return ExpressResponse(
|
|
|
success=True,
|
|
|
data=_convert_express_result_to_response(result, annotated_base64),
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
return ExpressResponse(
|
|
|
success=False,
|
|
|
error=ErrorDetail(
|
|
|
code=type(e).__name__,
|
|
|
message=str(e),
|
|
|
),
|
|
|
)
|