You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
158 lines
4.0 KiB
Python
158 lines
4.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
依赖注入模块
|
|
提供 FastAPI 依赖项
|
|
"""
|
|
|
|
import base64
|
|
from typing import Optional
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from fastapi import Depends, File, Form, Request, UploadFile
|
|
|
|
from api.exceptions import InvalidImageError, ModelNotLoadedError
|
|
from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams
|
|
from api.security import decode_base64_image, validate_uploaded_file
|
|
from ocr.pipeline import OCRPipeline, OCRResult
|
|
from utils.config import OCRConfig, PipelineConfig, ROIConfig
|
|
|
|
|
|
def get_ocr_pipeline(request: Request) -> OCRPipeline:
|
|
"""
|
|
获取 OCR Pipeline 实例
|
|
|
|
Args:
|
|
request: FastAPI 请求对象
|
|
|
|
Returns:
|
|
OCRPipeline 实例
|
|
|
|
Raises:
|
|
ModelNotLoadedError: 模型未加载
|
|
"""
|
|
pipeline = getattr(request.app.state, "ocr_pipeline", None)
|
|
if pipeline is None:
|
|
raise ModelNotLoadedError()
|
|
return pipeline
|
|
|
|
|
|
async def parse_multipart_image(
|
|
file: UploadFile = File(..., description="图片文件"),
|
|
) -> bytes:
|
|
"""
|
|
解析 multipart 上传的图片文件
|
|
|
|
Args:
|
|
file: 上传的文件
|
|
|
|
Returns:
|
|
验证后的图片字节数据
|
|
"""
|
|
content = await file.read()
|
|
return validate_uploaded_file(
|
|
content=content,
|
|
filename=file.filename,
|
|
content_type=file.content_type,
|
|
)
|
|
|
|
|
|
async def parse_multipart_params(
|
|
lang: str = Form(default="ch", description="识别语言"),
|
|
use_gpu: bool = Form(default=False, description="是否使用 GPU"),
|
|
drop_score: float = Form(default=0.5, ge=0.0, le=1.0, description="置信度阈值"),
|
|
roi_x: Optional[float] = Form(default=None, ge=0.0, le=1.0, description="ROI X"),
|
|
roi_y: Optional[float] = Form(default=None, ge=0.0, le=1.0, description="ROI Y"),
|
|
roi_width: Optional[float] = Form(
|
|
default=None, gt=0.0, le=1.0, description="ROI 宽度"
|
|
),
|
|
roi_height: Optional[float] = Form(
|
|
default=None, gt=0.0, le=1.0, description="ROI 高度"
|
|
),
|
|
return_annotated_image: bool = Form(
|
|
default=False, description="是否返回标注图片"
|
|
),
|
|
) -> OCRRequestParams:
|
|
"""
|
|
解析 multipart 表单参数
|
|
|
|
Returns:
|
|
OCR 请求参数对象
|
|
"""
|
|
return OCRRequestParams(
|
|
lang=lang,
|
|
use_gpu=use_gpu,
|
|
drop_score=drop_score,
|
|
roi_x=roi_x,
|
|
roi_y=roi_y,
|
|
roi_width=roi_width,
|
|
roi_height=roi_height,
|
|
return_annotated_image=return_annotated_image,
|
|
)
|
|
|
|
|
|
def decode_image_bytes(content: bytes) -> np.ndarray:
|
|
"""
|
|
将图片字节解码为 numpy 数组
|
|
|
|
Args:
|
|
content: 图片字节数据
|
|
|
|
Returns:
|
|
OpenCV 图像 (BGR 格式)
|
|
|
|
Raises:
|
|
InvalidImageError: 图片解码失败
|
|
"""
|
|
try:
|
|
nparr = np.frombuffer(content, np.uint8)
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|
if image is None:
|
|
raise InvalidImageError("图片解码失败")
|
|
return image
|
|
except Exception as e:
|
|
if isinstance(e, InvalidImageError):
|
|
raise
|
|
raise InvalidImageError(f"图片解码失败: {str(e)}")
|
|
|
|
|
|
def encode_image_base64(image: np.ndarray, format: str = ".jpg") -> str:
|
|
"""
|
|
将图像编码为 Base64 字符串
|
|
|
|
Args:
|
|
image: OpenCV 图像
|
|
format: 图片格式 (如 '.jpg', '.png')
|
|
|
|
Returns:
|
|
Base64 编码的图片字符串
|
|
"""
|
|
success, encoded = cv2.imencode(format, image)
|
|
if not success:
|
|
return ""
|
|
return base64.b64encode(encoded.tobytes()).decode("utf-8")
|
|
|
|
|
|
def build_pipeline_config(roi: Optional[ROIParams]) -> PipelineConfig:
|
|
"""
|
|
根据 ROI 参数构建管道配置
|
|
|
|
Args:
|
|
roi: ROI 参数
|
|
|
|
Returns:
|
|
管道配置对象
|
|
"""
|
|
if roi is None:
|
|
return PipelineConfig(roi=ROIConfig(enabled=False))
|
|
|
|
return PipelineConfig(
|
|
roi=ROIConfig(
|
|
enabled=True,
|
|
x_ratio=roi.x,
|
|
y_ratio=roi.y,
|
|
width_ratio=roi.width,
|
|
height_ratio=roi.height,
|
|
)
|
|
)
|