You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

158 lines
4.0 KiB
Python

# -*- coding: utf-8 -*-
"""
依赖注入模块
提供 FastAPI 依赖项
"""
import base64
from typing import Optional
import cv2
import numpy as np
from fastapi import Depends, File, Form, Request, UploadFile
from api.exceptions import InvalidImageError, ModelNotLoadedError
from api.schemas.request import OCRRequestBase64, OCRRequestParams, ROIParams
from api.security import decode_base64_image, validate_uploaded_file
from ocr.pipeline import OCRPipeline, OCRResult
from utils.config import OCRConfig, PipelineConfig, ROIConfig
def get_ocr_pipeline(request: Request) -> OCRPipeline:
"""
获取 OCR Pipeline 实例
Args:
request: FastAPI 请求对象
Returns:
OCRPipeline 实例
Raises:
ModelNotLoadedError: 模型未加载
"""
pipeline = getattr(request.app.state, "ocr_pipeline", None)
if pipeline is None:
raise ModelNotLoadedError()
return pipeline
async def parse_multipart_image(
file: UploadFile = File(..., description="图片文件"),
) -> bytes:
"""
解析 multipart 上传的图片文件
Args:
file: 上传的文件
Returns:
验证后的图片字节数据
"""
content = await file.read()
return validate_uploaded_file(
content=content,
filename=file.filename,
content_type=file.content_type,
)
async def parse_multipart_params(
lang: str = Form(default="ch", description="识别语言"),
use_gpu: bool = Form(default=False, description="是否使用 GPU"),
drop_score: float = Form(default=0.5, ge=0.0, le=1.0, description="置信度阈值"),
roi_x: Optional[float] = Form(default=None, ge=0.0, le=1.0, description="ROI X"),
roi_y: Optional[float] = Form(default=None, ge=0.0, le=1.0, description="ROI Y"),
roi_width: Optional[float] = Form(
default=None, gt=0.0, le=1.0, description="ROI 宽度"
),
roi_height: Optional[float] = Form(
default=None, gt=0.0, le=1.0, description="ROI 高度"
),
return_annotated_image: bool = Form(
default=False, description="是否返回标注图片"
),
) -> OCRRequestParams:
"""
解析 multipart 表单参数
Returns:
OCR 请求参数对象
"""
return OCRRequestParams(
lang=lang,
use_gpu=use_gpu,
drop_score=drop_score,
roi_x=roi_x,
roi_y=roi_y,
roi_width=roi_width,
roi_height=roi_height,
return_annotated_image=return_annotated_image,
)
def decode_image_bytes(content: bytes) -> np.ndarray:
"""
将图片字节解码为 numpy 数组
Args:
content: 图片字节数据
Returns:
OpenCV 图像 (BGR 格式)
Raises:
InvalidImageError: 图片解码失败
"""
try:
nparr = np.frombuffer(content, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None:
raise InvalidImageError("图片解码失败")
return image
except Exception as e:
if isinstance(e, InvalidImageError):
raise
raise InvalidImageError(f"图片解码失败: {str(e)}")
def encode_image_base64(image: np.ndarray, format: str = ".jpg") -> str:
"""
将图像编码为 Base64 字符串
Args:
image: OpenCV 图像
format: 图片格式 (如 '.jpg', '.png')
Returns:
Base64 编码的图片字符串
"""
success, encoded = cv2.imencode(format, image)
if not success:
return ""
return base64.b64encode(encoded.tobytes()).decode("utf-8")
def build_pipeline_config(roi: Optional[ROIParams]) -> PipelineConfig:
"""
根据 ROI 参数构建管道配置
Args:
roi: ROI 参数
Returns:
管道配置对象
"""
if roi is None:
return PipelineConfig(roi=ROIConfig(enabled=False))
return PipelineConfig(
roi=ROIConfig(
enabled=True,
x_ratio=roi.x,
y_ratio=roi.y,
width_ratio=roi.width,
height_ratio=roi.height,
)
)