You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
304 lines
8.3 KiB
Python
304 lines
8.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
OCR 处理管道模块
|
|
提供图片 OCR 识别和结果解析的完整处理流程
|
|
"""
|
|
|
|
import time
|
|
import numpy as np
|
|
from typing import List, Optional, Dict, Any, Callable, TYPE_CHECKING
|
|
from dataclasses import dataclass, field
|
|
|
|
from ocr.engine import OCREngine, TextBlock
|
|
from utils.config import PipelineConfig, OCRConfig
|
|
|
|
if TYPE_CHECKING:
|
|
from ocr.express_parser import ExpressInfo, ExpressParser
|
|
|
|
|
|
@dataclass
|
|
class OCRResult:
|
|
"""
|
|
OCR 处理结果数据结构
|
|
|
|
Attributes:
|
|
image_index: 图片索引(批量处理时使用)
|
|
image_path: 图片路径
|
|
timestamp: 处理时间戳
|
|
processing_time_ms: 处理耗时(毫秒)
|
|
text_blocks: 识别出的文本块列表
|
|
roi_applied: 是否应用了 ROI 裁剪
|
|
roi_rect: ROI 矩形 (x, y, w, h),如果应用了 ROI
|
|
"""
|
|
image_index: int
|
|
image_path: Optional[str]
|
|
timestamp: float
|
|
processing_time_ms: float
|
|
text_blocks: List[TextBlock]
|
|
roi_applied: bool = False
|
|
roi_rect: Optional[tuple] = None
|
|
|
|
@property
|
|
def text_count(self) -> int:
|
|
"""识别出的文本数量"""
|
|
return len(self.text_blocks)
|
|
|
|
@property
|
|
def all_texts(self) -> List[str]:
|
|
"""获取所有识别出的文本"""
|
|
return [block.text for block in self.text_blocks]
|
|
|
|
@property
|
|
def full_text(self) -> str:
|
|
"""获取所有文本拼接结果"""
|
|
return "\n".join(self.all_texts)
|
|
|
|
@property
|
|
def average_confidence(self) -> float:
|
|
"""获取平均置信度"""
|
|
if not self.text_blocks:
|
|
return 0.0
|
|
return sum(b.confidence for b in self.text_blocks) / len(self.text_blocks)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""转换为字典格式,便于 JSON 序列化"""
|
|
return {
|
|
"image_index": self.image_index,
|
|
"image_path": self.image_path,
|
|
"timestamp": self.timestamp,
|
|
"processing_time_ms": self.processing_time_ms,
|
|
"text_count": self.text_count,
|
|
"average_confidence": self.average_confidence,
|
|
"roi_applied": self.roi_applied,
|
|
"roi_rect": self.roi_rect,
|
|
"text_blocks": [block.to_dict() for block in self.text_blocks]
|
|
}
|
|
|
|
def filter_by_confidence(self, min_confidence: float) -> "OCRResult":
|
|
"""
|
|
按置信度过滤结果
|
|
|
|
Args:
|
|
min_confidence: 最小置信度阈值
|
|
|
|
Returns:
|
|
过滤后的 OCRResult
|
|
"""
|
|
filtered_blocks = [
|
|
block for block in self.text_blocks
|
|
if block.confidence >= min_confidence
|
|
]
|
|
return OCRResult(
|
|
image_index=self.image_index,
|
|
image_path=self.image_path,
|
|
timestamp=self.timestamp,
|
|
processing_time_ms=self.processing_time_ms,
|
|
text_blocks=filtered_blocks,
|
|
roi_applied=self.roi_applied,
|
|
roi_rect=self.roi_rect
|
|
)
|
|
|
|
def parse_express(self) -> "ExpressInfo":
|
|
"""
|
|
解析快递单信息
|
|
|
|
将分散的文本块合并并提取结构化的快递单信息
|
|
|
|
Returns:
|
|
结构化的快递单信息
|
|
"""
|
|
from ocr.express_parser import ExpressParser
|
|
parser = ExpressParser()
|
|
return parser.parse(self.text_blocks)
|
|
|
|
def merge_text(self) -> str:
|
|
"""
|
|
合并文本块为完整文本
|
|
|
|
基于位置信息智能合并,同一行的文本会被合并
|
|
|
|
Returns:
|
|
合并后的完整文本
|
|
"""
|
|
from ocr.express_parser import ExpressParser
|
|
parser = ExpressParser()
|
|
return parser.merge_text_blocks(self.text_blocks)
|
|
|
|
|
|
class OCRPipeline:
|
|
"""
|
|
OCR 处理管道
|
|
负责 ROI 裁剪、OCR 调用、结果封装
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
ocr_config: OCRConfig,
|
|
pipeline_config: Optional[PipelineConfig] = None
|
|
):
|
|
"""
|
|
初始化 OCR 管道
|
|
|
|
Args:
|
|
ocr_config: OCR 引擎配置
|
|
pipeline_config: 管道配置(可选)
|
|
"""
|
|
self._ocr_config = ocr_config
|
|
self._pipeline_config = pipeline_config or PipelineConfig()
|
|
self._engine = OCREngine(ocr_config)
|
|
self._image_counter: int = 0
|
|
|
|
# 预留扩展点:图片预处理回调
|
|
self._image_preprocessors: List[Callable[[np.ndarray], np.ndarray]] = []
|
|
|
|
# 预留扩展点:结果后处理回调
|
|
self._result_postprocessors: List[Callable[[OCRResult], OCRResult]] = []
|
|
|
|
def initialize(self) -> None:
|
|
"""初始化管道(预加载 OCR 模型)"""
|
|
self._engine.initialize()
|
|
|
|
def add_preprocessor(
|
|
self,
|
|
preprocessor: Callable[[np.ndarray], np.ndarray]
|
|
) -> None:
|
|
"""
|
|
添加图片预处理器
|
|
|
|
Args:
|
|
preprocessor: 预处理函数,接收图像返回处理后的图像
|
|
"""
|
|
self._image_preprocessors.append(preprocessor)
|
|
|
|
def add_postprocessor(
|
|
self,
|
|
postprocessor: Callable[[OCRResult], OCRResult]
|
|
) -> None:
|
|
"""
|
|
添加结果后处理器
|
|
|
|
Args:
|
|
postprocessor: 后处理函数,接收 OCRResult 返回处理后的结果
|
|
"""
|
|
self._result_postprocessors.append(postprocessor)
|
|
|
|
def _apply_roi(
|
|
self,
|
|
image: np.ndarray
|
|
) -> tuple:
|
|
"""
|
|
应用 ROI 裁剪
|
|
|
|
Args:
|
|
image: 原始图片
|
|
|
|
Returns:
|
|
(裁剪后的图像, ROI 偏移量, ROI 矩形)
|
|
"""
|
|
roi_config = self._pipeline_config.roi
|
|
|
|
if not roi_config.enabled:
|
|
return image, (0, 0), None
|
|
|
|
h, w = image.shape[:2]
|
|
x, y, roi_w, roi_h = roi_config.get_roi_rect(w, h)
|
|
|
|
# 边界检查
|
|
x = max(0, min(x, w - 1))
|
|
y = max(0, min(y, h - 1))
|
|
roi_w = min(roi_w, w - x)
|
|
roi_h = min(roi_h, h - y)
|
|
|
|
cropped = image[y:y + roi_h, x:x + roi_w]
|
|
return cropped, (x, y), (x, y, roi_w, roi_h)
|
|
|
|
def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
|
|
"""
|
|
执行图片预处理
|
|
|
|
Args:
|
|
image: 原始图片
|
|
|
|
Returns:
|
|
预处理后的图片
|
|
"""
|
|
processed = image
|
|
for preprocessor in self._image_preprocessors:
|
|
processed = preprocessor(processed)
|
|
return processed
|
|
|
|
def _postprocess_result(self, result: OCRResult) -> OCRResult:
|
|
"""
|
|
执行结果后处理
|
|
|
|
Args:
|
|
result: 原始结果
|
|
|
|
Returns:
|
|
后处理后的结果
|
|
"""
|
|
processed = result
|
|
for postprocessor in self._result_postprocessors:
|
|
processed = postprocessor(processed)
|
|
return processed
|
|
|
|
def process(
|
|
self,
|
|
image: np.ndarray,
|
|
image_path: Optional[str] = None
|
|
) -> OCRResult:
|
|
"""
|
|
处理单张图片
|
|
|
|
Args:
|
|
image: 输入图片 (numpy array, BGR 格式)
|
|
image_path: 图片路径(可选,用于结果记录)
|
|
|
|
Returns:
|
|
OCR 结果
|
|
"""
|
|
self._image_counter += 1
|
|
start_time = time.time()
|
|
|
|
# 应用 ROI 裁剪
|
|
cropped_image, roi_offset, roi_rect = self._apply_roi(image)
|
|
|
|
# 图片预处理
|
|
processed_image = self._preprocess_image(cropped_image)
|
|
|
|
# 执行 OCR
|
|
text_blocks = self._engine.recognize(processed_image, roi_offset)
|
|
|
|
# 计算处理耗时
|
|
processing_time_ms = (time.time() - start_time) * 1000
|
|
|
|
# 构建结果
|
|
result = OCRResult(
|
|
image_index=self._image_counter,
|
|
image_path=image_path,
|
|
timestamp=time.time(),
|
|
processing_time_ms=processing_time_ms,
|
|
text_blocks=text_blocks,
|
|
roi_applied=self._pipeline_config.roi.enabled,
|
|
roi_rect=roi_rect
|
|
)
|
|
|
|
# 结果后处理
|
|
result = self._postprocess_result(result)
|
|
|
|
return result
|
|
|
|
def reset_counter(self) -> None:
|
|
"""重置图片计数器"""
|
|
self._image_counter = 0
|
|
|
|
@property
|
|
def image_counter(self) -> int:
|
|
"""获取已处理的图片计数"""
|
|
return self._image_counter
|
|
|
|
@property
|
|
def config(self) -> PipelineConfig:
|
|
"""获取管道配置"""
|
|
return self._pipeline_config
|