You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

304 lines
8.3 KiB
Python

# -*- coding: utf-8 -*-
"""
OCR 处理管道模块
提供图片 OCR 识别和结果解析的完整处理流程
"""
import time
import numpy as np
from typing import List, Optional, Dict, Any, Callable, TYPE_CHECKING
from dataclasses import dataclass, field
from ocr.engine import OCREngine, TextBlock
from utils.config import PipelineConfig, OCRConfig
if TYPE_CHECKING:
from ocr.express_parser import ExpressInfo, ExpressParser
@dataclass
class OCRResult:
"""
OCR 处理结果数据结构
Attributes:
image_index: 图片索引(批量处理时使用)
image_path: 图片路径
timestamp: 处理时间戳
processing_time_ms: 处理耗时(毫秒)
text_blocks: 识别出的文本块列表
roi_applied: 是否应用了 ROI 裁剪
roi_rect: ROI 矩形 (x, y, w, h),如果应用了 ROI
"""
image_index: int
image_path: Optional[str]
timestamp: float
processing_time_ms: float
text_blocks: List[TextBlock]
roi_applied: bool = False
roi_rect: Optional[tuple] = None
@property
def text_count(self) -> int:
"""识别出的文本数量"""
return len(self.text_blocks)
@property
def all_texts(self) -> List[str]:
"""获取所有识别出的文本"""
return [block.text for block in self.text_blocks]
@property
def full_text(self) -> str:
"""获取所有文本拼接结果"""
return "\n".join(self.all_texts)
@property
def average_confidence(self) -> float:
"""获取平均置信度"""
if not self.text_blocks:
return 0.0
return sum(b.confidence for b in self.text_blocks) / len(self.text_blocks)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式,便于 JSON 序列化"""
return {
"image_index": self.image_index,
"image_path": self.image_path,
"timestamp": self.timestamp,
"processing_time_ms": self.processing_time_ms,
"text_count": self.text_count,
"average_confidence": self.average_confidence,
"roi_applied": self.roi_applied,
"roi_rect": self.roi_rect,
"text_blocks": [block.to_dict() for block in self.text_blocks]
}
def filter_by_confidence(self, min_confidence: float) -> "OCRResult":
"""
按置信度过滤结果
Args:
min_confidence: 最小置信度阈值
Returns:
过滤后的 OCRResult
"""
filtered_blocks = [
block for block in self.text_blocks
if block.confidence >= min_confidence
]
return OCRResult(
image_index=self.image_index,
image_path=self.image_path,
timestamp=self.timestamp,
processing_time_ms=self.processing_time_ms,
text_blocks=filtered_blocks,
roi_applied=self.roi_applied,
roi_rect=self.roi_rect
)
def parse_express(self) -> "ExpressInfo":
"""
解析快递单信息
将分散的文本块合并并提取结构化的快递单信息
Returns:
结构化的快递单信息
"""
from ocr.express_parser import ExpressParser
parser = ExpressParser()
return parser.parse(self.text_blocks)
def merge_text(self) -> str:
"""
合并文本块为完整文本
基于位置信息智能合并,同一行的文本会被合并
Returns:
合并后的完整文本
"""
from ocr.express_parser import ExpressParser
parser = ExpressParser()
return parser.merge_text_blocks(self.text_blocks)
class OCRPipeline:
"""
OCR 处理管道
负责 ROI 裁剪、OCR 调用、结果封装
"""
def __init__(
self,
ocr_config: OCRConfig,
pipeline_config: Optional[PipelineConfig] = None
):
"""
初始化 OCR 管道
Args:
ocr_config: OCR 引擎配置
pipeline_config: 管道配置(可选)
"""
self._ocr_config = ocr_config
self._pipeline_config = pipeline_config or PipelineConfig()
self._engine = OCREngine(ocr_config)
self._image_counter: int = 0
# 预留扩展点:图片预处理回调
self._image_preprocessors: List[Callable[[np.ndarray], np.ndarray]] = []
# 预留扩展点:结果后处理回调
self._result_postprocessors: List[Callable[[OCRResult], OCRResult]] = []
def initialize(self) -> None:
"""初始化管道(预加载 OCR 模型)"""
self._engine.initialize()
def add_preprocessor(
self,
preprocessor: Callable[[np.ndarray], np.ndarray]
) -> None:
"""
添加图片预处理器
Args:
preprocessor: 预处理函数,接收图像返回处理后的图像
"""
self._image_preprocessors.append(preprocessor)
def add_postprocessor(
self,
postprocessor: Callable[[OCRResult], OCRResult]
) -> None:
"""
添加结果后处理器
Args:
postprocessor: 后处理函数,接收 OCRResult 返回处理后的结果
"""
self._result_postprocessors.append(postprocessor)
def _apply_roi(
self,
image: np.ndarray
) -> tuple:
"""
应用 ROI 裁剪
Args:
image: 原始图片
Returns:
(裁剪后的图像, ROI 偏移量, ROI 矩形)
"""
roi_config = self._pipeline_config.roi
if not roi_config.enabled:
return image, (0, 0), None
h, w = image.shape[:2]
x, y, roi_w, roi_h = roi_config.get_roi_rect(w, h)
# 边界检查
x = max(0, min(x, w - 1))
y = max(0, min(y, h - 1))
roi_w = min(roi_w, w - x)
roi_h = min(roi_h, h - y)
cropped = image[y:y + roi_h, x:x + roi_w]
return cropped, (x, y), (x, y, roi_w, roi_h)
def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
"""
执行图片预处理
Args:
image: 原始图片
Returns:
预处理后的图片
"""
processed = image
for preprocessor in self._image_preprocessors:
processed = preprocessor(processed)
return processed
def _postprocess_result(self, result: OCRResult) -> OCRResult:
"""
执行结果后处理
Args:
result: 原始结果
Returns:
后处理后的结果
"""
processed = result
for postprocessor in self._result_postprocessors:
processed = postprocessor(processed)
return processed
def process(
self,
image: np.ndarray,
image_path: Optional[str] = None
) -> OCRResult:
"""
处理单张图片
Args:
image: 输入图片 (numpy array, BGR 格式)
image_path: 图片路径(可选,用于结果记录)
Returns:
OCR 结果
"""
self._image_counter += 1
start_time = time.time()
# 应用 ROI 裁剪
cropped_image, roi_offset, roi_rect = self._apply_roi(image)
# 图片预处理
processed_image = self._preprocess_image(cropped_image)
# 执行 OCR
text_blocks = self._engine.recognize(processed_image, roi_offset)
# 计算处理耗时
processing_time_ms = (time.time() - start_time) * 1000
# 构建结果
result = OCRResult(
image_index=self._image_counter,
image_path=image_path,
timestamp=time.time(),
processing_time_ms=processing_time_ms,
text_blocks=text_blocks,
roi_applied=self._pipeline_config.roi.enabled,
roi_rect=roi_rect
)
# 结果后处理
result = self._postprocess_result(result)
return result
def reset_counter(self) -> None:
"""重置图片计数器"""
self._image_counter = 0
@property
def image_counter(self) -> int:
"""获取已处理的图片计数"""
return self._image_counter
@property
def config(self) -> PipelineConfig:
"""获取管道配置"""
return self._pipeline_config