You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

243 lines
6.7 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""
配置管理模块
集中管理所有可配置参数,便于维护和扩展
"""
from dataclasses import dataclass, field
from typing import Optional, Tuple, List
from enum import Enum
class InputMode(Enum):
"""输入模式枚举"""
SINGLE = "single" # 单张图片
BATCH = "batch" # 多张图片列表
DIRECTORY = "directory" # 目录批量
@dataclass
class InputConfig:
"""
图片输入配置
Attributes:
mode: 输入模式
image_path: 单张图片路径
image_paths: 多张图片路径列表
directory: 图片目录路径
pattern: 文件匹配模式(如 "*.jpg"
recursive: 是否递归搜索子目录
"""
mode: InputMode = InputMode.SINGLE
image_path: Optional[str] = None
image_paths: Optional[List[str]] = None
directory: Optional[str] = None
pattern: Optional[str] = None
recursive: bool = False
def __post_init__(self):
"""参数校验"""
if self.mode == InputMode.SINGLE and not self.image_path:
raise ValueError("单张图片模式下必须指定 image_path")
if self.mode == InputMode.BATCH and not self.image_paths:
raise ValueError("批量模式下必须指定 image_paths")
if self.mode == InputMode.DIRECTORY and not self.directory:
raise ValueError("目录模式下必须指定 directory")
@dataclass
class ROIConfig:
"""
感兴趣区域(ROI)配置
使用归一化坐标 (0.0 ~ 1.0),便于适配不同分辨率
Attributes:
enabled: 是否启用 ROI 裁剪
x_ratio: ROI 左上角 x 坐标比例
y_ratio: ROI 左上角 y 坐标比例
width_ratio: ROI 宽度比例
height_ratio: ROI 高度比例
"""
enabled: bool = False
x_ratio: float = 0.1
y_ratio: float = 0.1
width_ratio: float = 0.8
height_ratio: float = 0.8
def get_roi_rect(self, frame_width: int, frame_height: int) -> Tuple[int, int, int, int]:
"""
根据帧尺寸计算实际 ROI 矩形
Args:
frame_width: 帧宽度
frame_height: 帧高度
Returns:
(x, y, width, height) 像素坐标
"""
x = int(frame_width * self.x_ratio)
y = int(frame_height * self.y_ratio)
width = int(frame_width * self.width_ratio)
height = int(frame_height * self.height_ratio)
return x, y, width, height
@dataclass
class OCRConfig:
"""
OCR 引擎配置 (适配 PaddleOCR 2.x API)
Attributes:
lang: 识别语言,支持 "ch"(中文), "en"(英文) 等
use_angle_cls: 是否启用方向分类器
use_gpu: 是否使用 GPU 加速
det_db_thresh: 文本检测阈值
det_db_box_thresh: 检测框阈值
drop_score: 低于此置信度的结果将被过滤
show_log: 是否显示 PaddleOCR 日志
det_model_dir: 检测模型目录路径
rec_model_dir: 识别模型目录路径
cls_model_dir: 分类模型目录路径
"""
lang: str = "ch"
use_angle_cls: bool = True
use_gpu: bool = False
det_db_thresh: float = 0.3
det_db_box_thresh: float = 0.5
drop_score: float = 0.5
show_log: bool = False
det_model_dir: Optional[str] = None
rec_model_dir: Optional[str] = None
cls_model_dir: Optional[str] = None
@dataclass
class PipelineConfig:
"""
OCR 处理管道配置
Attributes:
roi: ROI 配置
"""
roi: ROIConfig = field(default_factory=ROIConfig)
@dataclass
class VisualizeConfig:
"""
可视化配置
Attributes:
show_window: 是否显示可视化窗口
window_name: 窗口名称
box_color: 文本框颜色 (B, G, R)
box_thickness: 文本框线宽
text_color: 文字颜色 (B, G, R)
text_scale: 文字缩放比例
text_thickness: 文字线宽
show_confidence: 是否在文字旁显示置信度
font_path: 中文字体路径None 则使用 OpenCV 默认字体
"""
show_window: bool = False
window_name: str = "OCR Result"
box_color: Tuple[int, int, int] = (0, 255, 0)
box_thickness: int = 2
text_color: Tuple[int, int, int] = (0, 0, 255)
text_scale: float = 0.6
text_thickness: int = 1
show_confidence: bool = True
font_path: Optional[str] = None
@dataclass
class OutputConfig:
"""
输出配置
Attributes:
output_dir: 输出目录路径
save_json: 是否保存 JSON 结果
save_image: 是否保存标注后的图片
json_filename: JSON 文件名模板
image_suffix: 标注图片后缀
"""
output_dir: Optional[str] = None
save_json: bool = True
save_image: bool = False
json_filename: str = "ocr_result.json"
image_suffix: str = "_ocr"
@dataclass
class Config:
"""
全局配置类,聚合所有配置模块
Attributes:
input: 输入配置
ocr: OCR 引擎配置
pipeline: 处理管道配置
visualize: 可视化配置
output: 输出配置
"""
input: Optional[InputConfig] = None
ocr: OCRConfig = field(default_factory=OCRConfig)
pipeline: PipelineConfig = field(default_factory=PipelineConfig)
visualize: VisualizeConfig = field(default_factory=VisualizeConfig)
output: OutputConfig = field(default_factory=OutputConfig)
@classmethod
def default(cls) -> "Config":
"""创建默认配置"""
return cls()
@classmethod
def for_single_image(cls, image_path: str) -> "Config":
"""
创建单张图片模式的配置
Args:
image_path: 图片路径
"""
config = cls()
config.input = InputConfig(
mode=InputMode.SINGLE,
image_path=image_path
)
return config
@classmethod
def for_directory(cls, directory: str, pattern: Optional[str] = None, recursive: bool = False) -> "Config":
"""
创建目录批量模式的配置
Args:
directory: 目录路径
pattern: 文件匹配模式
recursive: 是否递归搜索
"""
config = cls()
config.input = InputConfig(
mode=InputMode.DIRECTORY,
directory=directory,
pattern=pattern,
recursive=recursive
)
return config
@classmethod
def for_batch(cls, image_paths: List[str]) -> "Config":
"""
创建批量图片模式的配置
Args:
image_paths: 图片路径列表
"""
config = cls()
config.input = InputConfig(
mode=InputMode.BATCH,
image_paths=image_paths
)
return config