You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

221 lines
6.0 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""
OCR 引擎模块
封装 PaddleOCR提供统一的 OCR 接口
"""
import os
from pathlib import Path
# 在导入 PaddleOCR 之前设置环境变量
# 解决 Windows 中文用户名路径问题
_PROJECT_ROOT = Path(__file__).parent.parent
_MODELS_DIR = _PROJECT_ROOT / "models"
_MODELS_DIR.mkdir(exist_ok=True)
os.environ["PADDLEOCR_HOME"] = str(_MODELS_DIR)
import numpy as np
from typing import List, Optional, Any
from dataclasses import dataclass
from paddleocr import PaddleOCR
from utils.config import OCRConfig
@dataclass
class TextBlock:
"""
文本块数据结构
表示 OCR 识别出的单个文本区域
Attributes:
text: 识别出的文本内容
confidence: 置信度 (0.0 ~ 1.0)
bbox: 边界框4 个点的坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
bbox_offset: ROI 偏移量,用于还原到原图坐标
"""
text: str
confidence: float
bbox: List[List[float]]
bbox_offset: tuple = (0, 0)
@property
def bbox_with_offset(self) -> List[List[float]]:
"""获取带偏移的边界框(还原到原图坐标)"""
offset_x, offset_y = self.bbox_offset
return [[p[0] + offset_x, p[1] + offset_y] for p in self.bbox]
@property
def center(self) -> tuple:
"""获取文本块中心点"""
x_coords = [p[0] for p in self.bbox]
y_coords = [p[1] for p in self.bbox]
return (sum(x_coords) / 4, sum(y_coords) / 4)
@property
def width(self) -> float:
"""获取文本块宽度"""
x_coords = [p[0] for p in self.bbox]
return max(x_coords) - min(x_coords)
@property
def height(self) -> float:
"""获取文本块高度"""
y_coords = [p[1] for p in self.bbox]
return max(y_coords) - min(y_coords)
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
"text": self.text,
"confidence": self.confidence,
"bbox": self.bbox,
"bbox_with_offset": self.bbox_with_offset,
"center": self.center,
"width": self.width,
"height": self.height
}
class OCREngine:
"""
OCR 引擎类
封装 PaddleOCR提供简洁的 OCR 调用接口
"""
def __init__(self, config: OCRConfig):
"""
初始化 OCR 引擎
Args:
config: OCR 配置
"""
self._config = config
self._ocr: Optional[PaddleOCR] = None
def initialize(self) -> None:
"""
初始化 PaddleOCR 实例
延迟初始化,避免在导入时加载模型
适配 PaddleOCR 2.x API
"""
if self._ocr is not None:
return
# 构建参数
params = {
"lang": self._config.lang,
"use_angle_cls": self._config.use_angle_cls,
"use_gpu": self._config.use_gpu,
"det_db_thresh": self._config.det_db_thresh,
"det_db_box_thresh": self._config.det_db_box_thresh,
"drop_score": self._config.drop_score,
"show_log": self._config.show_log
}
# 如果指定了模型目录,则使用自定义路径(解决中文路径问题)
if self._config.det_model_dir:
params["det_model_dir"] = self._config.det_model_dir
if self._config.rec_model_dir:
params["rec_model_dir"] = self._config.rec_model_dir
if self._config.cls_model_dir:
params["cls_model_dir"] = self._config.cls_model_dir
# PaddleOCR 2.x API
self._ocr = PaddleOCR(**params)
def recognize(
self,
image: np.ndarray,
roi_offset: tuple = (0, 0)
) -> List[TextBlock]:
"""
对图像进行 OCR 识别
Args:
image: 输入图像 (numpy array, BGR 或灰度图)
roi_offset: ROI 偏移量 (x, y),用于还原坐标
Returns:
识别结果列表
"""
# 确保引擎已初始化
if self._ocr is None:
self.initialize()
# 执行 OCR (PaddleOCR 2.x API)
result = self._ocr.ocr(image, cls=self._config.use_angle_cls)
# 解析结果
text_blocks: List[TextBlock] = []
# PaddleOCR 返回格式: [[line1, line2, ...]] 或 None
if result is None or len(result) == 0:
return text_blocks
# 遍历每一行结果
for line in result:
if line is None:
continue
for item in line:
if item is None or len(item) < 2:
continue
bbox = item[0] # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
text_info = item[1] # (text, confidence)
if len(text_info) < 2:
continue
text = text_info[0]
confidence = float(text_info[1])
# 过滤低置信度结果
if confidence < self._config.drop_score:
continue
text_block = TextBlock(
text=text,
confidence=confidence,
bbox=bbox,
bbox_offset=roi_offset
)
text_blocks.append(text_block)
return text_blocks
def recognize_batch(
self,
images: List[np.ndarray]
) -> List[List[TextBlock]]:
"""
批量 OCR 识别
Args:
images: 输入图像列表
Returns:
每张图像的识别结果列表
"""
return [self.recognize(img) for img in images]
@property
def config(self) -> OCRConfig:
"""获取当前配置"""
return self._config
def update_config(self, **kwargs) -> None:
"""
更新配置并重新初始化引擎
Args:
**kwargs: 要更新的配置项
"""
for key, value in kwargs.items():
if hasattr(self._config, key):
setattr(self._config, key, value)
# 重新初始化
self._ocr = None
self.initialize()