commit 7570e1314dce2b5d9a42369d92efa96b3d45ff57 Author: Harden <1915702192@qq.com> Date: Wed Jan 7 17:43:44 2026 +0800 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fee36c5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,48 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual Environment +venv/ +ENV/ +env/ +.venv/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Model archives (keep extracted folders only) +models/*.tar + +# Logs +*.log +log.txt + +# OCR output +ocr_*.json +*_result.json + +# OS +.DS_Store +Thumbs.db diff --git a/README.md b/README.md new file mode 100644 index 0000000..e11e492 --- /dev/null +++ b/README.md @@ -0,0 +1,605 @@ +# Vision-OCR: 图片 OCR 识别系统 + +基于 PaddleOCR 的图片 OCR 识别系统,支持单张图片、批量图片和目录扫描,提供文本检测、识别、方向分类,输出结构化识别结果。 + +## 功能特性 + +- **多输入模式**: 单张图片、多张图片列表、目录批量扫描 +- **完整 OCR 能力**: 文本检测 + 文本识别 + 方向分类 +- **结构化输出**: 文字内容、置信度、位置信息(4 点坐标) +- **快递单解析**: 自动合并分散文本块,提取运单号、收寄件人等结构化信息 +- **可视化展示**: 在图片上绘制文本框和识别结果 +- **结果导出**: 支持 JSON 结果导出和标注图片保存 +- **ROI 裁剪**: 支持只识别图片指定区域 +- **模块化设计**: 图片加载与 OCR 逻辑完全解耦,便于扩展 +- **全本地运行**: 不依赖任何云服务 + +## 项目结构 + +``` +vision-ocr/ +├── input/ # 图片输入模块 +│ ├── __init__.py +│ └── loader.py # 图片加载器 +├── ocr/ # OCR 处理模块 +│ ├── __init__.py +│ ├── engine.py # PaddleOCR 引擎封装 +│ ├── pipeline.py # OCR 处理管道 +│ └── express_parser.py # 快递单解析器 +├── visualize/ # 可视化模块 +│ ├── __init__.py +│ └── draw.py # 结果绘制器 +├── utils/ # 工具模块 +│ ├── __init__.py +│ └── config.py # 配置管理 +├── models/ # 模型文件目录(运行 download_models.py 后生成) +├── main.py # 主入口 +├── download_models.py # 模型下载脚本 +├── requirements.txt # 依赖清单 +└── README.md +``` + +## 环境要求 + +- Python 3.9+ +- 支持的操作系统: Windows / Linux / macOS + +## 安装 + +### 1. 克隆项目 + +```bash +git clone +cd vision-ocr +``` + +### 2. 创建虚拟环境(推荐) + +```bash +python -m venv venv + +# Windows +venv\Scripts\activate + +# Linux/macOS +source venv/bin/activate +``` + +### 3. 安装依赖 + +```bash +pip install -r requirements.txt +``` + +### 4. 模型说明 + +本项目已内置 PaddleOCR 模型文件(位于 `models/` 目录),clone 后即可直接使用,无需额外下载。 + +> **备用方案**:如果模型文件缺失或需要更新,可运行 `python download_models.py` 重新下载。 + +#### 模型详情 + +本项目使用 PaddleOCR 的 PP-OCRv4 系列模型,包含 3 个模型协同工作: + +| 模型类型 | 模型名称 | 作用 | 大小 | +|---------|---------|------|------| +| **det (检测模型)** | ch_PP-OCRv4_det_infer | 定位图像中所有文本区域的位置,输出每个文本块的 4 点边界框坐标 | ~4.7MB | +| **rec (识别模型)** | ch_PP-OCRv4_rec_infer | 将检测到的文本区域图像转换为实际文字内容,输出文本和置信度 | ~10MB | +| **cls (方向分类模型)** | ch_ppocr_mobile_v2.0_cls_infer | 判断文本是正向(0度)还是倒置(180度),用于矫正倒置文本后再识别 | ~1.4MB | + +**OCR 处理流程:** + +``` +输入图像 -> [det 检测] -> 文本区域 -> [cls 分类] -> 方向矫正 -> [rec 识别] -> 文字结果 +``` + +**模型下载地址:** + +- 检测模型: https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar +- 识别模型: https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_infer.tar +- 方向分类模型: https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar + +**注意:** 可通过 `--no-angle-cls` 参数禁用方向分类模型,适用于文本方向固定的场景,可略微提升处理速度。 + +### 5. GPU 加速(可选) + +如需使用 GPU 加速,请安装对应 CUDA 版本的 PaddlePaddle: + +```bash +# CUDA 11.8 +pip install paddlepaddle-gpu==2.5.2.post118 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html + +# CUDA 12.0 +pip install paddlepaddle-gpu==2.5.2.post120 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +``` + +## 使用方法 + +### 基础用法 + +```bash +# 识别单张图片 +python main.py --image path/to/image.jpg + +# 识别目录中的所有图片 +python main.py --dir path/to/images/ + +# 识别目录中的特定格式图片 +python main.py --dir path/to/images/ --pattern "*.png" + +# 递归搜索子目录 +python main.py --dir path/to/images/ --recursive +``` + +### 高级选项 + +```bash +# 启用 GPU 加速 +python main.py --image test.jpg --gpu + +# 启用 ROI 区域裁剪(只识别画面中央 60% 区域) +python main.py --image test.jpg --roi 0.2 0.2 0.6 0.6 + +# 调整置信度阈值(过滤低置信度结果) +python main.py --image test.jpg --drop-score 0.7 + +# 切换识别语言 +python main.py --image test.jpg --lang en # 英文 +python main.py --image test.jpg --lang ch # 中文(默认) + +# 禁用方向分类(轻微提升速度) +python main.py --image test.jpg --no-angle-cls + +# 显示可视化窗口 +python main.py --image test.jpg --show-window + +# 保存标注后的图片 +python main.py --image test.jpg --save-image + +# 指定输出目录 +python main.py --dir images/ --output-dir results/ +``` + +### 完整参数列表 + +| 参数 | 简写 | 说明 | 默认值 | +|------|------|------|--------| +| `--image` | `-i` | 单张图片路径 | - | +| `--dir` | `-d` | 图片目录路径 | - | +| `--pattern` | `-p` | 文件匹配模式 | - | +| `--recursive` | `-r` | 递归搜索子目录 | False | +| `--lang` | `-l` | OCR 语言 | ch | +| `--gpu` | - | 启用 GPU 加速 | False | +| `--no-angle-cls` | - | 禁用方向分类 | False | +| `--drop-score` | - | 置信度阈值 | 0.5 | +| `--roi` | - | ROI 区域 (x y w h) | - | +| `--show-window` | - | 显示可视化窗口 | False | +| `--no-confidence` | - | 不显示置信度 | False | +| `--output-dir` | `-o` | 输出目录路径 | - | +| `--save-image` | - | 保存标注后的图片 | False | +| `--no-json` | - | 不保存 JSON 结果 | False | +| `--json-filename` | - | JSON 结果文件名 | ocr_result.json | +| `--express` | `-e` | 启用快递单解析模式 | False | + +### 运行时快捷键 + +| 按键 | 功能 | +|------|------| +| `q` | 退出程序 | +| 任意键 | 处理下一张图片 | + +### 结果导出 + +程序处理完成后会自动将所有识别结果导出到 JSON 文件: + +- 默认输出文件:`ocr_result.json` +- 可通过 `--json-filename` 参数指定文件名 +- 可通过 `--output-dir` 参数指定输出目录 + +**汇总 JSON 格式:** + +```json +{ + "total_images": 10, + "total_text_blocks": 45, + "results": [ + { + "image_index": 1, + "image_path": "path/to/image.jpg", + "timestamp": 1704355200.123, + "processing_time_ms": 45.2, + "text_count": 3, + "average_confidence": 0.92, + "roi_applied": false, + "roi_rect": null, + "text_blocks": [ + { + "text": "识别的文本", + "confidence": 0.95, + "bbox": [[100, 50], [200, 50], [200, 80], [100, 80]], + "bbox_with_offset": [[100, 50], [200, 50], [200, 80], [100, 80]], + "center": [150, 65], + "width": 100, + "height": 30 + } + ] + } + ] +} +``` + +#### JSON 字段说明 + +**顶层字段** + +| 字段 | 类型 | 说明 | +|------|------|------| +| `total_images` | int | 处理的图片总数 | +| `total_text_blocks` | int | 所有图片识别出的文本块总数 | +| `results` | array | 每张图片的识别结果数组 | + +**单张图片结果字段** + +| 字段 | 类型 | 说明 | +|------|------|------| +| `image_index` | int | 图片索引,从 1 开始递增 | +| `image_path` | string | 图片文件的完整路径 | +| `timestamp` | float | 处理完成时的 Unix 时间戳(秒) | +| `processing_time_ms` | float | OCR 处理耗时(毫秒) | +| `text_count` | int | 该图片识别出的文本块数量 | +| `average_confidence` | float | 所有文本块的平均置信度 (0.0 ~ 1.0) | +| `roi_applied` | bool | 是否应用了 ROI 区域裁剪 | +| `roi_rect` | array\|null | ROI 矩形区域 `[x, y, width, height]`,未应用时为 `null` | +| `text_blocks` | array | 识别出的文本块数组 | + +**文本块 (text_blocks) 字段** + +| 字段 | 类型 | 说明 | +|------|------|------| +| `text` | string | 识别出的文本内容 | +| `confidence` | float | 识别置信度 (0.0 ~ 1.0),越高表示识别结果越可靠 | +| `bbox` | array | 文本边界框的 4 个顶点坐标 `[[x1,y1], [x2,y2], [x3,y3], [x4,y4]]`,顺序为左上、右上、右下、左下。如果启用了 ROI,坐标相对于 ROI 区域 | +| `bbox_with_offset` | array | 带偏移的边界框坐标,已还原到原图坐标系。格式同 `bbox` | +| `center` | array | 文本块中心点坐标 `[cx, cy]` | +| `width` | float | 文本块宽度(像素) | +| `height` | float | 文本块高度(像素) | + +## 快递单解析模式 + +使用 `--express` 参数启用快递单解析模式,系统会自动: + +1. **合并分散文本块**: 基于位置信息将同一行的文本块合并 +2. **提取结构化信息**: 运单号、快递公司、收/寄件人姓名、电话、地址 + +### 使用方式 + +```bash +# 单张快递单图片 +python main.py --image express.jpg --express + +# 批量处理快递单图片 +python main.py --dir express_images/ --express --output-dir results/ +``` + +### 输出格式 + +快递单模式下的 JSON 输出格式: + +```json +{ + "total_images": 5, + "total_text_blocks": 50, + "results": [ + { + "image_index": 1, + "image_path": "express.jpg", + "processing_time_ms": 45.2, + "express_info": { + "tracking_number": "SF1234567890", + "sender": { + "name": "张三", + "phone": "13800138000", + "address": "北京市朝阳区xxx路" + }, + "receiver": { + "name": "李四", + "phone": "13900139000", + "address": "上海市浦东新区xxx路" + }, + "courier_company": "顺丰速运", + "confidence": 0.95, + "extra_fields": {}, + "raw_text": "顺丰速运\n运单号:SF1234567890\n..." + }, + "merged_text": "顺丰速运\n运单号:SF1234567890\n收件人:李四 13900139000\n..." + } + ] +} +``` + +#### 快递单模式 JSON 字段说明 + +**单张图片结果字段(快递单模式)** + +| 字段 | 类型 | 说明 | +|------|------|------| +| `image_index` | int | 图片索引,从 1 开始递增 | +| `image_path` | string | 图片文件的完整路径 | +| `processing_time_ms` | float | OCR 处理耗时(毫秒) | +| `express_info` | object | 解析出的快递单结构化信息 | +| `merged_text` | string | 基于位置信息智能合并后的完整文本,同一行的文本块会被合并 | + +**快递单信息 (express_info) 字段** + +| 字段 | 类型 | 说明 | +|------|------|------| +| `tracking_number` | string\|null | 运单号/快递单号 | +| `sender` | object | 寄件人信息 | +| `sender.name` | string\|null | 寄件人姓名 | +| `sender.phone` | string\|null | 寄件人电话(11位手机号) | +| `sender.address` | string\|null | 寄件人地址 | +| `receiver` | object | 收件人信息 | +| `receiver.name` | string\|null | 收件人姓名 | +| `receiver.phone` | string\|null | 收件人电话(11位手机号) | +| `receiver.address` | string\|null | 收件人地址 | +| `courier_company` | string\|null | 快递公司名称(如:顺丰速运、圆通速递等) | +| `confidence` | float | 所有文本块的平均置信度 (0.0 ~ 1.0) | +| `extra_fields` | object | 其他识别到的额外字段(键值对形式) | +| `raw_text` | string | 原始合并文本,用于调试和验证 | + +### 支持的快递公司 + +顺丰、圆通、中通、韵达、申通、极兔、京东、邮政、EMS、百世、德邦、天天、宅急送 + +### 编程接口 + +```python +from ocr import OCRPipeline, ExpressParser + +# 使用 OCRResult 的 parse_express() 方法 +result = pipeline.process(image) +if result and result.text_count > 0: + # 解析快递单信息 + express_info = result.parse_express() + print(f"运单号: {express_info.tracking_number}") + print(f"收件人: {express_info.receiver_name}") + print(f"收件电话: {express_info.receiver_phone}") + + # 获取合并后的完整文本 + merged_text = result.merge_text() + print(f"合并文本: {merged_text}") +``` + +## 编程接口 + +### 作为模块使用 + +```python +from input import ImageLoader +from ocr import OCRPipeline +from visualize import OCRVisualizer +from utils import Config, InputConfig, InputMode + +# 创建配置 +config = Config.for_single_image("path/to/image.jpg") + +# 创建组件 +loader = ImageLoader() +pipeline = OCRPipeline(config.ocr, config.pipeline) +visualizer = OCRVisualizer(config.visualize) + +# 初始化 +pipeline.initialize() + +# 加载并处理图片 +image_info = loader.load("path/to/image.jpg") +if image_info: + result = pipeline.process(image_info.image, image_info.path) + + if result and result.text_count > 0: + # 获取识别结果 + for block in result.text_blocks: + print(f"文本: {block.text}") + print(f"置信度: {block.confidence}") + print(f"位置: {block.bbox}") + + # 导出为 JSON + json_data = result.to_dict() + + # 可视化 + display_image = visualizer.draw_result(image_info.image, result) + visualizer.show(display_image, wait_key=0) + +# 清理资源 +visualizer.close() +``` + +### 批量处理 + +```python +from input import ImageLoader +from ocr import OCRPipeline +from utils import Config + +# 创建配置 +config = Config.for_directory("path/to/images/", pattern="*.jpg") + +# 创建组件 +loader = ImageLoader() +pipeline = OCRPipeline(config.ocr) +pipeline.initialize() + +# 批量处理 +for image_info in loader.load_directory("path/to/images/"): + result = pipeline.process(image_info.image, image_info.path) + print(f"{image_info.filename}: 识别到 {result.text_count} 个文本块") +``` + +### OCRResult 数据结构 + +```python +{ + "image_index": 1, + "image_path": "path/to/image.jpg", + "timestamp": 1704355200.123, + "processing_time_ms": 45.6, + "text_count": 3, + "average_confidence": 0.92, + "roi_applied": False, + "roi_rect": None, + "text_blocks": [ + { + "text": "识别的文本", + "confidence": 0.95, + "bbox": [[x1, y1], [x2, y2], [x3, y3], [x4, y4]], + "bbox_with_offset": [[x1, y1], [x2, y2], [x3, y3], [x4, y4]], + "center": [cx, cy], + "width": 120.0, + "height": 30.0 + } + ] +} +``` + +## 模块说明 + +### input/loader.py - 图片加载模块 + +提供图片加载功能,支持单张、批量和目录加载。 + +- `ImageLoader`: 图片加载器类 +- `ImageInfo`: 图片信息数据结构 +- `load_image()`: 便捷函数,加载单张图片 +- `load_images()`: 便捷函数,批量加载图片 + +### ocr/engine.py - OCR 引擎模块 + +封装 PaddleOCR,提供简洁的 OCR 调用接口。 + +- `OCREngine`: OCR 引擎类 +- `TextBlock`: 文本块数据结构 + +### ocr/pipeline.py - OCR 处理管道 + +串联图片加载、ROI 裁剪、OCR 识别、结果封装。 + +- `OCRPipeline`: 处理管道类 +- `OCRResult`: OCR 结果数据结构 + +### visualize/draw.py - 可视化模块 + +在图像上绘制 OCR 识别结果。 + +- `OCRVisualizer`: 可视化器类 + +### utils/config.py - 配置管理模块 + +集中管理所有可配置参数。 + +- `Config`: 全局配置聚合类 +- `InputConfig`: 输入配置 +- `OCRConfig`: OCR 引擎配置 +- `PipelineConfig`: 管道配置 +- `VisualizeConfig`: 可视化配置 +- `OutputConfig`: 输出配置 +- `ROIConfig`: ROI 区域配置 + +## 扩展开发 + +### 添加图片预处理器 + +```python +import cv2 + +def denoise_preprocessor(image): + """降噪预处理""" + return cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21) + +pipeline.add_preprocessor(denoise_preprocessor) +``` + +### 添加结果后处理器 + +```python +def filter_short_text(result): + """过滤短文本""" + result.text_blocks = [ + block for block in result.text_blocks + if len(block.text) >= 3 + ] + return result + +pipeline.add_postprocessor(filter_short_text) +``` + +## 性能优化建议 + +1. **启用 ROI**: 使用 `--roi` 参数只处理感兴趣区域 +2. **使用 GPU**: 使用 `--gpu` 参数启用 GPU 加速 +3. **禁用方向分类**: 如果文本方向固定,使用 `--no-angle-cls` +4. **提高置信度阈值**: 使用 `--drop-score` 过滤低质量结果 +5. **批量处理**: 使用目录模式批量处理多张图片 + +## 常见问题 + +### Q: Windows 中文用户名导致模型加载失败? + +A: PaddlePaddle 的 C++ 推理引擎无法正确处理包含中文字符的路径。请运行以下命令将模型下载到项目目录: + +```bash +python download_models.py +``` + +程序会自动检测并使用 `models/` 目录中的模型。 + +### Q: 中文无法正常显示? + +A: OpenCV 默认字体不支持中文。可以在 `VisualizeConfig` 中配置 `font_path` 指向系统中文字体文件: + +```python +config.visualize.font_path = "C:/Windows/Fonts/simhei.ttf" # Windows +config.visualize.font_path = "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc" # Linux +``` + +### Q: OCR 速度慢? + +A: 参考上方「性能优化建议」部分。 + +### Q: 支持哪些图片格式? + +A: 支持以下格式:`.jpg`, `.jpeg`, `.png`, `.bmp`, `.tiff`, `.tif`, `.webp` + +## 贡献指南 + +欢迎提交 Issue 和 Pull Request。 + +### 开发流程 + +1. Fork 本仓库 +2. 创建功能分支: `git checkout -b feature/your-feature` +3. 提交更改: `git commit -m "Add your feature"` +4. 推送分支: `git push origin feature/your-feature` +5. 创建 Pull Request + +### 代码规范 + +- 遵循 PEP 8 代码风格 +- 所有公共类和函数需要添加文档字符串 +- 新功能需要添加相应的类型注解 +- 提交前确保代码可正常运行 + +### 目录结构规范 + +- `input/`: 仅包含图片加载相关代码 +- `ocr/`: 仅包含 OCR 处理相关代码 +- `visualize/`: 仅包含可视化相关代码 +- `utils/`: 通用工具和配置 + +## 许可证 + +MIT License + +## 致谢 + +- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) - 强大的 OCR 工具库 +- [OpenCV](https://opencv.org/) - 计算机视觉库 diff --git a/data/csb.png b/data/csb.png new file mode 100644 index 0000000..6b26b54 Binary files /dev/null and b/data/csb.png differ diff --git a/data/ems.png b/data/ems.png new file mode 100644 index 0000000..f58713a Binary files /dev/null and b/data/ems.png differ diff --git a/data/img.png b/data/img.png new file mode 100644 index 0000000..4ea115e Binary files /dev/null and b/data/img.png differ diff --git a/data/invert.png b/data/invert.png new file mode 100644 index 0000000..4c8c108 Binary files /dev/null and b/data/invert.png differ diff --git a/data/jd.png b/data/jd.png new file mode 100644 index 0000000..6d43427 Binary files /dev/null and b/data/jd.png differ diff --git a/data/sf.png b/data/sf.png new file mode 100644 index 0000000..95ac96d Binary files /dev/null and b/data/sf.png differ diff --git a/data/st.png b/data/st.png new file mode 100644 index 0000000..3d8f285 Binary files /dev/null and b/data/st.png differ diff --git a/data/test.png b/data/test.png new file mode 100644 index 0000000..ae09390 Binary files /dev/null and b/data/test.png differ diff --git a/data/zt.png b/data/zt.png new file mode 100644 index 0000000..176956a Binary files /dev/null and b/data/zt.png differ diff --git a/download_models.py b/download_models.py new file mode 100644 index 0000000..d06a5ab --- /dev/null +++ b/download_models.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +""" +模型下载脚本 +将 PaddleOCR 模型下载到项目目录,避免中文路径问题 +""" + +import os +import tarfile +import urllib.request +from pathlib import Path + +# 模型下载地址 +MODELS = { + "det": { + "url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar", + "dir": "ch_PP-OCRv4_det_infer" + }, + "rec": { + "url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_infer.tar", + "dir": "ch_PP-OCRv4_rec_infer" + }, + "cls": { + "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar", + "dir": "ch_ppocr_mobile_v2.0_cls_infer" + } +} + +def download_and_extract(url: str, save_dir: Path, model_name: str) -> Path: + """ + 下载并解压模型 + + Args: + url: 下载地址 + save_dir: 保存目录 + model_name: 模型名称 + + Returns: + 解压后的模型目录路径 + """ + save_dir.mkdir(parents=True, exist_ok=True) + tar_path = save_dir / f"{model_name}.tar" + + # 下载 + if not tar_path.exists(): + print(f"[INFO] Downloading {model_name}...") + urllib.request.urlretrieve(url, tar_path) + print(f"[INFO] Downloaded to {tar_path}") + else: + print(f"[INFO] {model_name} already exists, skipping download") + + # 解压 + extract_dir = save_dir / model_name + if not extract_dir.exists(): + print(f"[INFO] Extracting {model_name}...") + with tarfile.open(tar_path, "r") as tar: + tar.extractall(save_dir) + print(f"[INFO] Extracted to {extract_dir}") + + return extract_dir + + +def main(): + """下载所有模型""" + project_root = Path(__file__).parent + models_dir = project_root / "models" + + print(f"[INFO] Models will be saved to: {models_dir}") + + for model_type, info in MODELS.items(): + model_dir = download_and_extract( + url=info["url"], + save_dir=models_dir, + model_name=info["dir"] + ) + print(f"[INFO] {model_type} model ready at: {model_dir}") + + print("\n[INFO] All models downloaded successfully!") + print("[INFO] You can now run: python main.py") + + +if __name__ == "__main__": + main() diff --git a/input/__init__.py b/input/__init__.py new file mode 100644 index 0000000..ece3862 --- /dev/null +++ b/input/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +""" +图片输入模块 +提供图片加载功能,支持单张图片、多张图片和目录批量加载 +""" + +from .loader import ImageLoader, ImageInfo, load_image, load_images + +__all__ = ["ImageLoader", "ImageInfo", "load_image", "load_images"] diff --git a/input/loader.py b/input/loader.py new file mode 100644 index 0000000..9aad039 --- /dev/null +++ b/input/loader.py @@ -0,0 +1,278 @@ +# -*- coding: utf-8 -*- +""" +图片加载模块 +提供图片加载功能,支持单张图片、多张图片和目录批量加载 +""" + +import cv2 +import numpy as np +from pathlib import Path +from typing import List, Optional, Generator, Union, Tuple +from dataclasses import dataclass + + +# 支持的图片格式 +SUPPORTED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'} + + +@dataclass +class ImageInfo: + """ + 图片信息数据结构 + + Attributes: + path: 图片文件路径 + image: 图片数据 (numpy array, BGR 格式) + width: 图片宽度 + height: 图片高度 + filename: 文件名(不含路径) + """ + path: str + image: np.ndarray + width: int + height: int + filename: str + + @classmethod + def from_image(cls, path: str, image: np.ndarray) -> "ImageInfo": + """ + 从图片数据创建 ImageInfo + + Args: + path: 图片路径 + image: 图片数据 + + Returns: + ImageInfo 实例 + """ + h, w = image.shape[:2] + return cls( + path=path, + image=image, + width=w, + height=h, + filename=Path(path).name + ) + + +class ImageLoader: + """ + 图片加载器 + 支持单张图片、多张图片列表和目录批量加载 + """ + + def __init__(self, supported_extensions: Optional[set] = None): + """ + 初始化图片加载器 + + Args: + supported_extensions: 支持的图片扩展名集合,默认使用 SUPPORTED_EXTENSIONS + """ + self._extensions = supported_extensions or SUPPORTED_EXTENSIONS + + def load(self, path: Union[str, Path]) -> Optional[ImageInfo]: + """ + 加载单张图片 + + Args: + path: 图片文件路径 + + Returns: + ImageInfo 对象,加载失败返回 None + """ + path = Path(path) + + if not path.exists(): + print(f"[ERROR] 文件不存在: {path}") + return None + + if not path.is_file(): + print(f"[ERROR] 路径不是文件: {path}") + return None + + if path.suffix.lower() not in self._extensions: + print(f"[ERROR] 不支持的图片格式: {path.suffix}") + return None + + # 使用 cv2.imdecode 处理中文路径 + image = self._read_image(str(path)) + + if image is None: + print(f"[ERROR] 无法读取图片: {path}") + return None + + return ImageInfo.from_image(str(path), image) + + def load_batch(self, paths: List[Union[str, Path]]) -> Generator[ImageInfo, None, None]: + """ + 批量加载多张图片 + + Args: + paths: 图片路径列表 + + Yields: + ImageInfo 对象 + """ + for path in paths: + info = self.load(path) + if info is not None: + yield info + + def load_directory( + self, + dir_path: Union[str, Path], + pattern: Optional[str] = None, + recursive: bool = False + ) -> Generator[ImageInfo, None, None]: + """ + 加载目录中的所有图片 + + Args: + dir_path: 目录路径 + pattern: 文件匹配模式(如 "*.jpg"),None 表示加载所有支持的格式 + recursive: 是否递归搜索子目录 + + Yields: + ImageInfo 对象 + """ + dir_path = Path(dir_path) + + if not dir_path.exists(): + print(f"[ERROR] 目录不存在: {dir_path}") + return + + if not dir_path.is_dir(): + print(f"[ERROR] 路径不是目录: {dir_path}") + return + + # 获取文件列表 + if pattern: + if recursive: + files = list(dir_path.rglob(pattern)) + else: + files = list(dir_path.glob(pattern)) + else: + # 加载所有支持的格式 + files = [] + for ext in self._extensions: + if recursive: + files.extend(dir_path.rglob(f"*{ext}")) + files.extend(dir_path.rglob(f"*{ext.upper()}")) + else: + files.extend(dir_path.glob(f"*{ext}")) + files.extend(dir_path.glob(f"*{ext.upper()}")) + + # 按文件名排序 + files = sorted(set(files), key=lambda p: p.name) + + print(f"[INFO] 在目录 {dir_path} 中找到 {len(files)} 张图片") + + for file_path in files: + info = self.load(file_path) + if info is not None: + yield info + + def get_image_paths( + self, + dir_path: Union[str, Path], + pattern: Optional[str] = None, + recursive: bool = False + ) -> List[str]: + """ + 获取目录中所有图片的路径列表(不加载图片) + + Args: + dir_path: 目录路径 + pattern: 文件匹配模式 + recursive: 是否递归搜索 + + Returns: + 图片路径列表 + """ + dir_path = Path(dir_path) + + if not dir_path.exists() or not dir_path.is_dir(): + return [] + + if pattern: + if recursive: + files = list(dir_path.rglob(pattern)) + else: + files = list(dir_path.glob(pattern)) + else: + files = [] + for ext in self._extensions: + if recursive: + files.extend(dir_path.rglob(f"*{ext}")) + files.extend(dir_path.rglob(f"*{ext.upper()}")) + else: + files.extend(dir_path.glob(f"*{ext}")) + files.extend(dir_path.glob(f"*{ext.upper()}")) + + return sorted([str(f) for f in set(files)], key=lambda p: Path(p).name) + + def _read_image(self, path: str) -> Optional[np.ndarray]: + """ + 读取图片,支持中文路径 + + Args: + path: 图片路径 + + Returns: + 图片数据,读取失败返回 None + """ + # 使用 numpy 和 imdecode 处理中文路径 + try: + with open(path, 'rb') as f: + data = np.frombuffer(f.read(), dtype=np.uint8) + image = cv2.imdecode(data, cv2.IMREAD_COLOR) + return image + except Exception as e: + print(f"[ERROR] 读取图片失败: {path}, 错误: {e}") + return None + + @property + def supported_extensions(self) -> set: + """获取支持的图片扩展名""" + return self._extensions.copy() + + +def load_image(path: Union[str, Path]) -> Optional[np.ndarray]: + """ + 便捷函数:加载单张图片 + + Args: + path: 图片路径 + + Returns: + 图片数据 (numpy array, BGR 格式),加载失败返回 None + """ + loader = ImageLoader() + info = loader.load(path) + return info.image if info else None + + +def load_images( + paths: Optional[List[Union[str, Path]]] = None, + directory: Optional[Union[str, Path]] = None, + pattern: Optional[str] = None, + recursive: bool = False +) -> Generator[ImageInfo, None, None]: + """ + 便捷函数:批量加载图片 + + Args: + paths: 图片路径列表 + directory: 图片目录 + pattern: 文件匹配模式 + recursive: 是否递归搜索 + + Yields: + ImageInfo 对象 + """ + loader = ImageLoader() + + if paths: + yield from loader.load_batch(paths) + elif directory: + yield from loader.load_directory(directory, pattern, recursive) diff --git a/main.py b/main.py new file mode 100644 index 0000000..e2d3153 --- /dev/null +++ b/main.py @@ -0,0 +1,551 @@ +# -*- coding: utf-8 -*- +""" +OCR 图片识别系统 - 主入口 +支持单张图片、多张图片和目录批量处理 +""" + +import os +from pathlib import Path + +# 在所有其他导入之前设置 PaddleOCR 模型路径 +# 解决 Windows 中文用户名路径问题 +_PROJECT_ROOT = Path(__file__).parent +_MODELS_DIR = _PROJECT_ROOT / "models" +_MODELS_DIR.mkdir(exist_ok=True) +os.environ["PADDLEOCR_HOME"] = str(_MODELS_DIR) + +import argparse +import json +import sys +import cv2 +from typing import Optional, List, Generator + +from input.loader import ImageLoader, ImageInfo +from ocr.pipeline import OCRPipeline, OCRResult +from visualize.draw import OCRVisualizer +from utils.config import ( + Config, + InputConfig, + InputMode, + OCRConfig, + PipelineConfig, + VisualizeConfig, + OutputConfig, + ROIConfig +) + + +class OCRApplication: + """ + OCR 应用主类 + 协调各模块完成图片 OCR 识别 + """ + + def __init__( + self, + config: Config, + express_mode: bool = False + ): + """ + 初始化应用 + + Args: + config: 全局配置 + express_mode: 是否启用快递单解析模式 + """ + self._config = config + self._loader: Optional[ImageLoader] = None + self._pipeline: Optional[OCRPipeline] = None + self._visualizer: Optional[OCRVisualizer] = None + self._express_mode = express_mode + self._all_results: List[dict] = [] + + def initialize(self) -> bool: + """ + 初始化所有组件 + + Returns: + 是否初始化成功 + """ + print("[INFO] 正在初始化 OCR 系统...") + + # 创建图片加载器 + self._loader = ImageLoader() + + # 创建 OCR 管道 + self._pipeline = OCRPipeline( + ocr_config=self._config.ocr, + pipeline_config=self._config.pipeline + ) + + # 创建可视化器 + self._visualizer = OCRVisualizer(self._config.visualize) + + # 初始化 OCR 管道(预加载模型) + print("[INFO] 正在加载 OCR 模型...") + self._pipeline.initialize() + print("[INFO] OCR 模型加载完成") + + return True + + def _get_images(self) -> Generator[ImageInfo, None, None]: + """ + 根据配置获取图片 + + Yields: + ImageInfo 对象 + """ + input_config = self._config.input + if input_config is None: + return + + if input_config.mode == InputMode.SINGLE: + info = self._loader.load(input_config.image_path) + if info: + yield info + + elif input_config.mode == InputMode.BATCH: + yield from self._loader.load_batch(input_config.image_paths) + + elif input_config.mode == InputMode.DIRECTORY: + yield from self._loader.load_directory( + input_config.directory, + input_config.pattern, + input_config.recursive + ) + + def run(self) -> None: + """运行图片处理""" + if self._loader is None or self._pipeline is None or self._visualizer is None: + print("[ERROR] 系统未初始化") + return + + self._all_results = [] + print("[INFO] 开始 OCR 识别...") + + try: + for image_info in self._get_images(): + print(f"\n[INFO] 处理图片: {image_info.filename}") + + # OCR 处理 + result = self._pipeline.process(image_info.image, image_info.path) + + # 收集结果 + if result: + if self._express_mode: + # 快递单模式:解析并收集结构化结果 + express_info = result.parse_express() + self._all_results.append({ + "image_index": result.image_index, + "image_path": result.image_path, + "processing_time_ms": result.processing_time_ms, + "express_info": express_info.to_dict(), + "merged_text": result.merge_text() + }) + else: + self._all_results.append(result.to_dict()) + + # 打印结果 + if result and result.text_count > 0: + if self._express_mode: + self._print_express_result(result) + else: + self._print_result(result) + + # 可视化并保存 + if result: + self._handle_visualization(image_info, result) + + except KeyboardInterrupt: + print("\n[INFO] 收到中断信号,正在退出...") + + finally: + # 导出汇总结果 + self._export_summary() + self.cleanup() + + def _print_result(self, result: OCRResult) -> None: + """ + 打印 OCR 结果到控制台 + + Args: + result: OCR 结果 + """ + print(f" 识别到 {result.text_count} 个文本块 (耗时: {result.processing_time_ms:.1f}ms)") + for i, block in enumerate(result.text_blocks): + print(f" [{i+1}] {block.text} (置信度: {block.confidence:.3f})") + + def _print_express_result(self, result: OCRResult) -> None: + """ + 打印快递单解析结果到控制台 + + Args: + result: OCR 结果 + """ + express_info = result.parse_express() + print(f" 快递单解析结果 (耗时: {result.processing_time_ms:.1f}ms)") + + if express_info.courier_company: + print(f" 快递公司: {express_info.courier_company}") + if express_info.tracking_number: + print(f" 运单号: {express_info.tracking_number}") + if express_info.receiver_name: + print(f" 收件人: {express_info.receiver_name}") + if express_info.receiver_phone: + print(f" 收件电话: {express_info.receiver_phone}") + if express_info.receiver_address: + print(f" 收件地址: {express_info.receiver_address}") + if express_info.sender_name: + print(f" 寄件人: {express_info.sender_name}") + if express_info.sender_phone: + print(f" 寄件电话: {express_info.sender_phone}") + if express_info.sender_address: + print(f" 寄件地址: {express_info.sender_address}") + + if not express_info.is_valid: + print(" [未识别到有效快递单信息]") + print(f" 合并文本: {result.merge_text()}") + + def _handle_visualization(self, image_info: ImageInfo, result: OCRResult) -> None: + """ + 处理可视化和图片保存 + + Args: + image_info: 图片信息 + result: OCR 结果 + """ + # 绘制结果 + display_image = self._visualizer.draw_result(image_info.image, result) + + # 显示窗口 + if self._config.visualize.show_window: + key = self._visualizer.show(display_image, wait_key=0) + if key == ord('q') or key == ord('Q'): + raise KeyboardInterrupt() + + # 保存标注后的图片 + if self._config.output.save_image: + self._save_annotated_image(image_info, display_image) + + def _save_annotated_image(self, image_info: ImageInfo, annotated_image) -> None: + """ + 保存标注后的图片 + + Args: + image_info: 原始图片信息 + annotated_image: 标注后的图片 + """ + output_config = self._config.output + + # 确定输出目录 + if output_config.output_dir: + output_dir = Path(output_config.output_dir) + else: + output_dir = Path(image_info.path).parent + + output_dir.mkdir(parents=True, exist_ok=True) + + # 生成输出文件名 + original_path = Path(image_info.path) + output_filename = f"{original_path.stem}{output_config.image_suffix}{original_path.suffix}" + output_path = output_dir / output_filename + + # 保存图片(支持中文路径) + _, ext = os.path.splitext(str(output_path)) + success, encoded = cv2.imencode(ext, annotated_image) + if success: + with open(output_path, 'wb') as f: + f.write(encoded.tobytes()) + print(f" [INFO] 标注图片已保存: {output_path}") + + def _export_summary(self) -> None: + """ + 导出所有识别结果到汇总 JSON 文件 + """ + if not self._all_results: + print("[INFO] 没有识别结果需要导出") + return + + output_config = self._config.output + if not output_config.save_json: + return + + # 确定输出路径 + if output_config.output_dir: + output_dir = Path(output_config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / output_config.json_filename + else: + output_path = Path(output_config.json_filename) + + # 构建汇总数据 + summary = { + "total_images": len(self._all_results), + "total_text_blocks": sum(r.get("text_count", 0) for r in self._all_results), + "results": self._all_results + } + + # 写入文件 + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(summary, f, ensure_ascii=False, indent=2) + + print(f"[INFO] 汇总结果已导出到 {output_path},共 {summary['total_images']} 张图片") + + def cleanup(self) -> None: + """清理资源""" + if self._visualizer: + self._visualizer.close() + print("[INFO] 资源已释放") + + +def parse_args() -> argparse.Namespace: + """解析命令行参数""" + parser = argparse.ArgumentParser( + description="OCR 图片识别系统", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例: + # 识别单张图片 + python main.py --image path/to/image.jpg + + # 识别目录中的所有图片 + python main.py --dir path/to/images/ + + # 识别目录中的特定格式图片 + python main.py --dir path/to/images/ --pattern "*.png" + + # 递归搜索子目录 + python main.py --dir path/to/images/ --recursive + + # 启用快递单解析模式 + python main.py --image express.jpg --express + + # 保存标注后的图片 + python main.py --image test.jpg --save-image + + # 指定输出目录 + python main.py --dir images/ --output-dir results/ + + # 启用 ROI 裁剪(画面中央 60% 区域) + python main.py --image test.jpg --roi 0.2 0.2 0.6 0.6 + + # 使用 GPU 加速 + python main.py --image test.jpg --gpu + """ + ) + + # 输入源(互斥) + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument( + "--image", "-i", + type=str, + help="单张图片路径" + ) + input_group.add_argument( + "--dir", "-d", + type=str, + help="图片目录路径" + ) + + # 目录模式选项 + parser.add_argument( + "--pattern", "-p", + type=str, + default=None, + help="文件匹配模式(如 '*.jpg')" + ) + parser.add_argument( + "--recursive", "-r", + action="store_true", + help="递归搜索子目录" + ) + + # OCR 配置 + parser.add_argument( + "--lang", "-l", + type=str, + default="ch", + help="OCR 语言(默认: ch)" + ) + parser.add_argument( + "--gpu", + action="store_true", + help="启用 GPU 加速" + ) + parser.add_argument( + "--no-angle-cls", + action="store_true", + help="禁用方向分类" + ) + parser.add_argument( + "--drop-score", + type=float, + default=0.5, + help="置信度阈值(默认: 0.5)" + ) + + # ROI 配置 + parser.add_argument( + "--roi", + type=float, + nargs=4, + metavar=("X", "Y", "W", "H"), + help="ROI 区域(归一化坐标: x y width height)" + ) + + # 可视化配置 + parser.add_argument( + "--show-window", + action="store_true", + help="显示可视化窗口(默认不显示)" + ) + parser.add_argument( + "--no-confidence", + action="store_true", + help="不显示置信度" + ) + + # 输出配置 + parser.add_argument( + "--output-dir", "-o", + type=str, + default=None, + help="输出目录路径" + ) + parser.add_argument( + "--save-image", + action="store_true", + help="保存标注后的图片" + ) + parser.add_argument( + "--no-json", + action="store_true", + help="不保存 JSON 结果" + ) + parser.add_argument( + "--json-filename", + type=str, + default="ocr_result.json", + help="JSON 结果文件名(默认: ocr_result.json)" + ) + + # 快递单解析模式 + parser.add_argument( + "--express", "-e", + action="store_true", + help="启用快递单解析模式,自动合并文本并提取结构化信息" + ) + + return parser.parse_args() + + +def build_config(args: argparse.Namespace) -> Config: + """ + 根据命令行参数构建配置 + + Args: + args: 命令行参数 + + Returns: + 配置对象 + """ + # 输入配置 + if args.image: + input_config = InputConfig( + mode=InputMode.SINGLE, + image_path=args.image + ) + else: + input_config = InputConfig( + mode=InputMode.DIRECTORY, + directory=args.dir, + pattern=args.pattern, + recursive=args.recursive + ) + + # OCR 配置 + # 设置模型目录(解决 Windows 中文用户名路径问题) + det_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_det_infer") + rec_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_rec_infer") + cls_model_dir = str(_MODELS_DIR / "ch_ppocr_mobile_v2.0_cls_infer") + + # 检查模型是否已下载 + models_exist = ( + Path(det_model_dir).exists() and + Path(rec_model_dir).exists() and + Path(cls_model_dir).exists() + ) + + ocr_config = OCRConfig( + lang=args.lang, + use_angle_cls=not args.no_angle_cls, + use_gpu=args.gpu, + drop_score=args.drop_score, + det_model_dir=det_model_dir if models_exist else None, + rec_model_dir=rec_model_dir if models_exist else None, + cls_model_dir=cls_model_dir if models_exist else None + ) + + # ROI 配置 + roi_config = ROIConfig(enabled=False) + if args.roi: + roi_config = ROIConfig( + enabled=True, + x_ratio=args.roi[0], + y_ratio=args.roi[1], + width_ratio=args.roi[2], + height_ratio=args.roi[3] + ) + + # 管道配置 + pipeline_config = PipelineConfig(roi=roi_config) + + # 可视化配置 + visualize_config = VisualizeConfig( + show_window=args.show_window, + show_confidence=not args.no_confidence + ) + + # 输出配置 + output_config = OutputConfig( + output_dir=args.output_dir, + save_json=not args.no_json, + save_image=args.save_image, + json_filename=args.json_filename + ) + + return Config( + input=input_config, + ocr=ocr_config, + pipeline=pipeline_config, + visualize=visualize_config, + output=output_config + ) + + +def main() -> int: + """主函数""" + args = parse_args() + config = build_config(args) + + # 检查模型是否已下载 + if config.ocr.det_model_dir is None: + print("[WARN] 模型未在项目目录中找到") + print("[WARN] 对于 Windows 中文用户名用户,请先运行:") + print("[WARN] python download_models.py") + print("[INFO] 回退到默认 PaddleOCR 模型路径...") + + app = OCRApplication( + config, + express_mode=args.express + ) + + if not app.initialize(): + return 1 + + app.run() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/models/ch_PP-OCRv4_det_infer/inference.pdiparams b/models/ch_PP-OCRv4_det_infer/inference.pdiparams new file mode 100644 index 0000000..089594a Binary files /dev/null and b/models/ch_PP-OCRv4_det_infer/inference.pdiparams differ diff --git a/models/ch_PP-OCRv4_det_infer/inference.pdiparams.info b/models/ch_PP-OCRv4_det_infer/inference.pdiparams.info new file mode 100644 index 0000000..082c148 Binary files /dev/null and b/models/ch_PP-OCRv4_det_infer/inference.pdiparams.info differ diff --git a/models/ch_PP-OCRv4_det_infer/inference.pdmodel b/models/ch_PP-OCRv4_det_infer/inference.pdmodel new file mode 100644 index 0000000..223b861 Binary files /dev/null and b/models/ch_PP-OCRv4_det_infer/inference.pdmodel differ diff --git a/models/ch_PP-OCRv4_rec_infer/inference.pdiparams b/models/ch_PP-OCRv4_rec_infer/inference.pdiparams new file mode 100644 index 0000000..4c3d9e9 Binary files /dev/null and b/models/ch_PP-OCRv4_rec_infer/inference.pdiparams differ diff --git a/models/ch_PP-OCRv4_rec_infer/inference.pdiparams.info b/models/ch_PP-OCRv4_rec_infer/inference.pdiparams.info new file mode 100644 index 0000000..923329f Binary files /dev/null and b/models/ch_PP-OCRv4_rec_infer/inference.pdiparams.info differ diff --git a/models/ch_PP-OCRv4_rec_infer/inference.pdmodel b/models/ch_PP-OCRv4_rec_infer/inference.pdmodel new file mode 100644 index 0000000..dccddcc Binary files /dev/null and b/models/ch_PP-OCRv4_rec_infer/inference.pdmodel differ diff --git a/models/ch_ppocr_mobile_v2.0_cls_infer/._inference.pdmodel b/models/ch_ppocr_mobile_v2.0_cls_infer/._inference.pdmodel new file mode 100644 index 0000000..87503bf Binary files /dev/null and b/models/ch_ppocr_mobile_v2.0_cls_infer/._inference.pdmodel differ diff --git a/models/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams b/models/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams new file mode 100644 index 0000000..3449efb Binary files /dev/null and b/models/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams differ diff --git a/models/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams.info b/models/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams.info new file mode 100644 index 0000000..f31a157 Binary files /dev/null and b/models/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams.info differ diff --git a/models/ch_ppocr_mobile_v2.0_cls_infer/inference.pdmodel b/models/ch_ppocr_mobile_v2.0_cls_infer/inference.pdmodel new file mode 100644 index 0000000..b90c155 Binary files /dev/null and b/models/ch_ppocr_mobile_v2.0_cls_infer/inference.pdmodel differ diff --git a/ocr/__init__.py b/ocr/__init__.py new file mode 100644 index 0000000..2b1fb32 --- /dev/null +++ b/ocr/__init__.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +""" +OCR 模块 +提供 OCR 引擎和处理管道 +""" + +from .engine import OCREngine, TextBlock +from .pipeline import OCRPipeline, OCRResult +from .express_parser import ExpressParser, ExpressInfo, TextLine + +__all__ = [ + "OCREngine", + "TextBlock", + "OCRPipeline", + "OCRResult", + "ExpressParser", + "ExpressInfo", + "TextLine" +] diff --git a/ocr/engine.py b/ocr/engine.py new file mode 100644 index 0000000..ad07568 --- /dev/null +++ b/ocr/engine.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +""" +OCR 引擎模块 +封装 PaddleOCR,提供统一的 OCR 接口 +""" + +import os +from pathlib import Path + +# 在导入 PaddleOCR 之前设置环境变量 +# 解决 Windows 中文用户名路径问题 +_PROJECT_ROOT = Path(__file__).parent.parent +_MODELS_DIR = _PROJECT_ROOT / "models" +_MODELS_DIR.mkdir(exist_ok=True) +os.environ["PADDLEOCR_HOME"] = str(_MODELS_DIR) + +import numpy as np +from typing import List, Optional, Any +from dataclasses import dataclass +from paddleocr import PaddleOCR + +from utils.config import OCRConfig + + +@dataclass +class TextBlock: + """ + 文本块数据结构 + 表示 OCR 识别出的单个文本区域 + + Attributes: + text: 识别出的文本内容 + confidence: 置信度 (0.0 ~ 1.0) + bbox: 边界框,4 个点的坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] + bbox_offset: ROI 偏移量,用于还原到原图坐标 + """ + text: str + confidence: float + bbox: List[List[float]] + bbox_offset: tuple = (0, 0) + + @property + def bbox_with_offset(self) -> List[List[float]]: + """获取带偏移的边界框(还原到原图坐标)""" + offset_x, offset_y = self.bbox_offset + return [[p[0] + offset_x, p[1] + offset_y] for p in self.bbox] + + @property + def center(self) -> tuple: + """获取文本块中心点""" + x_coords = [p[0] for p in self.bbox] + y_coords = [p[1] for p in self.bbox] + return (sum(x_coords) / 4, sum(y_coords) / 4) + + @property + def width(self) -> float: + """获取文本块宽度""" + x_coords = [p[0] for p in self.bbox] + return max(x_coords) - min(x_coords) + + @property + def height(self) -> float: + """获取文本块高度""" + y_coords = [p[1] for p in self.bbox] + return max(y_coords) - min(y_coords) + + def to_dict(self) -> dict: + """转换为字典格式""" + return { + "text": self.text, + "confidence": self.confidence, + "bbox": self.bbox, + "bbox_with_offset": self.bbox_with_offset, + "center": self.center, + "width": self.width, + "height": self.height + } + + +class OCREngine: + """ + OCR 引擎类 + 封装 PaddleOCR,提供简洁的 OCR 调用接口 + """ + + def __init__(self, config: OCRConfig): + """ + 初始化 OCR 引擎 + + Args: + config: OCR 配置 + """ + self._config = config + self._ocr: Optional[PaddleOCR] = None + + def initialize(self) -> None: + """ + 初始化 PaddleOCR 实例 + 延迟初始化,避免在导入时加载模型 + 适配 PaddleOCR 2.x API + """ + if self._ocr is not None: + return + + # 构建参数 + params = { + "lang": self._config.lang, + "use_angle_cls": self._config.use_angle_cls, + "use_gpu": self._config.use_gpu, + "det_db_thresh": self._config.det_db_thresh, + "det_db_box_thresh": self._config.det_db_box_thresh, + "drop_score": self._config.drop_score, + "show_log": self._config.show_log + } + + # 如果指定了模型目录,则使用自定义路径(解决中文路径问题) + if self._config.det_model_dir: + params["det_model_dir"] = self._config.det_model_dir + if self._config.rec_model_dir: + params["rec_model_dir"] = self._config.rec_model_dir + if self._config.cls_model_dir: + params["cls_model_dir"] = self._config.cls_model_dir + + # PaddleOCR 2.x API + self._ocr = PaddleOCR(**params) + + def recognize( + self, + image: np.ndarray, + roi_offset: tuple = (0, 0) + ) -> List[TextBlock]: + """ + 对图像进行 OCR 识别 + + Args: + image: 输入图像 (numpy array, BGR 或灰度图) + roi_offset: ROI 偏移量 (x, y),用于还原坐标 + + Returns: + 识别结果列表 + """ + # 确保引擎已初始化 + if self._ocr is None: + self.initialize() + + # 执行 OCR (PaddleOCR 2.x API) + result = self._ocr.ocr(image, cls=self._config.use_angle_cls) + + # 解析结果 + text_blocks: List[TextBlock] = [] + + # PaddleOCR 返回格式: [[line1, line2, ...]] 或 None + if result is None or len(result) == 0: + return text_blocks + + # 遍历每一行结果 + for line in result: + if line is None: + continue + for item in line: + if item is None or len(item) < 2: + continue + + bbox = item[0] # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] + text_info = item[1] # (text, confidence) + + if len(text_info) < 2: + continue + + text = text_info[0] + confidence = float(text_info[1]) + + # 过滤低置信度结果 + if confidence < self._config.drop_score: + continue + + text_block = TextBlock( + text=text, + confidence=confidence, + bbox=bbox, + bbox_offset=roi_offset + ) + text_blocks.append(text_block) + + return text_blocks + + def recognize_batch( + self, + images: List[np.ndarray] + ) -> List[List[TextBlock]]: + """ + 批量 OCR 识别 + + Args: + images: 输入图像列表 + + Returns: + 每张图像的识别结果列表 + """ + return [self.recognize(img) for img in images] + + @property + def config(self) -> OCRConfig: + """获取当前配置""" + return self._config + + def update_config(self, **kwargs) -> None: + """ + 更新配置并重新初始化引擎 + + Args: + **kwargs: 要更新的配置项 + """ + for key, value in kwargs.items(): + if hasattr(self._config, key): + setattr(self._config, key, value) + + # 重新初始化 + self._ocr = None + self.initialize() diff --git a/ocr/express_parser.py b/ocr/express_parser.py new file mode 100644 index 0000000..2feae2c --- /dev/null +++ b/ocr/express_parser.py @@ -0,0 +1,404 @@ +# -*- coding: utf-8 -*- +""" +快递单解析模块 +将分散的 OCR 文本块合并并解析成结构化的快递单信息 +""" + +import re +from dataclasses import dataclass, field +from typing import List, Optional, Dict, Any +from .engine import TextBlock + + +@dataclass +class ExpressInfo: + """ + 快递单结构化信息 + + Attributes: + tracking_number: 运单号 + sender_name: 寄件人姓名 + sender_phone: 寄件人电话 + sender_address: 寄件人地址 + receiver_name: 收件人姓名 + receiver_phone: 收件人电话 + receiver_address: 收件人地址 + courier_company: 快递公司 + raw_text: 原始合并文本(用于调试) + confidence: 平均置信度 + extra_fields: 其他识别到的字段 + """ + tracking_number: Optional[str] = None + sender_name: Optional[str] = None + sender_phone: Optional[str] = None + sender_address: Optional[str] = None + receiver_name: Optional[str] = None + receiver_phone: Optional[str] = None + receiver_address: Optional[str] = None + courier_company: Optional[str] = None + raw_text: str = "" + confidence: float = 0.0 + extra_fields: Dict[str, str] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式""" + return { + "tracking_number": self.tracking_number, + "sender": { + "name": self.sender_name, + "phone": self.sender_phone, + "address": self.sender_address + }, + "receiver": { + "name": self.receiver_name, + "phone": self.receiver_phone, + "address": self.receiver_address + }, + "courier_company": self.courier_company, + "confidence": self.confidence, + "extra_fields": self.extra_fields, + "raw_text": self.raw_text + } + + @property + def is_valid(self) -> bool: + """检查是否包含有效的快递单信息""" + # 至少需要运单号或收件人信息 + return bool(self.tracking_number or self.receiver_name or self.receiver_phone) + + +@dataclass +class TextLine: + """ + 合并后的文本行 + + Attributes: + text: 合并后的文本 + blocks: 原始文本块列表 + y_center: 行中心 Y 坐标 + x_min: 行起始 X 坐标 + """ + text: str + blocks: List[TextBlock] + y_center: float + x_min: float + + @property + def confidence(self) -> float: + """计算平均置信度""" + if not self.blocks: + return 0.0 + return sum(b.confidence for b in self.blocks) / len(self.blocks) + + +class ExpressParser: + """ + 快递单解析器 + 将分散的文本块合并成行,并提取结构化信息 + """ + + # 快递公司关键词 + COURIER_KEYWORDS = { + "顺丰": "顺丰速运", + "SF": "顺丰速运", + "圆通": "圆通速递", + "中通": "中通快递", + "韵达": "韵达快递", + "申通": "申通快递", + "极兔": "极兔速递", + "京东": "京东物流", + "JD": "京东物流", + "邮政": "中国邮政", + "EMS": "中国邮政EMS", + "百世": "百世快递", + "德邦": "德邦快递", + "天天": "天天快递", + "宅急送": "宅急送", + } + + # 字段关键词模式 + FIELD_PATTERNS = { + "tracking_number": [ + r"运单号[::]\s*(\w+)", + r"单号[::]\s*(\w+)", + r"快递单号[::]\s*(\w+)", + r"物流单号[::]\s*(\w+)", + r"^(\d{10,20})$", # 纯数字运单号 + r"^([A-Z]{2}\d{9,13}[A-Z]{2})$", # 国际快递单号格式 + ], + "receiver_name": [ + r"收件人[::]\s*(.+?)(?:\s|电话|手机|地址|$)", + r"收货人[::]\s*(.+?)(?:\s|电话|手机|地址|$)", + r"收[::]\s*(.+?)(?:\s|电话|手机|地址|$)", + ], + "receiver_phone": [ + r"收件人.*?电话[::]\s*(\d{11})", + r"收件人.*?手机[::]\s*(\d{11})", + r"收.*?(\d{11})", + r"电话[::]\s*(\d{11})", + r"手机[::]\s*(\d{11})", + r"(? ExpressInfo: + """ + 解析文本块列表,提取快递单信息 + + Args: + text_blocks: OCR 识别的文本块列表 + + Returns: + 结构化的快递单信息 + """ + if not text_blocks: + return ExpressInfo() + + # 1. 合并文本块为行 + lines = self._merge_blocks_to_lines(text_blocks) + + # 2. 生成完整文本(用于正则匹配) + full_text = self._lines_to_text(lines) + + # 3. 提取结构化信息 + info = self._extract_info(full_text, lines) + + # 4. 计算平均置信度 + info.confidence = sum(b.confidence for b in text_blocks) / len(text_blocks) + info.raw_text = full_text + + return info + + def _merge_blocks_to_lines(self, blocks: List[TextBlock]) -> List[TextLine]: + """ + 将文本块按位置合并为行 + + 基于 Y 坐标将相近的文本块合并到同一行, + 然后按 X 坐标排序合并文本 + """ + if not blocks: + return [] + + # 按 Y 坐标排序 + sorted_blocks = sorted(blocks, key=lambda b: b.center[1]) + + lines: List[TextLine] = [] + current_line_blocks: List[TextBlock] = [sorted_blocks[0]] + current_y = sorted_blocks[0].center[1] + + for block in sorted_blocks[1:]: + block_y = block.center[1] + block_height = block.height + + # 判断是否属于同一行(Y 坐标差值小于阈值) + threshold = block_height * self._line_merge_threshold + if abs(block_y - current_y) <= threshold: + current_line_blocks.append(block) + else: + # 完成当前行,开始新行 + line = self._create_line(current_line_blocks) + lines.append(line) + current_line_blocks = [block] + current_y = block_y + + # 处理最后一行 + if current_line_blocks: + line = self._create_line(current_line_blocks) + lines.append(line) + + return lines + + def _create_line(self, blocks: List[TextBlock]) -> TextLine: + """ + 从文本块列表创建文本行 + + 按 X 坐标排序,根据间距决定是否添加空格 + """ + # 按 X 坐标排序 + sorted_blocks = sorted(blocks, key=lambda b: b.center[0]) + + # 合并文本 + text_parts = [] + prev_block = None + + for block in sorted_blocks: + if prev_block is not None: + # 计算水平间距 + prev_right = max(p[0] for p in prev_block.bbox) + curr_left = min(p[0] for p in block.bbox) + gap = curr_left - prev_right + + # 计算平均字符宽度 + avg_char_width = prev_block.width / max(len(prev_block.text), 1) + + # 如果间距较大,添加空格 + if gap > avg_char_width * self._horizontal_gap_threshold: + text_parts.append(" ") + + text_parts.append(block.text) + prev_block = block + + merged_text = "".join(text_parts) + y_center = sum(b.center[1] for b in sorted_blocks) / len(sorted_blocks) + x_min = min(min(p[0] for p in b.bbox) for b in sorted_blocks) + + return TextLine( + text=merged_text, + blocks=sorted_blocks, + y_center=y_center, + x_min=x_min + ) + + def _lines_to_text(self, lines: List[TextLine]) -> str: + """将文本行列表转换为完整文本""" + return "\n".join(line.text for line in lines) + + def _extract_info(self, full_text: str, lines: List[TextLine]) -> ExpressInfo: + """ + 从文本中提取快递单信息 + + Args: + full_text: 完整文本 + lines: 文本行列表 + + Returns: + 结构化的快递单信息 + """ + info = ExpressInfo() + + # 提取快递公司 + info.courier_company = self._extract_courier_company(full_text) + + # 提取各字段 + for field_name, patterns in self.FIELD_PATTERNS.items(): + value = self._extract_field(full_text, patterns) + if value: + setattr(info, field_name, value) + + # 尝试从上下文推断地址 + if not info.receiver_address: + info.receiver_address = self._extract_address_from_context(lines, "收") + + if not info.sender_address: + info.sender_address = self._extract_address_from_context(lines, "寄") + + return info + + def _extract_courier_company(self, text: str) -> Optional[str]: + """提取快递公司名称""" + text_upper = text.upper() + for keyword, company in self.COURIER_KEYWORDS.items(): + if keyword.upper() in text_upper: + return company + return None + + def _extract_field(self, text: str, patterns: List[str]) -> Optional[str]: + """ + 使用正则表达式列表提取字段值 + + Args: + text: 待匹配文本 + patterns: 正则表达式列表 + + Returns: + 匹配到的字段值,或 None + """ + for pattern in patterns: + match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE) + if match: + value = match.group(1).strip() + # 清理常见的干扰字符 + value = re.sub(r'[【】\[\]()()]', '', value) + if value: + return value + return None + + def _extract_address_from_context( + self, + lines: List[TextLine], + context_keyword: str + ) -> Optional[str]: + """ + 从上下文中提取地址 + + 查找包含省/市/区/县/街/路等关键词的行 + """ + address_keywords = ["省", "市", "区", "县", "镇", "村", "街", "路", "号", "栋", "楼", "室"] + + # 查找包含上下文关键词的行索引 + context_line_idx = -1 + for i, line in enumerate(lines): + if context_keyword in line.text: + context_line_idx = i + break + + # 在上下文行附近查找地址 + search_range = range( + max(0, context_line_idx), + min(len(lines), context_line_idx + 3 if context_line_idx >= 0 else len(lines)) + ) + + address_parts = [] + for i in search_range: + line_text = lines[i].text + # 检查是否包含地址关键词 + if any(kw in line_text for kw in address_keywords): + # 清理行首的标签(如 "地址:") + cleaned = re.sub(r'^[^::]*[::]\s*', '', line_text) + if cleaned and cleaned != line_text: + address_parts.append(cleaned) + elif any(kw in line_text for kw in address_keywords[:4]): # 省/市/区/县 + address_parts.append(line_text) + + if address_parts: + return "".join(address_parts) + + return None + + def merge_text_blocks(self, text_blocks: List[TextBlock]) -> str: + """ + 仅合并文本块,不进行字段提取 + + 用于获取完整的合并文本 + + Args: + text_blocks: 文本块列表 + + Returns: + 合并后的完整文本 + """ + lines = self._merge_blocks_to_lines(text_blocks) + return self._lines_to_text(lines) diff --git a/ocr/pipeline.py b/ocr/pipeline.py new file mode 100644 index 0000000..bca7031 --- /dev/null +++ b/ocr/pipeline.py @@ -0,0 +1,303 @@ +# -*- coding: utf-8 -*- +""" +OCR 处理管道模块 +提供图片 OCR 识别和结果解析的完整处理流程 +""" + +import time +import numpy as np +from typing import List, Optional, Dict, Any, Callable, TYPE_CHECKING +from dataclasses import dataclass, field + +from ocr.engine import OCREngine, TextBlock +from utils.config import PipelineConfig, OCRConfig + +if TYPE_CHECKING: + from ocr.express_parser import ExpressInfo, ExpressParser + + +@dataclass +class OCRResult: + """ + OCR 处理结果数据结构 + + Attributes: + image_index: 图片索引(批量处理时使用) + image_path: 图片路径 + timestamp: 处理时间戳 + processing_time_ms: 处理耗时(毫秒) + text_blocks: 识别出的文本块列表 + roi_applied: 是否应用了 ROI 裁剪 + roi_rect: ROI 矩形 (x, y, w, h),如果应用了 ROI + """ + image_index: int + image_path: Optional[str] + timestamp: float + processing_time_ms: float + text_blocks: List[TextBlock] + roi_applied: bool = False + roi_rect: Optional[tuple] = None + + @property + def text_count(self) -> int: + """识别出的文本数量""" + return len(self.text_blocks) + + @property + def all_texts(self) -> List[str]: + """获取所有识别出的文本""" + return [block.text for block in self.text_blocks] + + @property + def full_text(self) -> str: + """获取所有文本拼接结果""" + return "\n".join(self.all_texts) + + @property + def average_confidence(self) -> float: + """获取平均置信度""" + if not self.text_blocks: + return 0.0 + return sum(b.confidence for b in self.text_blocks) / len(self.text_blocks) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式,便于 JSON 序列化""" + return { + "image_index": self.image_index, + "image_path": self.image_path, + "timestamp": self.timestamp, + "processing_time_ms": self.processing_time_ms, + "text_count": self.text_count, + "average_confidence": self.average_confidence, + "roi_applied": self.roi_applied, + "roi_rect": self.roi_rect, + "text_blocks": [block.to_dict() for block in self.text_blocks] + } + + def filter_by_confidence(self, min_confidence: float) -> "OCRResult": + """ + 按置信度过滤结果 + + Args: + min_confidence: 最小置信度阈值 + + Returns: + 过滤后的 OCRResult + """ + filtered_blocks = [ + block for block in self.text_blocks + if block.confidence >= min_confidence + ] + return OCRResult( + image_index=self.image_index, + image_path=self.image_path, + timestamp=self.timestamp, + processing_time_ms=self.processing_time_ms, + text_blocks=filtered_blocks, + roi_applied=self.roi_applied, + roi_rect=self.roi_rect + ) + + def parse_express(self) -> "ExpressInfo": + """ + 解析快递单信息 + + 将分散的文本块合并并提取结构化的快递单信息 + + Returns: + 结构化的快递单信息 + """ + from ocr.express_parser import ExpressParser + parser = ExpressParser() + return parser.parse(self.text_blocks) + + def merge_text(self) -> str: + """ + 合并文本块为完整文本 + + 基于位置信息智能合并,同一行的文本会被合并 + + Returns: + 合并后的完整文本 + """ + from ocr.express_parser import ExpressParser + parser = ExpressParser() + return parser.merge_text_blocks(self.text_blocks) + + +class OCRPipeline: + """ + OCR 处理管道 + 负责 ROI 裁剪、OCR 调用、结果封装 + """ + + def __init__( + self, + ocr_config: OCRConfig, + pipeline_config: Optional[PipelineConfig] = None + ): + """ + 初始化 OCR 管道 + + Args: + ocr_config: OCR 引擎配置 + pipeline_config: 管道配置(可选) + """ + self._ocr_config = ocr_config + self._pipeline_config = pipeline_config or PipelineConfig() + self._engine = OCREngine(ocr_config) + self._image_counter: int = 0 + + # 预留扩展点:图片预处理回调 + self._image_preprocessors: List[Callable[[np.ndarray], np.ndarray]] = [] + + # 预留扩展点:结果后处理回调 + self._result_postprocessors: List[Callable[[OCRResult], OCRResult]] = [] + + def initialize(self) -> None: + """初始化管道(预加载 OCR 模型)""" + self._engine.initialize() + + def add_preprocessor( + self, + preprocessor: Callable[[np.ndarray], np.ndarray] + ) -> None: + """ + 添加图片预处理器 + + Args: + preprocessor: 预处理函数,接收图像返回处理后的图像 + """ + self._image_preprocessors.append(preprocessor) + + def add_postprocessor( + self, + postprocessor: Callable[[OCRResult], OCRResult] + ) -> None: + """ + 添加结果后处理器 + + Args: + postprocessor: 后处理函数,接收 OCRResult 返回处理后的结果 + """ + self._result_postprocessors.append(postprocessor) + + def _apply_roi( + self, + image: np.ndarray + ) -> tuple: + """ + 应用 ROI 裁剪 + + Args: + image: 原始图片 + + Returns: + (裁剪后的图像, ROI 偏移量, ROI 矩形) + """ + roi_config = self._pipeline_config.roi + + if not roi_config.enabled: + return image, (0, 0), None + + h, w = image.shape[:2] + x, y, roi_w, roi_h = roi_config.get_roi_rect(w, h) + + # 边界检查 + x = max(0, min(x, w - 1)) + y = max(0, min(y, h - 1)) + roi_w = min(roi_w, w - x) + roi_h = min(roi_h, h - y) + + cropped = image[y:y + roi_h, x:x + roi_w] + return cropped, (x, y), (x, y, roi_w, roi_h) + + def _preprocess_image(self, image: np.ndarray) -> np.ndarray: + """ + 执行图片预处理 + + Args: + image: 原始图片 + + Returns: + 预处理后的图片 + """ + processed = image + for preprocessor in self._image_preprocessors: + processed = preprocessor(processed) + return processed + + def _postprocess_result(self, result: OCRResult) -> OCRResult: + """ + 执行结果后处理 + + Args: + result: 原始结果 + + Returns: + 后处理后的结果 + """ + processed = result + for postprocessor in self._result_postprocessors: + processed = postprocessor(processed) + return processed + + def process( + self, + image: np.ndarray, + image_path: Optional[str] = None + ) -> OCRResult: + """ + 处理单张图片 + + Args: + image: 输入图片 (numpy array, BGR 格式) + image_path: 图片路径(可选,用于结果记录) + + Returns: + OCR 结果 + """ + self._image_counter += 1 + start_time = time.time() + + # 应用 ROI 裁剪 + cropped_image, roi_offset, roi_rect = self._apply_roi(image) + + # 图片预处理 + processed_image = self._preprocess_image(cropped_image) + + # 执行 OCR + text_blocks = self._engine.recognize(processed_image, roi_offset) + + # 计算处理耗时 + processing_time_ms = (time.time() - start_time) * 1000 + + # 构建结果 + result = OCRResult( + image_index=self._image_counter, + image_path=image_path, + timestamp=time.time(), + processing_time_ms=processing_time_ms, + text_blocks=text_blocks, + roi_applied=self._pipeline_config.roi.enabled, + roi_rect=roi_rect + ) + + # 结果后处理 + result = self._postprocess_result(result) + + return result + + def reset_counter(self) -> None: + """重置图片计数器""" + self._image_counter = 0 + + @property + def image_counter(self) -> int: + """获取已处理的图片计数""" + return self._image_counter + + @property + def config(self) -> PipelineConfig: + """获取管道配置""" + return self._pipeline_config diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..cf485d4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +# Vision-OCR Dependencies + +# Core dependencies +paddlepaddle>=2.5.0,<3.0.0 +paddleocr>=2.7.0,<3.0.0 +opencv-python>=4.8.0 +numpy>=1.24.0,<2.0 + +# Optional dependencies (for Chinese font rendering) +Pillow>=10.0.0 diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..de2f941 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +""" +工具模块 +提供配置管理等通用功能 +""" + +from .config import ( + Config, + OCRConfig, + InputConfig, + VisualizeConfig, + PipelineConfig, + ROIConfig, + OutputConfig, + InputMode +) + +__all__ = [ + "Config", + "OCRConfig", + "InputConfig", + "VisualizeConfig", + "PipelineConfig", + "ROIConfig", + "OutputConfig", + "InputMode" +] diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000..9b54aaa --- /dev/null +++ b/utils/config.py @@ -0,0 +1,242 @@ +# -*- coding: utf-8 -*- +""" +配置管理模块 +集中管理所有可配置参数,便于维护和扩展 +""" + +from dataclasses import dataclass, field +from typing import Optional, Tuple, List +from enum import Enum + + +class InputMode(Enum): + """输入模式枚举""" + SINGLE = "single" # 单张图片 + BATCH = "batch" # 多张图片列表 + DIRECTORY = "directory" # 目录批量 + + +@dataclass +class InputConfig: + """ + 图片输入配置 + + Attributes: + mode: 输入模式 + image_path: 单张图片路径 + image_paths: 多张图片路径列表 + directory: 图片目录路径 + pattern: 文件匹配模式(如 "*.jpg") + recursive: 是否递归搜索子目录 + """ + mode: InputMode = InputMode.SINGLE + image_path: Optional[str] = None + image_paths: Optional[List[str]] = None + directory: Optional[str] = None + pattern: Optional[str] = None + recursive: bool = False + + def __post_init__(self): + """参数校验""" + if self.mode == InputMode.SINGLE and not self.image_path: + raise ValueError("单张图片模式下必须指定 image_path") + if self.mode == InputMode.BATCH and not self.image_paths: + raise ValueError("批量模式下必须指定 image_paths") + if self.mode == InputMode.DIRECTORY and not self.directory: + raise ValueError("目录模式下必须指定 directory") + + +@dataclass +class ROIConfig: + """ + 感兴趣区域(ROI)配置 + 使用归一化坐标 (0.0 ~ 1.0),便于适配不同分辨率 + + Attributes: + enabled: 是否启用 ROI 裁剪 + x_ratio: ROI 左上角 x 坐标比例 + y_ratio: ROI 左上角 y 坐标比例 + width_ratio: ROI 宽度比例 + height_ratio: ROI 高度比例 + """ + enabled: bool = False + x_ratio: float = 0.1 + y_ratio: float = 0.1 + width_ratio: float = 0.8 + height_ratio: float = 0.8 + + def get_roi_rect(self, frame_width: int, frame_height: int) -> Tuple[int, int, int, int]: + """ + 根据帧尺寸计算实际 ROI 矩形 + + Args: + frame_width: 帧宽度 + frame_height: 帧高度 + + Returns: + (x, y, width, height) 像素坐标 + """ + x = int(frame_width * self.x_ratio) + y = int(frame_height * self.y_ratio) + width = int(frame_width * self.width_ratio) + height = int(frame_height * self.height_ratio) + return x, y, width, height + + +@dataclass +class OCRConfig: + """ + OCR 引擎配置 (适配 PaddleOCR 2.x API) + + Attributes: + lang: 识别语言,支持 "ch"(中文), "en"(英文) 等 + use_angle_cls: 是否启用方向分类器 + use_gpu: 是否使用 GPU 加速 + det_db_thresh: 文本检测阈值 + det_db_box_thresh: 检测框阈值 + drop_score: 低于此置信度的结果将被过滤 + show_log: 是否显示 PaddleOCR 日志 + det_model_dir: 检测模型目录路径 + rec_model_dir: 识别模型目录路径 + cls_model_dir: 分类模型目录路径 + """ + lang: str = "ch" + use_angle_cls: bool = True + use_gpu: bool = False + det_db_thresh: float = 0.3 + det_db_box_thresh: float = 0.5 + drop_score: float = 0.5 + show_log: bool = False + det_model_dir: Optional[str] = None + rec_model_dir: Optional[str] = None + cls_model_dir: Optional[str] = None + + +@dataclass +class PipelineConfig: + """ + OCR 处理管道配置 + + Attributes: + roi: ROI 配置 + """ + roi: ROIConfig = field(default_factory=ROIConfig) + + +@dataclass +class VisualizeConfig: + """ + 可视化配置 + + Attributes: + show_window: 是否显示可视化窗口 + window_name: 窗口名称 + box_color: 文本框颜色 (B, G, R) + box_thickness: 文本框线宽 + text_color: 文字颜色 (B, G, R) + text_scale: 文字缩放比例 + text_thickness: 文字线宽 + show_confidence: 是否在文字旁显示置信度 + font_path: 中文字体路径,None 则使用 OpenCV 默认字体 + """ + show_window: bool = False + window_name: str = "OCR Result" + box_color: Tuple[int, int, int] = (0, 255, 0) + box_thickness: int = 2 + text_color: Tuple[int, int, int] = (0, 0, 255) + text_scale: float = 0.6 + text_thickness: int = 1 + show_confidence: bool = True + font_path: Optional[str] = None + + +@dataclass +class OutputConfig: + """ + 输出配置 + + Attributes: + output_dir: 输出目录路径 + save_json: 是否保存 JSON 结果 + save_image: 是否保存标注后的图片 + json_filename: JSON 文件名模板 + image_suffix: 标注图片后缀 + """ + output_dir: Optional[str] = None + save_json: bool = True + save_image: bool = False + json_filename: str = "ocr_result.json" + image_suffix: str = "_ocr" + + +@dataclass +class Config: + """ + 全局配置类,聚合所有配置模块 + + Attributes: + input: 输入配置 + ocr: OCR 引擎配置 + pipeline: 处理管道配置 + visualize: 可视化配置 + output: 输出配置 + """ + input: Optional[InputConfig] = None + ocr: OCRConfig = field(default_factory=OCRConfig) + pipeline: PipelineConfig = field(default_factory=PipelineConfig) + visualize: VisualizeConfig = field(default_factory=VisualizeConfig) + output: OutputConfig = field(default_factory=OutputConfig) + + @classmethod + def default(cls) -> "Config": + """创建默认配置""" + return cls() + + @classmethod + def for_single_image(cls, image_path: str) -> "Config": + """ + 创建单张图片模式的配置 + + Args: + image_path: 图片路径 + """ + config = cls() + config.input = InputConfig( + mode=InputMode.SINGLE, + image_path=image_path + ) + return config + + @classmethod + def for_directory(cls, directory: str, pattern: Optional[str] = None, recursive: bool = False) -> "Config": + """ + 创建目录批量模式的配置 + + Args: + directory: 目录路径 + pattern: 文件匹配模式 + recursive: 是否递归搜索 + """ + config = cls() + config.input = InputConfig( + mode=InputMode.DIRECTORY, + directory=directory, + pattern=pattern, + recursive=recursive + ) + return config + + @classmethod + def for_batch(cls, image_paths: List[str]) -> "Config": + """ + 创建批量图片模式的配置 + + Args: + image_paths: 图片路径列表 + """ + config = cls() + config.input = InputConfig( + mode=InputMode.BATCH, + image_paths=image_paths + ) + return config diff --git a/visualize/__init__.py b/visualize/__init__.py new file mode 100644 index 0000000..c97b8b5 --- /dev/null +++ b/visualize/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +""" +可视化模块 +提供 OCR 结果的可视化绘制功能 +""" + +from .draw import OCRVisualizer + +__all__ = ["OCRVisualizer"] diff --git a/visualize/draw.py b/visualize/draw.py new file mode 100644 index 0000000..297960a --- /dev/null +++ b/visualize/draw.py @@ -0,0 +1,368 @@ +# -*- coding: utf-8 -*- +""" +可视化模块 +在图像上绘制 OCR 识别结果 +""" + +import os +import sys +import cv2 +import numpy as np +from typing import List, Optional, Tuple + +from ocr.engine import TextBlock +from ocr.pipeline import OCRResult +from utils.config import VisualizeConfig + + +# Windows 系统常用中文字体列表(按优先级排序) +_WINDOWS_CHINESE_FONTS = [ + "msyh.ttc", # 微软雅黑 + "msyhbd.ttc", # 微软雅黑粗体 + "simhei.ttf", # 黑体 + "simsun.ttc", # 宋体 + "simkai.ttf", # 楷体 +] + +# Linux 系统常用中文字体路径 +_LINUX_CHINESE_FONTS = [ + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", + "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", + "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf", +] + + +def _find_system_chinese_font() -> Optional[str]: + """ + 自动查找系统中文字体 + + Returns: + 字体文件路径,未找到返回 None + """ + if sys.platform == "win32": + # Windows 字体目录 + fonts_dir = os.path.join(os.environ.get("WINDIR", "C:\\Windows"), "Fonts") + for font_name in _WINDOWS_CHINESE_FONTS: + font_path = os.path.join(fonts_dir, font_name) + if os.path.exists(font_path): + return font_path + else: + # Linux/macOS + for font_path in _LINUX_CHINESE_FONTS: + if os.path.exists(font_path): + return font_path + + return None + + +class OCRVisualizer: + """ + OCR 结果可视化器 + 在图像上绘制文本框和识别结果 + """ + + def __init__(self, config: VisualizeConfig): + """ + 初始化可视化器 + + Args: + config: 可视化配置 + """ + self._config = config + self._font = cv2.FONT_HERSHEY_SIMPLEX + + # 尝试加载中文字体 + self._pil_font = None + self._use_pil = False + + # 确定字体路径:优先使用配置的路径,否则自动检测系统字体 + font_path = config.font_path or _find_system_chinese_font() + + if font_path: + try: + from PIL import ImageFont + self._pil_font = ImageFont.truetype(font_path, 20) + self._use_pil = True + except Exception: + # 字体加载失败,使用 OpenCV 默认字体 + pass + + def draw_text_blocks( + self, + frame: np.ndarray, + text_blocks: List[TextBlock], + copy: bool = True + ) -> np.ndarray: + """ + 在帧上绘制文本块 + + Args: + frame: 输入帧 + text_blocks: 文本块列表 + copy: 是否复制帧(避免修改原帧) + + Returns: + 绘制后的帧 + """ + if copy: + frame = frame.copy() + + for block in text_blocks: + self._draw_single_block(frame, block) + + return frame + + def draw_result( + self, + frame: np.ndarray, + result: Optional[OCRResult], + copy: bool = True + ) -> np.ndarray: + """ + 在帧上绘制 OCR 结果 + + Args: + frame: 输入帧 + result: OCR 结果 + copy: 是否复制帧 + + Returns: + 绘制后的帧 + """ + if copy: + frame = frame.copy() + + if result is None: + return frame + + # 绘制 ROI 区域(如果启用) + if result.roi_applied and result.roi_rect: + self._draw_roi(frame, result.roi_rect) + + # 绘制所有文本块 + for block in result.text_blocks: + self._draw_single_block(frame, block) + + # 绘制状态信息 + self._draw_status(frame, result) + + return frame + + def _draw_single_block( + self, + frame: np.ndarray, + block: TextBlock + ) -> None: + """ + 绘制单个文本块 + + Args: + frame: 帧 + block: 文本块 + """ + # 获取带偏移的边界框坐标 + bbox = block.bbox_with_offset + points = np.array(bbox, dtype=np.int32) + + # 绘制多边形边框 + cv2.polylines( + frame, + [points], + isClosed=True, + color=self._config.box_color, + thickness=self._config.box_thickness + ) + + # 准备显示文本 + display_text = block.text + if self._config.show_confidence: + display_text = f"{block.text} ({block.confidence:.2f})" + + # 计算文本位置(在边界框左上角上方) + text_x = int(min(p[0] for p in bbox)) + text_y = int(min(p[1] for p in bbox)) - 5 + + # 确保文本不超出画面 + text_y = max(text_y, 20) + + # 绘制文本 + if self._use_pil and self._pil_font: + self._draw_text_pil(frame, display_text, (text_x, text_y)) + else: + self._draw_text_cv2(frame, display_text, (text_x, text_y)) + + def _draw_text_cv2( + self, + frame: np.ndarray, + text: str, + position: Tuple[int, int] + ) -> None: + """ + 使用 OpenCV 绘制文本(不支持中文,会显示方块) + + Args: + frame: 帧 + text: 文本 + position: 位置 (x, y) + """ + # 绘制文本背景(提高可读性) + (text_width, text_height), baseline = cv2.getTextSize( + text, + self._font, + self._config.text_scale, + self._config.text_thickness + ) + + x, y = position + cv2.rectangle( + frame, + (x, y - text_height - 5), + (x + text_width + 5, y + 5), + (255, 255, 255), + -1 + ) + + # 绘制文本 + cv2.putText( + frame, + text, + position, + self._font, + self._config.text_scale, + self._config.text_color, + self._config.text_thickness, + cv2.LINE_AA + ) + + def _draw_text_pil( + self, + frame: np.ndarray, + text: str, + position: Tuple[int, int] + ) -> None: + """ + 使用 PIL 绘制文本(支持中文) + + Args: + frame: 帧 + text: 文本 + position: 位置 (x, y) + """ + from PIL import Image, ImageDraw + + # OpenCV 图像转 PIL + pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_image) + + # 获取文本尺寸 + bbox = draw.textbbox(position, text, font=self._pil_font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + + x, y = position + + # 绘制背景 + draw.rectangle( + [x - 2, y - text_height - 2, x + text_width + 2, y + 2], + fill=(255, 255, 255) + ) + + # 绘制文本 + text_color_rgb = ( + self._config.text_color[2], + self._config.text_color[1], + self._config.text_color[0] + ) + draw.text(position, text, font=self._pil_font, fill=text_color_rgb) + + # PIL 图像转回 OpenCV + result = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) + np.copyto(frame, result) + + def _draw_roi( + self, + frame: np.ndarray, + roi_rect: Tuple[int, int, int, int] + ) -> None: + """ + 绘制 ROI 区域 + + Args: + frame: 帧 + roi_rect: ROI 矩形 (x, y, width, height) + """ + x, y, w, h = roi_rect + cv2.rectangle( + frame, + (x, y), + (x + w, y + h), + (255, 255, 0), # 青色 + 2, + cv2.LINE_AA + ) + + def _draw_status( + self, + frame: np.ndarray, + result: OCRResult + ) -> None: + """ + 绘制状态信息 + + Args: + frame: 帧 + result: OCR 结果 + """ + h, w = frame.shape[:2] + + # 状态文本 + status_lines = [ + f"Image: {result.image_index}", + f"Texts: {result.text_count}", + f"Time: {result.processing_time_ms:.1f}ms" + ] + + y_offset = 25 + for line in status_lines: + cv2.putText( + frame, + line, + (10, y_offset), + self._font, + 0.5, + (0, 255, 0), + 1, + cv2.LINE_AA + ) + y_offset += 20 + + def show( + self, + frame: np.ndarray, + wait_key: int = 1 + ) -> int: + """ + 显示帧并等待按键 + + Args: + frame: 帧 + wait_key: 等待时间(毫秒) + + Returns: + 按下的键码 + """ + if not self._config.show_window: + return -1 + + cv2.imshow(self._config.window_name, frame) + return cv2.waitKey(wait_key) + + def close(self) -> None: + """关闭所有窗口""" + cv2.destroyAllWindows() + + @property + def config(self) -> VisualizeConfig: + """获取配置""" + return self._config