|
|
# -*- coding: utf-8 -*-
|
|
|
"""
|
|
|
OCR 图片识别系统 - 主入口
|
|
|
支持单张图片、多张图片和目录批量处理
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
|
|
|
# 在所有其他导入之前设置 PaddleOCR 模型路径
|
|
|
# 解决 Windows 中文用户名路径问题
|
|
|
_PROJECT_ROOT = Path(__file__).parent
|
|
|
_MODELS_DIR = _PROJECT_ROOT / "models"
|
|
|
_MODELS_DIR.mkdir(exist_ok=True)
|
|
|
os.environ["PADDLEOCR_HOME"] = str(_MODELS_DIR)
|
|
|
|
|
|
import argparse
|
|
|
import json
|
|
|
import sys
|
|
|
import cv2
|
|
|
from typing import Optional, List, Generator
|
|
|
|
|
|
from input.loader import ImageLoader, ImageInfo
|
|
|
from ocr.pipeline import OCRPipeline, OCRResult
|
|
|
from visualize.draw import OCRVisualizer
|
|
|
from utils.config import (
|
|
|
Config,
|
|
|
InputConfig,
|
|
|
InputMode,
|
|
|
OCRConfig,
|
|
|
PipelineConfig,
|
|
|
VisualizeConfig,
|
|
|
OutputConfig,
|
|
|
ROIConfig
|
|
|
)
|
|
|
|
|
|
|
|
|
class OCRApplication:
|
|
|
"""
|
|
|
OCR 应用主类
|
|
|
协调各模块完成图片 OCR 识别
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
config: Config,
|
|
|
express_mode: bool = False
|
|
|
):
|
|
|
"""
|
|
|
初始化应用
|
|
|
|
|
|
Args:
|
|
|
config: 全局配置
|
|
|
express_mode: 是否启用快递单解析模式
|
|
|
"""
|
|
|
self._config = config
|
|
|
self._loader: Optional[ImageLoader] = None
|
|
|
self._pipeline: Optional[OCRPipeline] = None
|
|
|
self._visualizer: Optional[OCRVisualizer] = None
|
|
|
self._express_mode = express_mode
|
|
|
self._all_results: List[dict] = []
|
|
|
|
|
|
def initialize(self) -> bool:
|
|
|
"""
|
|
|
初始化所有组件
|
|
|
|
|
|
Returns:
|
|
|
是否初始化成功
|
|
|
"""
|
|
|
print("[INFO] 正在初始化 OCR 系统...")
|
|
|
|
|
|
# 创建图片加载器
|
|
|
self._loader = ImageLoader()
|
|
|
|
|
|
# 创建 OCR 管道
|
|
|
self._pipeline = OCRPipeline(
|
|
|
ocr_config=self._config.ocr,
|
|
|
pipeline_config=self._config.pipeline
|
|
|
)
|
|
|
|
|
|
# 创建可视化器
|
|
|
self._visualizer = OCRVisualizer(self._config.visualize)
|
|
|
|
|
|
# 初始化 OCR 管道(预加载模型)
|
|
|
print("[INFO] 正在加载 OCR 模型...")
|
|
|
self._pipeline.initialize()
|
|
|
print("[INFO] OCR 模型加载完成")
|
|
|
|
|
|
return True
|
|
|
|
|
|
def _get_images(self) -> Generator[ImageInfo, None, None]:
|
|
|
"""
|
|
|
根据配置获取图片
|
|
|
|
|
|
Yields:
|
|
|
ImageInfo 对象
|
|
|
"""
|
|
|
input_config = self._config.input
|
|
|
if input_config is None:
|
|
|
return
|
|
|
|
|
|
if input_config.mode == InputMode.SINGLE:
|
|
|
info = self._loader.load(input_config.image_path)
|
|
|
if info:
|
|
|
yield info
|
|
|
|
|
|
elif input_config.mode == InputMode.BATCH:
|
|
|
yield from self._loader.load_batch(input_config.image_paths)
|
|
|
|
|
|
elif input_config.mode == InputMode.DIRECTORY:
|
|
|
yield from self._loader.load_directory(
|
|
|
input_config.directory,
|
|
|
input_config.pattern,
|
|
|
input_config.recursive
|
|
|
)
|
|
|
|
|
|
def run(self) -> None:
|
|
|
"""运行图片处理"""
|
|
|
if self._loader is None or self._pipeline is None or self._visualizer is None:
|
|
|
print("[ERROR] 系统未初始化")
|
|
|
return
|
|
|
|
|
|
self._all_results = []
|
|
|
print("[INFO] 开始 OCR 识别...")
|
|
|
|
|
|
try:
|
|
|
for image_info in self._get_images():
|
|
|
print(f"\n[INFO] 处理图片: {image_info.filename}")
|
|
|
|
|
|
# OCR 处理
|
|
|
result = self._pipeline.process(image_info.image, image_info.path)
|
|
|
|
|
|
# 收集结果
|
|
|
if result:
|
|
|
if self._express_mode:
|
|
|
# 快递单模式:解析并收集结构化结果
|
|
|
express_info = result.parse_express()
|
|
|
self._all_results.append({
|
|
|
"image_index": result.image_index,
|
|
|
"image_path": result.image_path,
|
|
|
"processing_time_ms": result.processing_time_ms,
|
|
|
"express_info": express_info.to_dict(),
|
|
|
"merged_text": result.merge_text()
|
|
|
})
|
|
|
else:
|
|
|
self._all_results.append(result.to_dict())
|
|
|
|
|
|
# 打印结果
|
|
|
if result and result.text_count > 0:
|
|
|
if self._express_mode:
|
|
|
self._print_express_result(result)
|
|
|
else:
|
|
|
self._print_result(result)
|
|
|
|
|
|
# 可视化并保存
|
|
|
if result:
|
|
|
self._handle_visualization(image_info, result)
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
print("\n[INFO] 收到中断信号,正在退出...")
|
|
|
|
|
|
finally:
|
|
|
# 导出汇总结果
|
|
|
self._export_summary()
|
|
|
self.cleanup()
|
|
|
|
|
|
def _print_result(self, result: OCRResult) -> None:
|
|
|
"""
|
|
|
打印 OCR 结果到控制台
|
|
|
|
|
|
Args:
|
|
|
result: OCR 结果
|
|
|
"""
|
|
|
print(f" 识别到 {result.text_count} 个文本块 (耗时: {result.processing_time_ms:.1f}ms)")
|
|
|
for i, block in enumerate(result.text_blocks):
|
|
|
print(f" [{i+1}] {block.text} (置信度: {block.confidence:.3f})")
|
|
|
|
|
|
def _print_express_result(self, result: OCRResult) -> None:
|
|
|
"""
|
|
|
打印快递单解析结果到控制台
|
|
|
|
|
|
Args:
|
|
|
result: OCR 结果
|
|
|
"""
|
|
|
express_info = result.parse_express()
|
|
|
print(f" 快递单解析结果 (耗时: {result.processing_time_ms:.1f}ms)")
|
|
|
|
|
|
if express_info.courier_company:
|
|
|
print(f" 快递公司: {express_info.courier_company}")
|
|
|
if express_info.tracking_number:
|
|
|
print(f" 运单号: {express_info.tracking_number}")
|
|
|
if express_info.receiver_name:
|
|
|
print(f" 收件人: {express_info.receiver_name}")
|
|
|
if express_info.receiver_phone:
|
|
|
print(f" 收件电话: {express_info.receiver_phone}")
|
|
|
if express_info.receiver_address:
|
|
|
print(f" 收件地址: {express_info.receiver_address}")
|
|
|
if express_info.sender_name:
|
|
|
print(f" 寄件人: {express_info.sender_name}")
|
|
|
if express_info.sender_phone:
|
|
|
print(f" 寄件电话: {express_info.sender_phone}")
|
|
|
if express_info.sender_address:
|
|
|
print(f" 寄件地址: {express_info.sender_address}")
|
|
|
|
|
|
if not express_info.is_valid:
|
|
|
print(" [未识别到有效快递单信息]")
|
|
|
print(f" 合并文本: {result.merge_text()}")
|
|
|
|
|
|
def _handle_visualization(self, image_info: ImageInfo, result: OCRResult) -> None:
|
|
|
"""
|
|
|
处理可视化和图片保存
|
|
|
|
|
|
Args:
|
|
|
image_info: 图片信息
|
|
|
result: OCR 结果
|
|
|
"""
|
|
|
# 绘制结果
|
|
|
display_image = self._visualizer.draw_result(image_info.image, result)
|
|
|
|
|
|
# 显示窗口
|
|
|
if self._config.visualize.show_window:
|
|
|
key = self._visualizer.show(display_image, wait_key=0)
|
|
|
if key == ord('q') or key == ord('Q'):
|
|
|
raise KeyboardInterrupt()
|
|
|
|
|
|
# 保存标注后的图片
|
|
|
if self._config.output.save_image:
|
|
|
self._save_annotated_image(image_info, display_image)
|
|
|
|
|
|
def _save_annotated_image(self, image_info: ImageInfo, annotated_image) -> None:
|
|
|
"""
|
|
|
保存标注后的图片
|
|
|
|
|
|
Args:
|
|
|
image_info: 原始图片信息
|
|
|
annotated_image: 标注后的图片
|
|
|
"""
|
|
|
output_config = self._config.output
|
|
|
|
|
|
# 确定输出目录
|
|
|
if output_config.output_dir:
|
|
|
output_dir = Path(output_config.output_dir)
|
|
|
else:
|
|
|
output_dir = Path(image_info.path).parent
|
|
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
# 生成输出文件名
|
|
|
original_path = Path(image_info.path)
|
|
|
output_filename = f"{original_path.stem}{output_config.image_suffix}{original_path.suffix}"
|
|
|
output_path = output_dir / output_filename
|
|
|
|
|
|
# 保存图片(支持中文路径)
|
|
|
_, ext = os.path.splitext(str(output_path))
|
|
|
success, encoded = cv2.imencode(ext, annotated_image)
|
|
|
if success:
|
|
|
with open(output_path, 'wb') as f:
|
|
|
f.write(encoded.tobytes())
|
|
|
print(f" [INFO] 标注图片已保存: {output_path}")
|
|
|
|
|
|
def _export_summary(self) -> None:
|
|
|
"""
|
|
|
导出所有识别结果到汇总 JSON 文件
|
|
|
"""
|
|
|
if not self._all_results:
|
|
|
print("[INFO] 没有识别结果需要导出")
|
|
|
return
|
|
|
|
|
|
output_config = self._config.output
|
|
|
if not output_config.save_json:
|
|
|
return
|
|
|
|
|
|
# 确定输出路径
|
|
|
if output_config.output_dir:
|
|
|
output_dir = Path(output_config.output_dir)
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
output_path = output_dir / output_config.json_filename
|
|
|
else:
|
|
|
output_path = Path(output_config.json_filename)
|
|
|
|
|
|
# 构建汇总数据
|
|
|
summary = {
|
|
|
"total_images": len(self._all_results),
|
|
|
"total_text_blocks": sum(r.get("text_count", 0) for r in self._all_results),
|
|
|
"results": self._all_results
|
|
|
}
|
|
|
|
|
|
# 写入文件
|
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
print(f"[INFO] 汇总结果已导出到 {output_path},共 {summary['total_images']} 张图片")
|
|
|
|
|
|
def cleanup(self) -> None:
|
|
|
"""清理资源"""
|
|
|
if self._visualizer:
|
|
|
self._visualizer.close()
|
|
|
print("[INFO] 资源已释放")
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
|
"""解析命令行参数"""
|
|
|
parser = argparse.ArgumentParser(
|
|
|
description="OCR 图片识别系统",
|
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
|
epilog="""
|
|
|
示例:
|
|
|
# 识别单张图片
|
|
|
python main.py --image path/to/image.jpg
|
|
|
|
|
|
# 识别目录中的所有图片
|
|
|
python main.py --dir path/to/images/
|
|
|
|
|
|
# 识别目录中的特定格式图片
|
|
|
python main.py --dir path/to/images/ --pattern "*.png"
|
|
|
|
|
|
# 递归搜索子目录
|
|
|
python main.py --dir path/to/images/ --recursive
|
|
|
|
|
|
# 启用快递单解析模式
|
|
|
python main.py --image express.jpg --express
|
|
|
|
|
|
# 保存标注后的图片
|
|
|
python main.py --image test.jpg --save-image
|
|
|
|
|
|
# 指定输出目录
|
|
|
python main.py --dir images/ --output-dir results/
|
|
|
|
|
|
# 启用 ROI 裁剪(画面中央 60% 区域)
|
|
|
python main.py --image test.jpg --roi 0.2 0.2 0.6 0.6
|
|
|
|
|
|
# 使用 GPU 加速
|
|
|
python main.py --image test.jpg --gpu
|
|
|
"""
|
|
|
)
|
|
|
|
|
|
# 输入源(互斥)
|
|
|
input_group = parser.add_mutually_exclusive_group(required=True)
|
|
|
input_group.add_argument(
|
|
|
"--image", "-i",
|
|
|
type=str,
|
|
|
help="单张图片路径"
|
|
|
)
|
|
|
input_group.add_argument(
|
|
|
"--dir", "-d",
|
|
|
type=str,
|
|
|
help="图片目录路径"
|
|
|
)
|
|
|
|
|
|
# 目录模式选项
|
|
|
parser.add_argument(
|
|
|
"--pattern", "-p",
|
|
|
type=str,
|
|
|
default=None,
|
|
|
help="文件匹配模式(如 '*.jpg')"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--recursive", "-r",
|
|
|
action="store_true",
|
|
|
help="递归搜索子目录"
|
|
|
)
|
|
|
|
|
|
# OCR 配置
|
|
|
parser.add_argument(
|
|
|
"--lang", "-l",
|
|
|
type=str,
|
|
|
default="ch",
|
|
|
help="OCR 语言(默认: ch)"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--gpu",
|
|
|
action="store_true",
|
|
|
help="启用 GPU 加速"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--no-angle-cls",
|
|
|
action="store_true",
|
|
|
help="禁用方向分类"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--drop-score",
|
|
|
type=float,
|
|
|
default=0.5,
|
|
|
help="置信度阈值(默认: 0.5)"
|
|
|
)
|
|
|
|
|
|
# ROI 配置
|
|
|
parser.add_argument(
|
|
|
"--roi",
|
|
|
type=float,
|
|
|
nargs=4,
|
|
|
metavar=("X", "Y", "W", "H"),
|
|
|
help="ROI 区域(归一化坐标: x y width height)"
|
|
|
)
|
|
|
|
|
|
# 可视化配置
|
|
|
parser.add_argument(
|
|
|
"--show-window",
|
|
|
action="store_true",
|
|
|
help="显示可视化窗口(默认不显示)"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--no-confidence",
|
|
|
action="store_true",
|
|
|
help="不显示置信度"
|
|
|
)
|
|
|
|
|
|
# 输出配置
|
|
|
parser.add_argument(
|
|
|
"--output-dir", "-o",
|
|
|
type=str,
|
|
|
default=None,
|
|
|
help="输出目录路径"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--save-image",
|
|
|
action="store_true",
|
|
|
help="保存标注后的图片"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--no-json",
|
|
|
action="store_true",
|
|
|
help="不保存 JSON 结果"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--json-filename",
|
|
|
type=str,
|
|
|
default="ocr_result.json",
|
|
|
help="JSON 结果文件名(默认: ocr_result.json)"
|
|
|
)
|
|
|
|
|
|
# 快递单解析模式
|
|
|
parser.add_argument(
|
|
|
"--express", "-e",
|
|
|
action="store_true",
|
|
|
help="启用快递单解析模式,自动合并文本并提取结构化信息"
|
|
|
)
|
|
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
def build_config(args: argparse.Namespace) -> Config:
|
|
|
"""
|
|
|
根据命令行参数构建配置
|
|
|
|
|
|
Args:
|
|
|
args: 命令行参数
|
|
|
|
|
|
Returns:
|
|
|
配置对象
|
|
|
"""
|
|
|
# 输入配置
|
|
|
if args.image:
|
|
|
input_config = InputConfig(
|
|
|
mode=InputMode.SINGLE,
|
|
|
image_path=args.image
|
|
|
)
|
|
|
else:
|
|
|
input_config = InputConfig(
|
|
|
mode=InputMode.DIRECTORY,
|
|
|
directory=args.dir,
|
|
|
pattern=args.pattern,
|
|
|
recursive=args.recursive
|
|
|
)
|
|
|
|
|
|
# OCR 配置
|
|
|
# 设置模型目录(解决 Windows 中文用户名路径问题)
|
|
|
det_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_det_infer")
|
|
|
rec_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_rec_infer")
|
|
|
cls_model_dir = str(_MODELS_DIR / "ch_ppocr_mobile_v2.0_cls_infer")
|
|
|
|
|
|
# 检查模型是否已下载
|
|
|
models_exist = (
|
|
|
Path(det_model_dir).exists() and
|
|
|
Path(rec_model_dir).exists() and
|
|
|
Path(cls_model_dir).exists()
|
|
|
)
|
|
|
|
|
|
ocr_config = OCRConfig(
|
|
|
lang=args.lang,
|
|
|
use_angle_cls=not args.no_angle_cls,
|
|
|
use_gpu=args.gpu,
|
|
|
drop_score=args.drop_score,
|
|
|
det_model_dir=det_model_dir if models_exist else None,
|
|
|
rec_model_dir=rec_model_dir if models_exist else None,
|
|
|
cls_model_dir=cls_model_dir if models_exist else None
|
|
|
)
|
|
|
|
|
|
# ROI 配置
|
|
|
roi_config = ROIConfig(enabled=False)
|
|
|
if args.roi:
|
|
|
roi_config = ROIConfig(
|
|
|
enabled=True,
|
|
|
x_ratio=args.roi[0],
|
|
|
y_ratio=args.roi[1],
|
|
|
width_ratio=args.roi[2],
|
|
|
height_ratio=args.roi[3]
|
|
|
)
|
|
|
|
|
|
# 管道配置
|
|
|
pipeline_config = PipelineConfig(roi=roi_config)
|
|
|
|
|
|
# 可视化配置
|
|
|
visualize_config = VisualizeConfig(
|
|
|
show_window=args.show_window,
|
|
|
show_confidence=not args.no_confidence
|
|
|
)
|
|
|
|
|
|
# 输出配置
|
|
|
output_config = OutputConfig(
|
|
|
output_dir=args.output_dir,
|
|
|
save_json=not args.no_json,
|
|
|
save_image=args.save_image,
|
|
|
json_filename=args.json_filename
|
|
|
)
|
|
|
|
|
|
return Config(
|
|
|
input=input_config,
|
|
|
ocr=ocr_config,
|
|
|
pipeline=pipeline_config,
|
|
|
visualize=visualize_config,
|
|
|
output=output_config
|
|
|
)
|
|
|
|
|
|
|
|
|
def main() -> int:
|
|
|
"""主函数"""
|
|
|
args = parse_args()
|
|
|
config = build_config(args)
|
|
|
|
|
|
# 检查模型是否已下载
|
|
|
if config.ocr.det_model_dir is None:
|
|
|
print("[WARN] 模型未在项目目录中找到")
|
|
|
print("[WARN] 对于 Windows 中文用户名用户,请先运行:")
|
|
|
print("[WARN] python download_models.py")
|
|
|
print("[INFO] 回退到默认 PaddleOCR 模型路径...")
|
|
|
|
|
|
app = OCRApplication(
|
|
|
config,
|
|
|
express_mode=args.express
|
|
|
)
|
|
|
|
|
|
if not app.initialize():
|
|
|
return 1
|
|
|
|
|
|
app.run()
|
|
|
return 0
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
sys.exit(main())
|