You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

335 lines
9.4 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""
OCR 处理管道模块
提供图片 OCR 识别和结果解析的完整处理流程
"""
import time
import numpy as np
from typing import List, Optional, Dict, Any, Callable, TYPE_CHECKING
from dataclasses import dataclass, field
from ocr.engine import OCREngine, TextBlock
from utils.config import PipelineConfig, OCRConfig
if TYPE_CHECKING:
from ocr.express_parser import ExpressInfo, ExpressParser
@dataclass
class OCRResult:
"""
OCR 处理结果数据结构
Attributes:
image_index: 图片索引(批量处理时使用)
image_path: 图片路径
timestamp: 处理时间戳
processing_time_ms: 处理耗时(毫秒)
text_blocks: 识别出的文本块列表
roi_applied: 是否应用了 ROI 裁剪
roi_rect: ROI 矩形 (x, y, w, h),如果应用了 ROI
"""
image_index: int
image_path: Optional[str]
timestamp: float
processing_time_ms: float
text_blocks: List[TextBlock]
roi_applied: bool = False
roi_rect: Optional[tuple] = None
@property
def text_count(self) -> int:
"""识别出的文本数量"""
return len(self.text_blocks)
@property
def all_texts(self) -> List[str]:
"""获取所有识别出的文本"""
return [block.text for block in self.text_blocks]
@property
def full_text(self) -> str:
"""获取所有文本拼接结果"""
return "\n".join(self.all_texts)
@property
def average_confidence(self) -> float:
"""获取平均置信度"""
if not self.text_blocks:
return 0.0
return sum(b.confidence for b in self.text_blocks) / len(self.text_blocks)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式,便于 JSON 序列化"""
return {
"image_index": self.image_index,
"image_path": self.image_path,
"timestamp": self.timestamp,
"processing_time_ms": self.processing_time_ms,
"text_count": self.text_count,
"average_confidence": self.average_confidence,
"roi_applied": self.roi_applied,
"roi_rect": self.roi_rect,
"text_blocks": [block.to_dict() for block in self.text_blocks]
}
def filter_by_confidence(self, min_confidence: float) -> "OCRResult":
"""
按置信度过滤结果
Args:
min_confidence: 最小置信度阈值
Returns:
过滤后的 OCRResult
"""
filtered_blocks = [
block for block in self.text_blocks
if block.confidence >= min_confidence
]
return OCRResult(
image_index=self.image_index,
image_path=self.image_path,
timestamp=self.timestamp,
processing_time_ms=self.processing_time_ms,
text_blocks=filtered_blocks,
roi_applied=self.roi_applied,
roi_rect=self.roi_rect
)
def parse_express(self) -> "ExpressInfo":
"""
解析快递单信息
将分散的文本块合并并提取结构化的快递单信息
Returns:
结构化的快递单信息
"""
from ocr.express_parser import ExpressParser
parser = ExpressParser()
return parser.parse(self.text_blocks)
def merge_text(self) -> str:
"""
合并文本块为完整文本
基于位置信息智能合并,同一行的文本会被合并
Returns:
合并后的完整文本
"""
from ocr.express_parser import ExpressParser
parser = ExpressParser()
return parser.merge_text_blocks(self.text_blocks)
class OCRPipeline:
"""
OCR 处理管道
负责 ROI 裁剪、OCR 调用、结果封装
"""
def __init__(
self,
ocr_config: OCRConfig,
pipeline_config: Optional[PipelineConfig] = None
):
"""
初始化 OCR 管道
Args:
ocr_config: OCR 引擎配置
pipeline_config: 管道配置(可选)
"""
self._ocr_config = ocr_config
self._pipeline_config = pipeline_config or PipelineConfig()
self._engine = OCREngine(ocr_config)
self._image_counter: int = 0
# 预留扩展点:图片预处理回调
self._image_preprocessors: List[Callable[[np.ndarray], np.ndarray]] = []
# 预留扩展点:结果后处理回调
self._result_postprocessors: List[Callable[[OCRResult], OCRResult]] = []
def initialize(self) -> None:
"""初始化管道(预加载 OCR 模型)"""
self._engine.initialize()
def add_preprocessor(
self,
preprocessor: Callable[[np.ndarray], np.ndarray]
) -> None:
"""
添加图片预处理器
Args:
preprocessor: 预处理函数,接收图像返回处理后的图像
"""
self._image_preprocessors.append(preprocessor)
def add_postprocessor(
self,
postprocessor: Callable[[OCRResult], OCRResult]
) -> None:
"""
添加结果后处理器
Args:
postprocessor: 后处理函数,接收 OCRResult 返回处理后的结果
"""
self._result_postprocessors.append(postprocessor)
def _apply_roi(
self,
image: np.ndarray
) -> tuple:
"""
应用 ROI 裁剪(使用默认配置)
Args:
image: 原始图片
Returns:
(裁剪后的图像, ROI 偏移量, ROI 矩形)
"""
return self._apply_roi_with_config(image, self._pipeline_config)
def _apply_roi_with_config(
self,
image: np.ndarray,
config: PipelineConfig
) -> tuple:
"""
应用 ROI 裁剪(使用指定配置,线程安全)
Args:
image: 原始图片
config: 管道配置
Returns:
(裁剪后的图像, ROI 偏移量, ROI 矩形)
"""
roi_config = config.roi
if not roi_config.enabled:
return image, (0, 0), None
h, w = image.shape[:2]
x, y, roi_w, roi_h = roi_config.get_roi_rect(w, h)
# 边界检查
x = max(0, min(x, w - 1))
y = max(0, min(y, h - 1))
roi_w = min(roi_w, w - x)
roi_h = min(roi_h, h - y)
cropped = image[y:y + roi_h, x:x + roi_w]
return cropped, (x, y), (x, y, roi_w, roi_h)
def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
"""
执行图片预处理
Args:
image: 原始图片
Returns:
预处理后的图片
"""
processed = image
for preprocessor in self._image_preprocessors:
processed = preprocessor(processed)
return processed
def _postprocess_result(self, result: OCRResult) -> OCRResult:
"""
执行结果后处理
Args:
result: 原始结果
Returns:
后处理后的结果
"""
processed = result
for postprocessor in self._result_postprocessors:
processed = postprocessor(processed)
return processed
def process(
self,
image: np.ndarray,
image_path: Optional[str] = None,
pipeline_config: Optional[PipelineConfig] = None,
drop_score: Optional[float] = None,
) -> OCRResult:
"""
处理单张图片
Args:
image: 输入图片 (numpy array, BGR 格式)
image_path: 图片路径(可选,用于结果记录)
pipeline_config: 临时管道配置(可选,用于单次请求的配置覆盖,线程安全)
drop_score: 置信度阈值(可选,用于过滤低置信度结果)
Returns:
OCR 结果
"""
self._image_counter += 1
start_time = time.time()
# 使用临时配置或默认配置(线程安全:不修改共享状态)
config = pipeline_config if pipeline_config is not None else self._pipeline_config
# 应用 ROI 裁剪
cropped_image, roi_offset, roi_rect = self._apply_roi_with_config(image, config)
# 图片预处理
processed_image = self._preprocess_image(cropped_image)
# 执行 OCR
text_blocks = self._engine.recognize(processed_image, roi_offset)
# 应用置信度过滤(如果指定了 drop_score
if drop_score is not None:
text_blocks = [
block for block in text_blocks
if block.confidence >= drop_score
]
# 计算处理耗时
processing_time_ms = (time.time() - start_time) * 1000
# 构建结果
result = OCRResult(
image_index=self._image_counter,
image_path=image_path,
timestamp=time.time(),
processing_time_ms=processing_time_ms,
text_blocks=text_blocks,
roi_applied=config.roi.enabled,
roi_rect=roi_rect
)
# 结果后处理
result = self._postprocess_result(result)
return result
def reset_counter(self) -> None:
"""重置图片计数器"""
self._image_counter = 0
@property
def image_counter(self) -> int:
"""获取已处理的图片计数"""
return self._image_counter
@property
def config(self) -> PipelineConfig:
"""获取管道配置"""
return self._pipeline_config