You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

552 lines
16 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""
OCR 图片识别系统 - 主入口
支持单张图片、多张图片和目录批量处理
"""
import os
from pathlib import Path
# 在所有其他导入之前设置 PaddleOCR 模型路径
# 解决 Windows 中文用户名路径问题
_PROJECT_ROOT = Path(__file__).parent
_MODELS_DIR = _PROJECT_ROOT / "models"
_MODELS_DIR.mkdir(exist_ok=True)
os.environ["PADDLEOCR_HOME"] = str(_MODELS_DIR)
import argparse
import json
import sys
import cv2
from typing import Optional, List, Generator
from input.loader import ImageLoader, ImageInfo
from ocr.pipeline import OCRPipeline, OCRResult
from visualize.draw import OCRVisualizer
from utils.config import (
Config,
InputConfig,
InputMode,
OCRConfig,
PipelineConfig,
VisualizeConfig,
OutputConfig,
ROIConfig
)
class OCRApplication:
"""
OCR 应用主类
协调各模块完成图片 OCR 识别
"""
def __init__(
self,
config: Config,
express_mode: bool = False
):
"""
初始化应用
Args:
config: 全局配置
express_mode: 是否启用快递单解析模式
"""
self._config = config
self._loader: Optional[ImageLoader] = None
self._pipeline: Optional[OCRPipeline] = None
self._visualizer: Optional[OCRVisualizer] = None
self._express_mode = express_mode
self._all_results: List[dict] = []
def initialize(self) -> bool:
"""
初始化所有组件
Returns:
是否初始化成功
"""
print("[INFO] 正在初始化 OCR 系统...")
# 创建图片加载器
self._loader = ImageLoader()
# 创建 OCR 管道
self._pipeline = OCRPipeline(
ocr_config=self._config.ocr,
pipeline_config=self._config.pipeline
)
# 创建可视化器
self._visualizer = OCRVisualizer(self._config.visualize)
# 初始化 OCR 管道(预加载模型)
print("[INFO] 正在加载 OCR 模型...")
self._pipeline.initialize()
print("[INFO] OCR 模型加载完成")
return True
def _get_images(self) -> Generator[ImageInfo, None, None]:
"""
根据配置获取图片
Yields:
ImageInfo 对象
"""
input_config = self._config.input
if input_config is None:
return
if input_config.mode == InputMode.SINGLE:
info = self._loader.load(input_config.image_path)
if info:
yield info
elif input_config.mode == InputMode.BATCH:
yield from self._loader.load_batch(input_config.image_paths)
elif input_config.mode == InputMode.DIRECTORY:
yield from self._loader.load_directory(
input_config.directory,
input_config.pattern,
input_config.recursive
)
def run(self) -> None:
"""运行图片处理"""
if self._loader is None or self._pipeline is None or self._visualizer is None:
print("[ERROR] 系统未初始化")
return
self._all_results = []
print("[INFO] 开始 OCR 识别...")
try:
for image_info in self._get_images():
print(f"\n[INFO] 处理图片: {image_info.filename}")
# OCR 处理
result = self._pipeline.process(image_info.image, image_info.path)
# 收集结果
if result:
if self._express_mode:
# 快递单模式:解析并收集结构化结果
express_info = result.parse_express()
self._all_results.append({
"image_index": result.image_index,
"image_path": result.image_path,
"processing_time_ms": result.processing_time_ms,
"express_info": express_info.to_dict(),
"merged_text": result.merge_text()
})
else:
self._all_results.append(result.to_dict())
# 打印结果
if result and result.text_count > 0:
if self._express_mode:
self._print_express_result(result)
else:
self._print_result(result)
# 可视化并保存
if result:
self._handle_visualization(image_info, result)
except KeyboardInterrupt:
print("\n[INFO] 收到中断信号,正在退出...")
finally:
# 导出汇总结果
self._export_summary()
self.cleanup()
def _print_result(self, result: OCRResult) -> None:
"""
打印 OCR 结果到控制台
Args:
result: OCR 结果
"""
print(f" 识别到 {result.text_count} 个文本块 (耗时: {result.processing_time_ms:.1f}ms)")
for i, block in enumerate(result.text_blocks):
print(f" [{i+1}] {block.text} (置信度: {block.confidence:.3f})")
def _print_express_result(self, result: OCRResult) -> None:
"""
打印快递单解析结果到控制台
Args:
result: OCR 结果
"""
express_info = result.parse_express()
print(f" 快递单解析结果 (耗时: {result.processing_time_ms:.1f}ms)")
if express_info.courier_company:
print(f" 快递公司: {express_info.courier_company}")
if express_info.tracking_number:
print(f" 运单号: {express_info.tracking_number}")
if express_info.receiver_name:
print(f" 收件人: {express_info.receiver_name}")
if express_info.receiver_phone:
print(f" 收件电话: {express_info.receiver_phone}")
if express_info.receiver_address:
print(f" 收件地址: {express_info.receiver_address}")
if express_info.sender_name:
print(f" 寄件人: {express_info.sender_name}")
if express_info.sender_phone:
print(f" 寄件电话: {express_info.sender_phone}")
if express_info.sender_address:
print(f" 寄件地址: {express_info.sender_address}")
if not express_info.is_valid:
print(" [未识别到有效快递单信息]")
print(f" 合并文本: {result.merge_text()}")
def _handle_visualization(self, image_info: ImageInfo, result: OCRResult) -> None:
"""
处理可视化和图片保存
Args:
image_info: 图片信息
result: OCR 结果
"""
# 绘制结果
display_image = self._visualizer.draw_result(image_info.image, result)
# 显示窗口
if self._config.visualize.show_window:
key = self._visualizer.show(display_image, wait_key=0)
if key == ord('q') or key == ord('Q'):
raise KeyboardInterrupt()
# 保存标注后的图片
if self._config.output.save_image:
self._save_annotated_image(image_info, display_image)
def _save_annotated_image(self, image_info: ImageInfo, annotated_image) -> None:
"""
保存标注后的图片
Args:
image_info: 原始图片信息
annotated_image: 标注后的图片
"""
output_config = self._config.output
# 确定输出目录
if output_config.output_dir:
output_dir = Path(output_config.output_dir)
else:
output_dir = Path(image_info.path).parent
output_dir.mkdir(parents=True, exist_ok=True)
# 生成输出文件名
original_path = Path(image_info.path)
output_filename = f"{original_path.stem}{output_config.image_suffix}{original_path.suffix}"
output_path = output_dir / output_filename
# 保存图片(支持中文路径)
_, ext = os.path.splitext(str(output_path))
success, encoded = cv2.imencode(ext, annotated_image)
if success:
with open(output_path, 'wb') as f:
f.write(encoded.tobytes())
print(f" [INFO] 标注图片已保存: {output_path}")
def _export_summary(self) -> None:
"""
导出所有识别结果到汇总 JSON 文件
"""
if not self._all_results:
print("[INFO] 没有识别结果需要导出")
return
output_config = self._config.output
if not output_config.save_json:
return
# 确定输出路径
if output_config.output_dir:
output_dir = Path(output_config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / output_config.json_filename
else:
output_path = Path(output_config.json_filename)
# 构建汇总数据
summary = {
"total_images": len(self._all_results),
"total_text_blocks": sum(r.get("text_count", 0) for r in self._all_results),
"results": self._all_results
}
# 写入文件
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print(f"[INFO] 汇总结果已导出到 {output_path},共 {summary['total_images']} 张图片")
def cleanup(self) -> None:
"""清理资源"""
if self._visualizer:
self._visualizer.close()
print("[INFO] 资源已释放")
def parse_args() -> argparse.Namespace:
"""解析命令行参数"""
parser = argparse.ArgumentParser(
description="OCR 图片识别系统",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 识别单张图片
python main.py --image path/to/image.jpg
# 识别目录中的所有图片
python main.py --dir path/to/images/
# 识别目录中的特定格式图片
python main.py --dir path/to/images/ --pattern "*.png"
# 递归搜索子目录
python main.py --dir path/to/images/ --recursive
# 启用快递单解析模式
python main.py --image express.jpg --express
# 保存标注后的图片
python main.py --image test.jpg --save-image
# 指定输出目录
python main.py --dir images/ --output-dir results/
# 启用 ROI 裁剪(画面中央 60% 区域)
python main.py --image test.jpg --roi 0.2 0.2 0.6 0.6
# 使用 GPU 加速
python main.py --image test.jpg --gpu
"""
)
# 输入源(互斥)
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument(
"--image", "-i",
type=str,
help="单张图片路径"
)
input_group.add_argument(
"--dir", "-d",
type=str,
help="图片目录路径"
)
# 目录模式选项
parser.add_argument(
"--pattern", "-p",
type=str,
default=None,
help="文件匹配模式(如 '*.jpg'"
)
parser.add_argument(
"--recursive", "-r",
action="store_true",
help="递归搜索子目录"
)
# OCR 配置
parser.add_argument(
"--lang", "-l",
type=str,
default="ch",
help="OCR 语言(默认: ch"
)
parser.add_argument(
"--gpu",
action="store_true",
help="启用 GPU 加速"
)
parser.add_argument(
"--no-angle-cls",
action="store_true",
help="禁用方向分类"
)
parser.add_argument(
"--drop-score",
type=float,
default=0.5,
help="置信度阈值(默认: 0.5"
)
# ROI 配置
parser.add_argument(
"--roi",
type=float,
nargs=4,
metavar=("X", "Y", "W", "H"),
help="ROI 区域(归一化坐标: x y width height"
)
# 可视化配置
parser.add_argument(
"--show-window",
action="store_true",
help="显示可视化窗口(默认不显示)"
)
parser.add_argument(
"--no-confidence",
action="store_true",
help="不显示置信度"
)
# 输出配置
parser.add_argument(
"--output-dir", "-o",
type=str,
default=None,
help="输出目录路径"
)
parser.add_argument(
"--save-image",
action="store_true",
help="保存标注后的图片"
)
parser.add_argument(
"--no-json",
action="store_true",
help="不保存 JSON 结果"
)
parser.add_argument(
"--json-filename",
type=str,
default="ocr_result.json",
help="JSON 结果文件名(默认: ocr_result.json"
)
# 快递单解析模式
parser.add_argument(
"--express", "-e",
action="store_true",
help="启用快递单解析模式,自动合并文本并提取结构化信息"
)
return parser.parse_args()
def build_config(args: argparse.Namespace) -> Config:
"""
根据命令行参数构建配置
Args:
args: 命令行参数
Returns:
配置对象
"""
# 输入配置
if args.image:
input_config = InputConfig(
mode=InputMode.SINGLE,
image_path=args.image
)
else:
input_config = InputConfig(
mode=InputMode.DIRECTORY,
directory=args.dir,
pattern=args.pattern,
recursive=args.recursive
)
# OCR 配置
# 设置模型目录(解决 Windows 中文用户名路径问题)
det_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_det_infer")
rec_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_rec_infer")
cls_model_dir = str(_MODELS_DIR / "ch_ppocr_mobile_v2.0_cls_infer")
# 检查模型是否已下载
models_exist = (
Path(det_model_dir).exists() and
Path(rec_model_dir).exists() and
Path(cls_model_dir).exists()
)
ocr_config = OCRConfig(
lang=args.lang,
use_angle_cls=not args.no_angle_cls,
use_gpu=args.gpu,
drop_score=args.drop_score,
det_model_dir=det_model_dir if models_exist else None,
rec_model_dir=rec_model_dir if models_exist else None,
cls_model_dir=cls_model_dir if models_exist else None
)
# ROI 配置
roi_config = ROIConfig(enabled=False)
if args.roi:
roi_config = ROIConfig(
enabled=True,
x_ratio=args.roi[0],
y_ratio=args.roi[1],
width_ratio=args.roi[2],
height_ratio=args.roi[3]
)
# 管道配置
pipeline_config = PipelineConfig(roi=roi_config)
# 可视化配置
visualize_config = VisualizeConfig(
show_window=args.show_window,
show_confidence=not args.no_confidence
)
# 输出配置
output_config = OutputConfig(
output_dir=args.output_dir,
save_json=not args.no_json,
save_image=args.save_image,
json_filename=args.json_filename
)
return Config(
input=input_config,
ocr=ocr_config,
pipeline=pipeline_config,
visualize=visualize_config,
output=output_config
)
def main() -> int:
"""主函数"""
args = parse_args()
config = build_config(args)
# 检查模型是否已下载
if config.ocr.det_model_dir is None:
print("[WARN] 模型未在项目目录中找到")
print("[WARN] 对于 Windows 中文用户名用户,请先运行:")
print("[WARN] python download_models.py")
print("[INFO] 回退到默认 PaddleOCR 模型路径...")
app = OCRApplication(
config,
express_mode=args.express
)
if not app.initialize():
return 1
app.run()
return 0
if __name__ == "__main__":
sys.exit(main())