|
|
# -*- coding: utf-8 -*-
|
|
|
"""
|
|
|
OCR 引擎模块
|
|
|
封装 PaddleOCR,提供统一的 OCR 接口
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
|
|
|
# 在导入 PaddleOCR 之前设置环境变量
|
|
|
# 解决 Windows 中文用户名路径问题
|
|
|
_PROJECT_ROOT = Path(__file__).parent.parent
|
|
|
_MODELS_DIR = _PROJECT_ROOT / "models"
|
|
|
_MODELS_DIR.mkdir(exist_ok=True)
|
|
|
os.environ["PADDLEOCR_HOME"] = str(_MODELS_DIR)
|
|
|
|
|
|
import numpy as np
|
|
|
from typing import List, Optional, Any
|
|
|
from dataclasses import dataclass
|
|
|
from paddleocr import PaddleOCR
|
|
|
|
|
|
from utils.config import OCRConfig
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class TextBlock:
|
|
|
"""
|
|
|
文本块数据结构
|
|
|
表示 OCR 识别出的单个文本区域
|
|
|
|
|
|
Attributes:
|
|
|
text: 识别出的文本内容
|
|
|
confidence: 置信度 (0.0 ~ 1.0)
|
|
|
bbox: 边界框,4 个点的坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
|
|
bbox_offset: ROI 偏移量,用于还原到原图坐标
|
|
|
"""
|
|
|
text: str
|
|
|
confidence: float
|
|
|
bbox: List[List[float]]
|
|
|
bbox_offset: tuple = (0, 0)
|
|
|
|
|
|
@property
|
|
|
def bbox_with_offset(self) -> List[List[float]]:
|
|
|
"""获取带偏移的边界框(还原到原图坐标)"""
|
|
|
offset_x, offset_y = self.bbox_offset
|
|
|
return [[p[0] + offset_x, p[1] + offset_y] for p in self.bbox]
|
|
|
|
|
|
@property
|
|
|
def center(self) -> tuple:
|
|
|
"""获取文本块中心点"""
|
|
|
x_coords = [p[0] for p in self.bbox]
|
|
|
y_coords = [p[1] for p in self.bbox]
|
|
|
return (sum(x_coords) / 4, sum(y_coords) / 4)
|
|
|
|
|
|
@property
|
|
|
def width(self) -> float:
|
|
|
"""获取文本块宽度"""
|
|
|
x_coords = [p[0] for p in self.bbox]
|
|
|
return max(x_coords) - min(x_coords)
|
|
|
|
|
|
@property
|
|
|
def height(self) -> float:
|
|
|
"""获取文本块高度"""
|
|
|
y_coords = [p[1] for p in self.bbox]
|
|
|
return max(y_coords) - min(y_coords)
|
|
|
|
|
|
def to_dict(self) -> dict:
|
|
|
"""转换为字典格式"""
|
|
|
return {
|
|
|
"text": self.text,
|
|
|
"confidence": self.confidence,
|
|
|
"bbox": self.bbox,
|
|
|
"bbox_with_offset": self.bbox_with_offset,
|
|
|
"center": self.center,
|
|
|
"width": self.width,
|
|
|
"height": self.height
|
|
|
}
|
|
|
|
|
|
|
|
|
class OCREngine:
|
|
|
"""
|
|
|
OCR 引擎类
|
|
|
封装 PaddleOCR,提供简洁的 OCR 调用接口
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: OCRConfig):
|
|
|
"""
|
|
|
初始化 OCR 引擎
|
|
|
|
|
|
Args:
|
|
|
config: OCR 配置
|
|
|
"""
|
|
|
self._config = config
|
|
|
self._ocr: Optional[PaddleOCR] = None
|
|
|
|
|
|
def initialize(self) -> None:
|
|
|
"""
|
|
|
初始化 PaddleOCR 实例
|
|
|
延迟初始化,避免在导入时加载模型
|
|
|
适配 PaddleOCR 2.x API
|
|
|
"""
|
|
|
if self._ocr is not None:
|
|
|
return
|
|
|
|
|
|
# 构建参数
|
|
|
params = {
|
|
|
"lang": self._config.lang,
|
|
|
"use_angle_cls": self._config.use_angle_cls,
|
|
|
"use_gpu": self._config.use_gpu,
|
|
|
"det_db_thresh": self._config.det_db_thresh,
|
|
|
"det_db_box_thresh": self._config.det_db_box_thresh,
|
|
|
"drop_score": self._config.drop_score,
|
|
|
"show_log": self._config.show_log
|
|
|
}
|
|
|
|
|
|
# 如果指定了模型目录,则使用自定义路径(解决中文路径问题)
|
|
|
if self._config.det_model_dir:
|
|
|
params["det_model_dir"] = self._config.det_model_dir
|
|
|
if self._config.rec_model_dir:
|
|
|
params["rec_model_dir"] = self._config.rec_model_dir
|
|
|
if self._config.cls_model_dir:
|
|
|
params["cls_model_dir"] = self._config.cls_model_dir
|
|
|
|
|
|
# PaddleOCR 2.x API
|
|
|
self._ocr = PaddleOCR(**params)
|
|
|
|
|
|
def recognize(
|
|
|
self,
|
|
|
image: np.ndarray,
|
|
|
roi_offset: tuple = (0, 0)
|
|
|
) -> List[TextBlock]:
|
|
|
"""
|
|
|
对图像进行 OCR 识别
|
|
|
|
|
|
Args:
|
|
|
image: 输入图像 (numpy array, BGR 或灰度图)
|
|
|
roi_offset: ROI 偏移量 (x, y),用于还原坐标
|
|
|
|
|
|
Returns:
|
|
|
识别结果列表
|
|
|
"""
|
|
|
# 确保引擎已初始化
|
|
|
if self._ocr is None:
|
|
|
self.initialize()
|
|
|
|
|
|
# 执行 OCR (PaddleOCR 2.x API)
|
|
|
result = self._ocr.ocr(image, cls=self._config.use_angle_cls)
|
|
|
|
|
|
# 解析结果
|
|
|
text_blocks: List[TextBlock] = []
|
|
|
|
|
|
# PaddleOCR 返回格式: [[line1, line2, ...]] 或 None
|
|
|
if result is None or len(result) == 0:
|
|
|
return text_blocks
|
|
|
|
|
|
# 遍历每一行结果
|
|
|
for line in result:
|
|
|
if line is None:
|
|
|
continue
|
|
|
for item in line:
|
|
|
if item is None or len(item) < 2:
|
|
|
continue
|
|
|
|
|
|
bbox = item[0] # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
|
|
text_info = item[1] # (text, confidence)
|
|
|
|
|
|
if len(text_info) < 2:
|
|
|
continue
|
|
|
|
|
|
text = text_info[0]
|
|
|
confidence = float(text_info[1])
|
|
|
|
|
|
# 过滤低置信度结果
|
|
|
if confidence < self._config.drop_score:
|
|
|
continue
|
|
|
|
|
|
text_block = TextBlock(
|
|
|
text=text,
|
|
|
confidence=confidence,
|
|
|
bbox=bbox,
|
|
|
bbox_offset=roi_offset
|
|
|
)
|
|
|
text_blocks.append(text_block)
|
|
|
|
|
|
return text_blocks
|
|
|
|
|
|
def recognize_batch(
|
|
|
self,
|
|
|
images: List[np.ndarray]
|
|
|
) -> List[List[TextBlock]]:
|
|
|
"""
|
|
|
批量 OCR 识别
|
|
|
|
|
|
Args:
|
|
|
images: 输入图像列表
|
|
|
|
|
|
Returns:
|
|
|
每张图像的识别结果列表
|
|
|
"""
|
|
|
return [self.recognize(img) for img in images]
|
|
|
|
|
|
@property
|
|
|
def config(self) -> OCRConfig:
|
|
|
"""获取当前配置"""
|
|
|
return self._config
|
|
|
|
|
|
def update_config(self, **kwargs) -> None:
|
|
|
"""
|
|
|
更新配置并重新初始化引擎
|
|
|
|
|
|
Args:
|
|
|
**kwargs: 要更新的配置项
|
|
|
"""
|
|
|
for key, value in kwargs.items():
|
|
|
if hasattr(self._config, key):
|
|
|
setattr(self._config, key, value)
|
|
|
|
|
|
# 重新初始化
|
|
|
self._ocr = None
|
|
|
self.initialize()
|