master
蒋尚宏 4 weeks ago
commit 7570e1314d

48
.gitignore vendored

@ -0,0 +1,48 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual Environment
venv/
ENV/
env/
.venv/
# IDE
.idea/
.vscode/
*.swp
*.swo
# Model archives (keep extracted folders only)
models/*.tar
# Logs
*.log
log.txt
# OCR output
ocr_*.json
*_result.json
# OS
.DS_Store
Thumbs.db

@ -0,0 +1,605 @@
# Vision-OCR: 图片 OCR 识别系统
基于 PaddleOCR 的图片 OCR 识别系统,支持单张图片、批量图片和目录扫描,提供文本检测、识别、方向分类,输出结构化识别结果。
## 功能特性
- **多输入模式**: 单张图片、多张图片列表、目录批量扫描
- **完整 OCR 能力**: 文本检测 + 文本识别 + 方向分类
- **结构化输出**: 文字内容、置信度、位置信息4 点坐标)
- **快递单解析**: 自动合并分散文本块,提取运单号、收寄件人等结构化信息
- **可视化展示**: 在图片上绘制文本框和识别结果
- **结果导出**: 支持 JSON 结果导出和标注图片保存
- **ROI 裁剪**: 支持只识别图片指定区域
- **模块化设计**: 图片加载与 OCR 逻辑完全解耦,便于扩展
- **全本地运行**: 不依赖任何云服务
## 项目结构
```
vision-ocr/
├── input/ # 图片输入模块
│ ├── __init__.py
│ └── loader.py # 图片加载器
├── ocr/ # OCR 处理模块
│ ├── __init__.py
│ ├── engine.py # PaddleOCR 引擎封装
│ ├── pipeline.py # OCR 处理管道
│ └── express_parser.py # 快递单解析器
├── visualize/ # 可视化模块
│ ├── __init__.py
│ └── draw.py # 结果绘制器
├── utils/ # 工具模块
│ ├── __init__.py
│ └── config.py # 配置管理
├── models/ # 模型文件目录(运行 download_models.py 后生成)
├── main.py # 主入口
├── download_models.py # 模型下载脚本
├── requirements.txt # 依赖清单
└── README.md
```
## 环境要求
- Python 3.9+
- 支持的操作系统: Windows / Linux / macOS
## 安装
### 1. 克隆项目
```bash
git clone <repository-url>
cd vision-ocr
```
### 2. 创建虚拟环境(推荐)
```bash
python -m venv venv
# Windows
venv\Scripts\activate
# Linux/macOS
source venv/bin/activate
```
### 3. 安装依赖
```bash
pip install -r requirements.txt
```
### 4. 模型说明
本项目已内置 PaddleOCR 模型文件(位于 `models/` 目录clone 后即可直接使用,无需额外下载。
> **备用方案**:如果模型文件缺失或需要更新,可运行 `python download_models.py` 重新下载。
#### 模型详情
本项目使用 PaddleOCR 的 PP-OCRv4 系列模型,包含 3 个模型协同工作:
| 模型类型 | 模型名称 | 作用 | 大小 |
|---------|---------|------|------|
| **det (检测模型)** | ch_PP-OCRv4_det_infer | 定位图像中所有文本区域的位置,输出每个文本块的 4 点边界框坐标 | ~4.7MB |
| **rec (识别模型)** | ch_PP-OCRv4_rec_infer | 将检测到的文本区域图像转换为实际文字内容,输出文本和置信度 | ~10MB |
| **cls (方向分类模型)** | ch_ppocr_mobile_v2.0_cls_infer | 判断文本是正向(0度)还是倒置(180度),用于矫正倒置文本后再识别 | ~1.4MB |
**OCR 处理流程:**
```
输入图像 -> [det 检测] -> 文本区域 -> [cls 分类] -> 方向矫正 -> [rec 识别] -> 文字结果
```
**模型下载地址:**
- 检测模型: https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar
- 识别模型: https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_infer.tar
- 方向分类模型: https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
**注意:** 可通过 `--no-angle-cls` 参数禁用方向分类模型,适用于文本方向固定的场景,可略微提升处理速度。
### 5. GPU 加速(可选)
如需使用 GPU 加速,请安装对应 CUDA 版本的 PaddlePaddle
```bash
# CUDA 11.8
pip install paddlepaddle-gpu==2.5.2.post118 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
# CUDA 12.0
pip install paddlepaddle-gpu==2.5.2.post120 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
```
## 使用方法
### 基础用法
```bash
# 识别单张图片
python main.py --image path/to/image.jpg
# 识别目录中的所有图片
python main.py --dir path/to/images/
# 识别目录中的特定格式图片
python main.py --dir path/to/images/ --pattern "*.png"
# 递归搜索子目录
python main.py --dir path/to/images/ --recursive
```
### 高级选项
```bash
# 启用 GPU 加速
python main.py --image test.jpg --gpu
# 启用 ROI 区域裁剪(只识别画面中央 60% 区域)
python main.py --image test.jpg --roi 0.2 0.2 0.6 0.6
# 调整置信度阈值(过滤低置信度结果)
python main.py --image test.jpg --drop-score 0.7
# 切换识别语言
python main.py --image test.jpg --lang en # 英文
python main.py --image test.jpg --lang ch # 中文(默认)
# 禁用方向分类(轻微提升速度)
python main.py --image test.jpg --no-angle-cls
# 显示可视化窗口
python main.py --image test.jpg --show-window
# 保存标注后的图片
python main.py --image test.jpg --save-image
# 指定输出目录
python main.py --dir images/ --output-dir results/
```
### 完整参数列表
| 参数 | 简写 | 说明 | 默认值 |
|------|------|------|--------|
| `--image` | `-i` | 单张图片路径 | - |
| `--dir` | `-d` | 图片目录路径 | - |
| `--pattern` | `-p` | 文件匹配模式 | - |
| `--recursive` | `-r` | 递归搜索子目录 | False |
| `--lang` | `-l` | OCR 语言 | ch |
| `--gpu` | - | 启用 GPU 加速 | False |
| `--no-angle-cls` | - | 禁用方向分类 | False |
| `--drop-score` | - | 置信度阈值 | 0.5 |
| `--roi` | - | ROI 区域 (x y w h) | - |
| `--show-window` | - | 显示可视化窗口 | False |
| `--no-confidence` | - | 不显示置信度 | False |
| `--output-dir` | `-o` | 输出目录路径 | - |
| `--save-image` | - | 保存标注后的图片 | False |
| `--no-json` | - | 不保存 JSON 结果 | False |
| `--json-filename` | - | JSON 结果文件名 | ocr_result.json |
| `--express` | `-e` | 启用快递单解析模式 | False |
### 运行时快捷键
| 按键 | 功能 |
|------|------|
| `q` | 退出程序 |
| 任意键 | 处理下一张图片 |
### 结果导出
程序处理完成后会自动将所有识别结果导出到 JSON 文件:
- 默认输出文件:`ocr_result.json`
- 可通过 `--json-filename` 参数指定文件名
- 可通过 `--output-dir` 参数指定输出目录
**汇总 JSON 格式:**
```json
{
"total_images": 10,
"total_text_blocks": 45,
"results": [
{
"image_index": 1,
"image_path": "path/to/image.jpg",
"timestamp": 1704355200.123,
"processing_time_ms": 45.2,
"text_count": 3,
"average_confidence": 0.92,
"roi_applied": false,
"roi_rect": null,
"text_blocks": [
{
"text": "识别的文本",
"confidence": 0.95,
"bbox": [[100, 50], [200, 50], [200, 80], [100, 80]],
"bbox_with_offset": [[100, 50], [200, 50], [200, 80], [100, 80]],
"center": [150, 65],
"width": 100,
"height": 30
}
]
}
]
}
```
#### JSON 字段说明
**顶层字段**
| 字段 | 类型 | 说明 |
|------|------|------|
| `total_images` | int | 处理的图片总数 |
| `total_text_blocks` | int | 所有图片识别出的文本块总数 |
| `results` | array | 每张图片的识别结果数组 |
**单张图片结果字段**
| 字段 | 类型 | 说明 |
|------|------|------|
| `image_index` | int | 图片索引,从 1 开始递增 |
| `image_path` | string | 图片文件的完整路径 |
| `timestamp` | float | 处理完成时的 Unix 时间戳(秒) |
| `processing_time_ms` | float | OCR 处理耗时(毫秒) |
| `text_count` | int | 该图片识别出的文本块数量 |
| `average_confidence` | float | 所有文本块的平均置信度 (0.0 ~ 1.0) |
| `roi_applied` | bool | 是否应用了 ROI 区域裁剪 |
| `roi_rect` | array\|null | ROI 矩形区域 `[x, y, width, height]`,未应用时为 `null` |
| `text_blocks` | array | 识别出的文本块数组 |
**文本块 (text_blocks) 字段**
| 字段 | 类型 | 说明 |
|------|------|------|
| `text` | string | 识别出的文本内容 |
| `confidence` | float | 识别置信度 (0.0 ~ 1.0),越高表示识别结果越可靠 |
| `bbox` | array | 文本边界框的 4 个顶点坐标 `[[x1,y1], [x2,y2], [x3,y3], [x4,y4]]`,顺序为左上、右上、右下、左下。如果启用了 ROI坐标相对于 ROI 区域 |
| `bbox_with_offset` | array | 带偏移的边界框坐标,已还原到原图坐标系。格式同 `bbox` |
| `center` | array | 文本块中心点坐标 `[cx, cy]` |
| `width` | float | 文本块宽度(像素) |
| `height` | float | 文本块高度(像素) |
## 快递单解析模式
使用 `--express` 参数启用快递单解析模式,系统会自动:
1. **合并分散文本块**: 基于位置信息将同一行的文本块合并
2. **提取结构化信息**: 运单号、快递公司、收/寄件人姓名、电话、地址
### 使用方式
```bash
# 单张快递单图片
python main.py --image express.jpg --express
# 批量处理快递单图片
python main.py --dir express_images/ --express --output-dir results/
```
### 输出格式
快递单模式下的 JSON 输出格式:
```json
{
"total_images": 5,
"total_text_blocks": 50,
"results": [
{
"image_index": 1,
"image_path": "express.jpg",
"processing_time_ms": 45.2,
"express_info": {
"tracking_number": "SF1234567890",
"sender": {
"name": "张三",
"phone": "13800138000",
"address": "北京市朝阳区xxx路"
},
"receiver": {
"name": "李四",
"phone": "13900139000",
"address": "上海市浦东新区xxx路"
},
"courier_company": "顺丰速运",
"confidence": 0.95,
"extra_fields": {},
"raw_text": "顺丰速运\n运单号SF1234567890\n..."
},
"merged_text": "顺丰速运\n运单号SF1234567890\n收件人李四 13900139000\n..."
}
]
}
```
#### 快递单模式 JSON 字段说明
**单张图片结果字段(快递单模式)**
| 字段 | 类型 | 说明 |
|------|------|------|
| `image_index` | int | 图片索引,从 1 开始递增 |
| `image_path` | string | 图片文件的完整路径 |
| `processing_time_ms` | float | OCR 处理耗时(毫秒) |
| `express_info` | object | 解析出的快递单结构化信息 |
| `merged_text` | string | 基于位置信息智能合并后的完整文本,同一行的文本块会被合并 |
**快递单信息 (express_info) 字段**
| 字段 | 类型 | 说明 |
|------|------|------|
| `tracking_number` | string\|null | 运单号/快递单号 |
| `sender` | object | 寄件人信息 |
| `sender.name` | string\|null | 寄件人姓名 |
| `sender.phone` | string\|null | 寄件人电话11位手机号 |
| `sender.address` | string\|null | 寄件人地址 |
| `receiver` | object | 收件人信息 |
| `receiver.name` | string\|null | 收件人姓名 |
| `receiver.phone` | string\|null | 收件人电话11位手机号 |
| `receiver.address` | string\|null | 收件人地址 |
| `courier_company` | string\|null | 快递公司名称(如:顺丰速运、圆通速递等) |
| `confidence` | float | 所有文本块的平均置信度 (0.0 ~ 1.0) |
| `extra_fields` | object | 其他识别到的额外字段(键值对形式) |
| `raw_text` | string | 原始合并文本,用于调试和验证 |
### 支持的快递公司
顺丰、圆通、中通、韵达、申通、极兔、京东、邮政、EMS、百世、德邦、天天、宅急送
### 编程接口
```python
from ocr import OCRPipeline, ExpressParser
# 使用 OCRResult 的 parse_express() 方法
result = pipeline.process(image)
if result and result.text_count > 0:
# 解析快递单信息
express_info = result.parse_express()
print(f"运单号: {express_info.tracking_number}")
print(f"收件人: {express_info.receiver_name}")
print(f"收件电话: {express_info.receiver_phone}")
# 获取合并后的完整文本
merged_text = result.merge_text()
print(f"合并文本: {merged_text}")
```
## 编程接口
### 作为模块使用
```python
from input import ImageLoader
from ocr import OCRPipeline
from visualize import OCRVisualizer
from utils import Config, InputConfig, InputMode
# 创建配置
config = Config.for_single_image("path/to/image.jpg")
# 创建组件
loader = ImageLoader()
pipeline = OCRPipeline(config.ocr, config.pipeline)
visualizer = OCRVisualizer(config.visualize)
# 初始化
pipeline.initialize()
# 加载并处理图片
image_info = loader.load("path/to/image.jpg")
if image_info:
result = pipeline.process(image_info.image, image_info.path)
if result and result.text_count > 0:
# 获取识别结果
for block in result.text_blocks:
print(f"文本: {block.text}")
print(f"置信度: {block.confidence}")
print(f"位置: {block.bbox}")
# 导出为 JSON
json_data = result.to_dict()
# 可视化
display_image = visualizer.draw_result(image_info.image, result)
visualizer.show(display_image, wait_key=0)
# 清理资源
visualizer.close()
```
### 批量处理
```python
from input import ImageLoader
from ocr import OCRPipeline
from utils import Config
# 创建配置
config = Config.for_directory("path/to/images/", pattern="*.jpg")
# 创建组件
loader = ImageLoader()
pipeline = OCRPipeline(config.ocr)
pipeline.initialize()
# 批量处理
for image_info in loader.load_directory("path/to/images/"):
result = pipeline.process(image_info.image, image_info.path)
print(f"{image_info.filename}: 识别到 {result.text_count} 个文本块")
```
### OCRResult 数据结构
```python
{
"image_index": 1,
"image_path": "path/to/image.jpg",
"timestamp": 1704355200.123,
"processing_time_ms": 45.6,
"text_count": 3,
"average_confidence": 0.92,
"roi_applied": False,
"roi_rect": None,
"text_blocks": [
{
"text": "识别的文本",
"confidence": 0.95,
"bbox": [[x1, y1], [x2, y2], [x3, y3], [x4, y4]],
"bbox_with_offset": [[x1, y1], [x2, y2], [x3, y3], [x4, y4]],
"center": [cx, cy],
"width": 120.0,
"height": 30.0
}
]
}
```
## 模块说明
### input/loader.py - 图片加载模块
提供图片加载功能,支持单张、批量和目录加载。
- `ImageLoader`: 图片加载器类
- `ImageInfo`: 图片信息数据结构
- `load_image()`: 便捷函数,加载单张图片
- `load_images()`: 便捷函数,批量加载图片
### ocr/engine.py - OCR 引擎模块
封装 PaddleOCR提供简洁的 OCR 调用接口。
- `OCREngine`: OCR 引擎类
- `TextBlock`: 文本块数据结构
### ocr/pipeline.py - OCR 处理管道
串联图片加载、ROI 裁剪、OCR 识别、结果封装。
- `OCRPipeline`: 处理管道类
- `OCRResult`: OCR 结果数据结构
### visualize/draw.py - 可视化模块
在图像上绘制 OCR 识别结果。
- `OCRVisualizer`: 可视化器类
### utils/config.py - 配置管理模块
集中管理所有可配置参数。
- `Config`: 全局配置聚合类
- `InputConfig`: 输入配置
- `OCRConfig`: OCR 引擎配置
- `PipelineConfig`: 管道配置
- `VisualizeConfig`: 可视化配置
- `OutputConfig`: 输出配置
- `ROIConfig`: ROI 区域配置
## 扩展开发
### 添加图片预处理器
```python
import cv2
def denoise_preprocessor(image):
"""降噪预处理"""
return cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
pipeline.add_preprocessor(denoise_preprocessor)
```
### 添加结果后处理器
```python
def filter_short_text(result):
"""过滤短文本"""
result.text_blocks = [
block for block in result.text_blocks
if len(block.text) >= 3
]
return result
pipeline.add_postprocessor(filter_short_text)
```
## 性能优化建议
1. **启用 ROI**: 使用 `--roi` 参数只处理感兴趣区域
2. **使用 GPU**: 使用 `--gpu` 参数启用 GPU 加速
3. **禁用方向分类**: 如果文本方向固定,使用 `--no-angle-cls`
4. **提高置信度阈值**: 使用 `--drop-score` 过滤低质量结果
5. **批量处理**: 使用目录模式批量处理多张图片
## 常见问题
### Q: Windows 中文用户名导致模型加载失败?
A: PaddlePaddle 的 C++ 推理引擎无法正确处理包含中文字符的路径。请运行以下命令将模型下载到项目目录:
```bash
python download_models.py
```
程序会自动检测并使用 `models/` 目录中的模型。
### Q: 中文无法正常显示?
A: OpenCV 默认字体不支持中文。可以在 `VisualizeConfig` 中配置 `font_path` 指向系统中文字体文件:
```python
config.visualize.font_path = "C:/Windows/Fonts/simhei.ttf" # Windows
config.visualize.font_path = "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc" # Linux
```
### Q: OCR 速度慢?
A: 参考上方「性能优化建议」部分。
### Q: 支持哪些图片格式?
A: 支持以下格式:`.jpg`, `.jpeg`, `.png`, `.bmp`, `.tiff`, `.tif`, `.webp`
## 贡献指南
欢迎提交 Issue 和 Pull Request。
### 开发流程
1. Fork 本仓库
2. 创建功能分支: `git checkout -b feature/your-feature`
3. 提交更改: `git commit -m "Add your feature"`
4. 推送分支: `git push origin feature/your-feature`
5. 创建 Pull Request
### 代码规范
- 遵循 PEP 8 代码风格
- 所有公共类和函数需要添加文档字符串
- 新功能需要添加相应的类型注解
- 提交前确保代码可正常运行
### 目录结构规范
- `input/`: 仅包含图片加载相关代码
- `ocr/`: 仅包含 OCR 处理相关代码
- `visualize/`: 仅包含可视化相关代码
- `utils/`: 通用工具和配置
## 许可证
MIT License
## 致谢
- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) - 强大的 OCR 工具库
- [OpenCV](https://opencv.org/) - 计算机视觉库

Binary file not shown.

After

Width:  |  Height:  |  Size: 865 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 253 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 853 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 920 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 775 KiB

@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
"""
模型下载脚本
PaddleOCR 模型下载到项目目录避免中文路径问题
"""
import os
import tarfile
import urllib.request
from pathlib import Path
# 模型下载地址
MODELS = {
"det": {
"url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar",
"dir": "ch_PP-OCRv4_det_infer"
},
"rec": {
"url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_infer.tar",
"dir": "ch_PP-OCRv4_rec_infer"
},
"cls": {
"url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar",
"dir": "ch_ppocr_mobile_v2.0_cls_infer"
}
}
def download_and_extract(url: str, save_dir: Path, model_name: str) -> Path:
"""
下载并解压模型
Args:
url: 下载地址
save_dir: 保存目录
model_name: 模型名称
Returns:
解压后的模型目录路径
"""
save_dir.mkdir(parents=True, exist_ok=True)
tar_path = save_dir / f"{model_name}.tar"
# 下载
if not tar_path.exists():
print(f"[INFO] Downloading {model_name}...")
urllib.request.urlretrieve(url, tar_path)
print(f"[INFO] Downloaded to {tar_path}")
else:
print(f"[INFO] {model_name} already exists, skipping download")
# 解压
extract_dir = save_dir / model_name
if not extract_dir.exists():
print(f"[INFO] Extracting {model_name}...")
with tarfile.open(tar_path, "r") as tar:
tar.extractall(save_dir)
print(f"[INFO] Extracted to {extract_dir}")
return extract_dir
def main():
"""下载所有模型"""
project_root = Path(__file__).parent
models_dir = project_root / "models"
print(f"[INFO] Models will be saved to: {models_dir}")
for model_type, info in MODELS.items():
model_dir = download_and_extract(
url=info["url"],
save_dir=models_dir,
model_name=info["dir"]
)
print(f"[INFO] {model_type} model ready at: {model_dir}")
print("\n[INFO] All models downloaded successfully!")
print("[INFO] You can now run: python main.py")
if __name__ == "__main__":
main()

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
"""
图片输入模块
提供图片加载功能支持单张图片多张图片和目录批量加载
"""
from .loader import ImageLoader, ImageInfo, load_image, load_images
__all__ = ["ImageLoader", "ImageInfo", "load_image", "load_images"]

@ -0,0 +1,278 @@
# -*- coding: utf-8 -*-
"""
图片加载模块
提供图片加载功能支持单张图片多张图片和目录批量加载
"""
import cv2
import numpy as np
from pathlib import Path
from typing import List, Optional, Generator, Union, Tuple
from dataclasses import dataclass
# 支持的图片格式
SUPPORTED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'}
@dataclass
class ImageInfo:
"""
图片信息数据结构
Attributes:
path: 图片文件路径
image: 图片数据 (numpy array, BGR 格式)
width: 图片宽度
height: 图片高度
filename: 文件名不含路径
"""
path: str
image: np.ndarray
width: int
height: int
filename: str
@classmethod
def from_image(cls, path: str, image: np.ndarray) -> "ImageInfo":
"""
从图片数据创建 ImageInfo
Args:
path: 图片路径
image: 图片数据
Returns:
ImageInfo 实例
"""
h, w = image.shape[:2]
return cls(
path=path,
image=image,
width=w,
height=h,
filename=Path(path).name
)
class ImageLoader:
"""
图片加载器
支持单张图片多张图片列表和目录批量加载
"""
def __init__(self, supported_extensions: Optional[set] = None):
"""
初始化图片加载器
Args:
supported_extensions: 支持的图片扩展名集合默认使用 SUPPORTED_EXTENSIONS
"""
self._extensions = supported_extensions or SUPPORTED_EXTENSIONS
def load(self, path: Union[str, Path]) -> Optional[ImageInfo]:
"""
加载单张图片
Args:
path: 图片文件路径
Returns:
ImageInfo 对象加载失败返回 None
"""
path = Path(path)
if not path.exists():
print(f"[ERROR] 文件不存在: {path}")
return None
if not path.is_file():
print(f"[ERROR] 路径不是文件: {path}")
return None
if path.suffix.lower() not in self._extensions:
print(f"[ERROR] 不支持的图片格式: {path.suffix}")
return None
# 使用 cv2.imdecode 处理中文路径
image = self._read_image(str(path))
if image is None:
print(f"[ERROR] 无法读取图片: {path}")
return None
return ImageInfo.from_image(str(path), image)
def load_batch(self, paths: List[Union[str, Path]]) -> Generator[ImageInfo, None, None]:
"""
批量加载多张图片
Args:
paths: 图片路径列表
Yields:
ImageInfo 对象
"""
for path in paths:
info = self.load(path)
if info is not None:
yield info
def load_directory(
self,
dir_path: Union[str, Path],
pattern: Optional[str] = None,
recursive: bool = False
) -> Generator[ImageInfo, None, None]:
"""
加载目录中的所有图片
Args:
dir_path: 目录路径
pattern: 文件匹配模式 "*.jpg"None 表示加载所有支持的格式
recursive: 是否递归搜索子目录
Yields:
ImageInfo 对象
"""
dir_path = Path(dir_path)
if not dir_path.exists():
print(f"[ERROR] 目录不存在: {dir_path}")
return
if not dir_path.is_dir():
print(f"[ERROR] 路径不是目录: {dir_path}")
return
# 获取文件列表
if pattern:
if recursive:
files = list(dir_path.rglob(pattern))
else:
files = list(dir_path.glob(pattern))
else:
# 加载所有支持的格式
files = []
for ext in self._extensions:
if recursive:
files.extend(dir_path.rglob(f"*{ext}"))
files.extend(dir_path.rglob(f"*{ext.upper()}"))
else:
files.extend(dir_path.glob(f"*{ext}"))
files.extend(dir_path.glob(f"*{ext.upper()}"))
# 按文件名排序
files = sorted(set(files), key=lambda p: p.name)
print(f"[INFO] 在目录 {dir_path} 中找到 {len(files)} 张图片")
for file_path in files:
info = self.load(file_path)
if info is not None:
yield info
def get_image_paths(
self,
dir_path: Union[str, Path],
pattern: Optional[str] = None,
recursive: bool = False
) -> List[str]:
"""
获取目录中所有图片的路径列表不加载图片
Args:
dir_path: 目录路径
pattern: 文件匹配模式
recursive: 是否递归搜索
Returns:
图片路径列表
"""
dir_path = Path(dir_path)
if not dir_path.exists() or not dir_path.is_dir():
return []
if pattern:
if recursive:
files = list(dir_path.rglob(pattern))
else:
files = list(dir_path.glob(pattern))
else:
files = []
for ext in self._extensions:
if recursive:
files.extend(dir_path.rglob(f"*{ext}"))
files.extend(dir_path.rglob(f"*{ext.upper()}"))
else:
files.extend(dir_path.glob(f"*{ext}"))
files.extend(dir_path.glob(f"*{ext.upper()}"))
return sorted([str(f) for f in set(files)], key=lambda p: Path(p).name)
def _read_image(self, path: str) -> Optional[np.ndarray]:
"""
读取图片支持中文路径
Args:
path: 图片路径
Returns:
图片数据读取失败返回 None
"""
# 使用 numpy 和 imdecode 处理中文路径
try:
with open(path, 'rb') as f:
data = np.frombuffer(f.read(), dtype=np.uint8)
image = cv2.imdecode(data, cv2.IMREAD_COLOR)
return image
except Exception as e:
print(f"[ERROR] 读取图片失败: {path}, 错误: {e}")
return None
@property
def supported_extensions(self) -> set:
"""获取支持的图片扩展名"""
return self._extensions.copy()
def load_image(path: Union[str, Path]) -> Optional[np.ndarray]:
"""
便捷函数加载单张图片
Args:
path: 图片路径
Returns:
图片数据 (numpy array, BGR 格式)加载失败返回 None
"""
loader = ImageLoader()
info = loader.load(path)
return info.image if info else None
def load_images(
paths: Optional[List[Union[str, Path]]] = None,
directory: Optional[Union[str, Path]] = None,
pattern: Optional[str] = None,
recursive: bool = False
) -> Generator[ImageInfo, None, None]:
"""
便捷函数批量加载图片
Args:
paths: 图片路径列表
directory: 图片目录
pattern: 文件匹配模式
recursive: 是否递归搜索
Yields:
ImageInfo 对象
"""
loader = ImageLoader()
if paths:
yield from loader.load_batch(paths)
elif directory:
yield from loader.load_directory(directory, pattern, recursive)

@ -0,0 +1,551 @@
# -*- coding: utf-8 -*-
"""
OCR 图片识别系统 - 主入口
支持单张图片多张图片和目录批量处理
"""
import os
from pathlib import Path
# 在所有其他导入之前设置 PaddleOCR 模型路径
# 解决 Windows 中文用户名路径问题
_PROJECT_ROOT = Path(__file__).parent
_MODELS_DIR = _PROJECT_ROOT / "models"
_MODELS_DIR.mkdir(exist_ok=True)
os.environ["PADDLEOCR_HOME"] = str(_MODELS_DIR)
import argparse
import json
import sys
import cv2
from typing import Optional, List, Generator
from input.loader import ImageLoader, ImageInfo
from ocr.pipeline import OCRPipeline, OCRResult
from visualize.draw import OCRVisualizer
from utils.config import (
Config,
InputConfig,
InputMode,
OCRConfig,
PipelineConfig,
VisualizeConfig,
OutputConfig,
ROIConfig
)
class OCRApplication:
"""
OCR 应用主类
协调各模块完成图片 OCR 识别
"""
def __init__(
self,
config: Config,
express_mode: bool = False
):
"""
初始化应用
Args:
config: 全局配置
express_mode: 是否启用快递单解析模式
"""
self._config = config
self._loader: Optional[ImageLoader] = None
self._pipeline: Optional[OCRPipeline] = None
self._visualizer: Optional[OCRVisualizer] = None
self._express_mode = express_mode
self._all_results: List[dict] = []
def initialize(self) -> bool:
"""
初始化所有组件
Returns:
是否初始化成功
"""
print("[INFO] 正在初始化 OCR 系统...")
# 创建图片加载器
self._loader = ImageLoader()
# 创建 OCR 管道
self._pipeline = OCRPipeline(
ocr_config=self._config.ocr,
pipeline_config=self._config.pipeline
)
# 创建可视化器
self._visualizer = OCRVisualizer(self._config.visualize)
# 初始化 OCR 管道(预加载模型)
print("[INFO] 正在加载 OCR 模型...")
self._pipeline.initialize()
print("[INFO] OCR 模型加载完成")
return True
def _get_images(self) -> Generator[ImageInfo, None, None]:
"""
根据配置获取图片
Yields:
ImageInfo 对象
"""
input_config = self._config.input
if input_config is None:
return
if input_config.mode == InputMode.SINGLE:
info = self._loader.load(input_config.image_path)
if info:
yield info
elif input_config.mode == InputMode.BATCH:
yield from self._loader.load_batch(input_config.image_paths)
elif input_config.mode == InputMode.DIRECTORY:
yield from self._loader.load_directory(
input_config.directory,
input_config.pattern,
input_config.recursive
)
def run(self) -> None:
"""运行图片处理"""
if self._loader is None or self._pipeline is None or self._visualizer is None:
print("[ERROR] 系统未初始化")
return
self._all_results = []
print("[INFO] 开始 OCR 识别...")
try:
for image_info in self._get_images():
print(f"\n[INFO] 处理图片: {image_info.filename}")
# OCR 处理
result = self._pipeline.process(image_info.image, image_info.path)
# 收集结果
if result:
if self._express_mode:
# 快递单模式:解析并收集结构化结果
express_info = result.parse_express()
self._all_results.append({
"image_index": result.image_index,
"image_path": result.image_path,
"processing_time_ms": result.processing_time_ms,
"express_info": express_info.to_dict(),
"merged_text": result.merge_text()
})
else:
self._all_results.append(result.to_dict())
# 打印结果
if result and result.text_count > 0:
if self._express_mode:
self._print_express_result(result)
else:
self._print_result(result)
# 可视化并保存
if result:
self._handle_visualization(image_info, result)
except KeyboardInterrupt:
print("\n[INFO] 收到中断信号,正在退出...")
finally:
# 导出汇总结果
self._export_summary()
self.cleanup()
def _print_result(self, result: OCRResult) -> None:
"""
打印 OCR 结果到控制台
Args:
result: OCR 结果
"""
print(f" 识别到 {result.text_count} 个文本块 (耗时: {result.processing_time_ms:.1f}ms)")
for i, block in enumerate(result.text_blocks):
print(f" [{i+1}] {block.text} (置信度: {block.confidence:.3f})")
def _print_express_result(self, result: OCRResult) -> None:
"""
打印快递单解析结果到控制台
Args:
result: OCR 结果
"""
express_info = result.parse_express()
print(f" 快递单解析结果 (耗时: {result.processing_time_ms:.1f}ms)")
if express_info.courier_company:
print(f" 快递公司: {express_info.courier_company}")
if express_info.tracking_number:
print(f" 运单号: {express_info.tracking_number}")
if express_info.receiver_name:
print(f" 收件人: {express_info.receiver_name}")
if express_info.receiver_phone:
print(f" 收件电话: {express_info.receiver_phone}")
if express_info.receiver_address:
print(f" 收件地址: {express_info.receiver_address}")
if express_info.sender_name:
print(f" 寄件人: {express_info.sender_name}")
if express_info.sender_phone:
print(f" 寄件电话: {express_info.sender_phone}")
if express_info.sender_address:
print(f" 寄件地址: {express_info.sender_address}")
if not express_info.is_valid:
print(" [未识别到有效快递单信息]")
print(f" 合并文本: {result.merge_text()}")
def _handle_visualization(self, image_info: ImageInfo, result: OCRResult) -> None:
"""
处理可视化和图片保存
Args:
image_info: 图片信息
result: OCR 结果
"""
# 绘制结果
display_image = self._visualizer.draw_result(image_info.image, result)
# 显示窗口
if self._config.visualize.show_window:
key = self._visualizer.show(display_image, wait_key=0)
if key == ord('q') or key == ord('Q'):
raise KeyboardInterrupt()
# 保存标注后的图片
if self._config.output.save_image:
self._save_annotated_image(image_info, display_image)
def _save_annotated_image(self, image_info: ImageInfo, annotated_image) -> None:
"""
保存标注后的图片
Args:
image_info: 原始图片信息
annotated_image: 标注后的图片
"""
output_config = self._config.output
# 确定输出目录
if output_config.output_dir:
output_dir = Path(output_config.output_dir)
else:
output_dir = Path(image_info.path).parent
output_dir.mkdir(parents=True, exist_ok=True)
# 生成输出文件名
original_path = Path(image_info.path)
output_filename = f"{original_path.stem}{output_config.image_suffix}{original_path.suffix}"
output_path = output_dir / output_filename
# 保存图片(支持中文路径)
_, ext = os.path.splitext(str(output_path))
success, encoded = cv2.imencode(ext, annotated_image)
if success:
with open(output_path, 'wb') as f:
f.write(encoded.tobytes())
print(f" [INFO] 标注图片已保存: {output_path}")
def _export_summary(self) -> None:
"""
导出所有识别结果到汇总 JSON 文件
"""
if not self._all_results:
print("[INFO] 没有识别结果需要导出")
return
output_config = self._config.output
if not output_config.save_json:
return
# 确定输出路径
if output_config.output_dir:
output_dir = Path(output_config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / output_config.json_filename
else:
output_path = Path(output_config.json_filename)
# 构建汇总数据
summary = {
"total_images": len(self._all_results),
"total_text_blocks": sum(r.get("text_count", 0) for r in self._all_results),
"results": self._all_results
}
# 写入文件
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print(f"[INFO] 汇总结果已导出到 {output_path},共 {summary['total_images']} 张图片")
def cleanup(self) -> None:
"""清理资源"""
if self._visualizer:
self._visualizer.close()
print("[INFO] 资源已释放")
def parse_args() -> argparse.Namespace:
"""解析命令行参数"""
parser = argparse.ArgumentParser(
description="OCR 图片识别系统",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 识别单张图片
python main.py --image path/to/image.jpg
# 识别目录中的所有图片
python main.py --dir path/to/images/
# 识别目录中的特定格式图片
python main.py --dir path/to/images/ --pattern "*.png"
# 递归搜索子目录
python main.py --dir path/to/images/ --recursive
# 启用快递单解析模式
python main.py --image express.jpg --express
# 保存标注后的图片
python main.py --image test.jpg --save-image
# 指定输出目录
python main.py --dir images/ --output-dir results/
# 启用 ROI 裁剪(画面中央 60% 区域)
python main.py --image test.jpg --roi 0.2 0.2 0.6 0.6
# 使用 GPU 加速
python main.py --image test.jpg --gpu
"""
)
# 输入源(互斥)
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument(
"--image", "-i",
type=str,
help="单张图片路径"
)
input_group.add_argument(
"--dir", "-d",
type=str,
help="图片目录路径"
)
# 目录模式选项
parser.add_argument(
"--pattern", "-p",
type=str,
default=None,
help="文件匹配模式(如 '*.jpg'"
)
parser.add_argument(
"--recursive", "-r",
action="store_true",
help="递归搜索子目录"
)
# OCR 配置
parser.add_argument(
"--lang", "-l",
type=str,
default="ch",
help="OCR 语言(默认: ch"
)
parser.add_argument(
"--gpu",
action="store_true",
help="启用 GPU 加速"
)
parser.add_argument(
"--no-angle-cls",
action="store_true",
help="禁用方向分类"
)
parser.add_argument(
"--drop-score",
type=float,
default=0.5,
help="置信度阈值(默认: 0.5"
)
# ROI 配置
parser.add_argument(
"--roi",
type=float,
nargs=4,
metavar=("X", "Y", "W", "H"),
help="ROI 区域(归一化坐标: x y width height"
)
# 可视化配置
parser.add_argument(
"--show-window",
action="store_true",
help="显示可视化窗口(默认不显示)"
)
parser.add_argument(
"--no-confidence",
action="store_true",
help="不显示置信度"
)
# 输出配置
parser.add_argument(
"--output-dir", "-o",
type=str,
default=None,
help="输出目录路径"
)
parser.add_argument(
"--save-image",
action="store_true",
help="保存标注后的图片"
)
parser.add_argument(
"--no-json",
action="store_true",
help="不保存 JSON 结果"
)
parser.add_argument(
"--json-filename",
type=str,
default="ocr_result.json",
help="JSON 结果文件名(默认: ocr_result.json"
)
# 快递单解析模式
parser.add_argument(
"--express", "-e",
action="store_true",
help="启用快递单解析模式,自动合并文本并提取结构化信息"
)
return parser.parse_args()
def build_config(args: argparse.Namespace) -> Config:
"""
根据命令行参数构建配置
Args:
args: 命令行参数
Returns:
配置对象
"""
# 输入配置
if args.image:
input_config = InputConfig(
mode=InputMode.SINGLE,
image_path=args.image
)
else:
input_config = InputConfig(
mode=InputMode.DIRECTORY,
directory=args.dir,
pattern=args.pattern,
recursive=args.recursive
)
# OCR 配置
# 设置模型目录(解决 Windows 中文用户名路径问题)
det_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_det_infer")
rec_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_rec_infer")
cls_model_dir = str(_MODELS_DIR / "ch_ppocr_mobile_v2.0_cls_infer")
# 检查模型是否已下载
models_exist = (
Path(det_model_dir).exists() and
Path(rec_model_dir).exists() and
Path(cls_model_dir).exists()
)
ocr_config = OCRConfig(
lang=args.lang,
use_angle_cls=not args.no_angle_cls,
use_gpu=args.gpu,
drop_score=args.drop_score,
det_model_dir=det_model_dir if models_exist else None,
rec_model_dir=rec_model_dir if models_exist else None,
cls_model_dir=cls_model_dir if models_exist else None
)
# ROI 配置
roi_config = ROIConfig(enabled=False)
if args.roi:
roi_config = ROIConfig(
enabled=True,
x_ratio=args.roi[0],
y_ratio=args.roi[1],
width_ratio=args.roi[2],
height_ratio=args.roi[3]
)
# 管道配置
pipeline_config = PipelineConfig(roi=roi_config)
# 可视化配置
visualize_config = VisualizeConfig(
show_window=args.show_window,
show_confidence=not args.no_confidence
)
# 输出配置
output_config = OutputConfig(
output_dir=args.output_dir,
save_json=not args.no_json,
save_image=args.save_image,
json_filename=args.json_filename
)
return Config(
input=input_config,
ocr=ocr_config,
pipeline=pipeline_config,
visualize=visualize_config,
output=output_config
)
def main() -> int:
"""主函数"""
args = parse_args()
config = build_config(args)
# 检查模型是否已下载
if config.ocr.det_model_dir is None:
print("[WARN] 模型未在项目目录中找到")
print("[WARN] 对于 Windows 中文用户名用户,请先运行:")
print("[WARN] python download_models.py")
print("[INFO] 回退到默认 PaddleOCR 模型路径...")
app = OCRApplication(
config,
express_mode=args.express
)
if not app.initialize():
return 1
app.run()
return 0
if __name__ == "__main__":
sys.exit(main())

@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-
"""
OCR 模块
提供 OCR 引擎和处理管道
"""
from .engine import OCREngine, TextBlock
from .pipeline import OCRPipeline, OCRResult
from .express_parser import ExpressParser, ExpressInfo, TextLine
__all__ = [
"OCREngine",
"TextBlock",
"OCRPipeline",
"OCRResult",
"ExpressParser",
"ExpressInfo",
"TextLine"
]

@ -0,0 +1,220 @@
# -*- coding: utf-8 -*-
"""
OCR 引擎模块
封装 PaddleOCR提供统一的 OCR 接口
"""
import os
from pathlib import Path
# 在导入 PaddleOCR 之前设置环境变量
# 解决 Windows 中文用户名路径问题
_PROJECT_ROOT = Path(__file__).parent.parent
_MODELS_DIR = _PROJECT_ROOT / "models"
_MODELS_DIR.mkdir(exist_ok=True)
os.environ["PADDLEOCR_HOME"] = str(_MODELS_DIR)
import numpy as np
from typing import List, Optional, Any
from dataclasses import dataclass
from paddleocr import PaddleOCR
from utils.config import OCRConfig
@dataclass
class TextBlock:
"""
文本块数据结构
表示 OCR 识别出的单个文本区域
Attributes:
text: 识别出的文本内容
confidence: 置信度 (0.0 ~ 1.0)
bbox: 边界框4 个点的坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
bbox_offset: ROI 偏移量用于还原到原图坐标
"""
text: str
confidence: float
bbox: List[List[float]]
bbox_offset: tuple = (0, 0)
@property
def bbox_with_offset(self) -> List[List[float]]:
"""获取带偏移的边界框(还原到原图坐标)"""
offset_x, offset_y = self.bbox_offset
return [[p[0] + offset_x, p[1] + offset_y] for p in self.bbox]
@property
def center(self) -> tuple:
"""获取文本块中心点"""
x_coords = [p[0] for p in self.bbox]
y_coords = [p[1] for p in self.bbox]
return (sum(x_coords) / 4, sum(y_coords) / 4)
@property
def width(self) -> float:
"""获取文本块宽度"""
x_coords = [p[0] for p in self.bbox]
return max(x_coords) - min(x_coords)
@property
def height(self) -> float:
"""获取文本块高度"""
y_coords = [p[1] for p in self.bbox]
return max(y_coords) - min(y_coords)
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
"text": self.text,
"confidence": self.confidence,
"bbox": self.bbox,
"bbox_with_offset": self.bbox_with_offset,
"center": self.center,
"width": self.width,
"height": self.height
}
class OCREngine:
"""
OCR 引擎类
封装 PaddleOCR提供简洁的 OCR 调用接口
"""
def __init__(self, config: OCRConfig):
"""
初始化 OCR 引擎
Args:
config: OCR 配置
"""
self._config = config
self._ocr: Optional[PaddleOCR] = None
def initialize(self) -> None:
"""
初始化 PaddleOCR 实例
延迟初始化避免在导入时加载模型
适配 PaddleOCR 2.x API
"""
if self._ocr is not None:
return
# 构建参数
params = {
"lang": self._config.lang,
"use_angle_cls": self._config.use_angle_cls,
"use_gpu": self._config.use_gpu,
"det_db_thresh": self._config.det_db_thresh,
"det_db_box_thresh": self._config.det_db_box_thresh,
"drop_score": self._config.drop_score,
"show_log": self._config.show_log
}
# 如果指定了模型目录,则使用自定义路径(解决中文路径问题)
if self._config.det_model_dir:
params["det_model_dir"] = self._config.det_model_dir
if self._config.rec_model_dir:
params["rec_model_dir"] = self._config.rec_model_dir
if self._config.cls_model_dir:
params["cls_model_dir"] = self._config.cls_model_dir
# PaddleOCR 2.x API
self._ocr = PaddleOCR(**params)
def recognize(
self,
image: np.ndarray,
roi_offset: tuple = (0, 0)
) -> List[TextBlock]:
"""
对图像进行 OCR 识别
Args:
image: 输入图像 (numpy array, BGR 或灰度图)
roi_offset: ROI 偏移量 (x, y)用于还原坐标
Returns:
识别结果列表
"""
# 确保引擎已初始化
if self._ocr is None:
self.initialize()
# 执行 OCR (PaddleOCR 2.x API)
result = self._ocr.ocr(image, cls=self._config.use_angle_cls)
# 解析结果
text_blocks: List[TextBlock] = []
# PaddleOCR 返回格式: [[line1, line2, ...]] 或 None
if result is None or len(result) == 0:
return text_blocks
# 遍历每一行结果
for line in result:
if line is None:
continue
for item in line:
if item is None or len(item) < 2:
continue
bbox = item[0] # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
text_info = item[1] # (text, confidence)
if len(text_info) < 2:
continue
text = text_info[0]
confidence = float(text_info[1])
# 过滤低置信度结果
if confidence < self._config.drop_score:
continue
text_block = TextBlock(
text=text,
confidence=confidence,
bbox=bbox,
bbox_offset=roi_offset
)
text_blocks.append(text_block)
return text_blocks
def recognize_batch(
self,
images: List[np.ndarray]
) -> List[List[TextBlock]]:
"""
批量 OCR 识别
Args:
images: 输入图像列表
Returns:
每张图像的识别结果列表
"""
return [self.recognize(img) for img in images]
@property
def config(self) -> OCRConfig:
"""获取当前配置"""
return self._config
def update_config(self, **kwargs) -> None:
"""
更新配置并重新初始化引擎
Args:
**kwargs: 要更新的配置项
"""
for key, value in kwargs.items():
if hasattr(self._config, key):
setattr(self._config, key, value)
# 重新初始化
self._ocr = None
self.initialize()

@ -0,0 +1,404 @@
# -*- coding: utf-8 -*-
"""
快递单解析模块
将分散的 OCR 文本块合并并解析成结构化的快递单信息
"""
import re
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any
from .engine import TextBlock
@dataclass
class ExpressInfo:
"""
快递单结构化信息
Attributes:
tracking_number: 运单号
sender_name: 寄件人姓名
sender_phone: 寄件人电话
sender_address: 寄件人地址
receiver_name: 收件人姓名
receiver_phone: 收件人电话
receiver_address: 收件人地址
courier_company: 快递公司
raw_text: 原始合并文本用于调试
confidence: 平均置信度
extra_fields: 其他识别到的字段
"""
tracking_number: Optional[str] = None
sender_name: Optional[str] = None
sender_phone: Optional[str] = None
sender_address: Optional[str] = None
receiver_name: Optional[str] = None
receiver_phone: Optional[str] = None
receiver_address: Optional[str] = None
courier_company: Optional[str] = None
raw_text: str = ""
confidence: float = 0.0
extra_fields: Dict[str, str] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"tracking_number": self.tracking_number,
"sender": {
"name": self.sender_name,
"phone": self.sender_phone,
"address": self.sender_address
},
"receiver": {
"name": self.receiver_name,
"phone": self.receiver_phone,
"address": self.receiver_address
},
"courier_company": self.courier_company,
"confidence": self.confidence,
"extra_fields": self.extra_fields,
"raw_text": self.raw_text
}
@property
def is_valid(self) -> bool:
"""检查是否包含有效的快递单信息"""
# 至少需要运单号或收件人信息
return bool(self.tracking_number or self.receiver_name or self.receiver_phone)
@dataclass
class TextLine:
"""
合并后的文本行
Attributes:
text: 合并后的文本
blocks: 原始文本块列表
y_center: 行中心 Y 坐标
x_min: 行起始 X 坐标
"""
text: str
blocks: List[TextBlock]
y_center: float
x_min: float
@property
def confidence(self) -> float:
"""计算平均置信度"""
if not self.blocks:
return 0.0
return sum(b.confidence for b in self.blocks) / len(self.blocks)
class ExpressParser:
"""
快递单解析器
将分散的文本块合并成行并提取结构化信息
"""
# 快递公司关键词
COURIER_KEYWORDS = {
"顺丰": "顺丰速运",
"SF": "顺丰速运",
"圆通": "圆通速递",
"中通": "中通快递",
"韵达": "韵达快递",
"申通": "申通快递",
"极兔": "极兔速递",
"京东": "京东物流",
"JD": "京东物流",
"邮政": "中国邮政",
"EMS": "中国邮政EMS",
"百世": "百世快递",
"德邦": "德邦快递",
"天天": "天天快递",
"宅急送": "宅急送",
}
# 字段关键词模式
FIELD_PATTERNS = {
"tracking_number": [
r"运单号[:]\s*(\w+)",
r"单号[:]\s*(\w+)",
r"快递单号[:]\s*(\w+)",
r"物流单号[:]\s*(\w+)",
r"^(\d{10,20})$", # 纯数字运单号
r"^([A-Z]{2}\d{9,13}[A-Z]{2})$", # 国际快递单号格式
],
"receiver_name": [
r"收件人[:]\s*(.+?)(?:\s|电话|手机|地址|$)",
r"收货人[:]\s*(.+?)(?:\s|电话|手机|地址|$)",
r"收[:]\s*(.+?)(?:\s|电话|手机|地址|$)",
],
"receiver_phone": [
r"收件人.*?电话[:]\s*(\d{11})",
r"收件人.*?手机[:]\s*(\d{11})",
r"收.*?(\d{11})",
r"电话[:]\s*(\d{11})",
r"手机[:]\s*(\d{11})",
r"(?<![0-9])(\d{11})(?![0-9])", # 独立的11位手机号
],
"receiver_address": [
r"收件地址[:]\s*(.+?)(?:寄件|发件|$)",
r"收货地址[:]\s*(.+?)(?:寄件|发件|$)",
r"地址[:]\s*(.+?)(?:寄件|发件|电话|$)",
],
"sender_name": [
r"寄件人[:]\s*(.+?)(?:\s|电话|手机|地址|$)",
r"发件人[:]\s*(.+?)(?:\s|电话|手机|地址|$)",
r"寄[:]\s*(.+?)(?:\s|电话|手机|地址|$)",
],
"sender_phone": [
r"寄件人.*?电话[:]\s*(\d{11})",
r"寄件人.*?手机[:]\s*(\d{11})",
],
"sender_address": [
r"寄件地址[:]\s*(.+?)(?:收件|$)",
r"发件地址[:]\s*(.+?)(?:收件|$)",
],
}
def __init__(
self,
line_merge_threshold: float = 0.6,
horizontal_gap_threshold: float = 2.0
):
"""
初始化解析器
Args:
line_merge_threshold: 行合并阈值相对于文本高度的比例
horizontal_gap_threshold: 水平间距阈值相对于平均字符宽度的比例
"""
self._line_merge_threshold = line_merge_threshold
self._horizontal_gap_threshold = horizontal_gap_threshold
def parse(self, text_blocks: List[TextBlock]) -> ExpressInfo:
"""
解析文本块列表提取快递单信息
Args:
text_blocks: OCR 识别的文本块列表
Returns:
结构化的快递单信息
"""
if not text_blocks:
return ExpressInfo()
# 1. 合并文本块为行
lines = self._merge_blocks_to_lines(text_blocks)
# 2. 生成完整文本(用于正则匹配)
full_text = self._lines_to_text(lines)
# 3. 提取结构化信息
info = self._extract_info(full_text, lines)
# 4. 计算平均置信度
info.confidence = sum(b.confidence for b in text_blocks) / len(text_blocks)
info.raw_text = full_text
return info
def _merge_blocks_to_lines(self, blocks: List[TextBlock]) -> List[TextLine]:
"""
将文本块按位置合并为行
基于 Y 坐标将相近的文本块合并到同一行
然后按 X 坐标排序合并文本
"""
if not blocks:
return []
# 按 Y 坐标排序
sorted_blocks = sorted(blocks, key=lambda b: b.center[1])
lines: List[TextLine] = []
current_line_blocks: List[TextBlock] = [sorted_blocks[0]]
current_y = sorted_blocks[0].center[1]
for block in sorted_blocks[1:]:
block_y = block.center[1]
block_height = block.height
# 判断是否属于同一行Y 坐标差值小于阈值)
threshold = block_height * self._line_merge_threshold
if abs(block_y - current_y) <= threshold:
current_line_blocks.append(block)
else:
# 完成当前行,开始新行
line = self._create_line(current_line_blocks)
lines.append(line)
current_line_blocks = [block]
current_y = block_y
# 处理最后一行
if current_line_blocks:
line = self._create_line(current_line_blocks)
lines.append(line)
return lines
def _create_line(self, blocks: List[TextBlock]) -> TextLine:
"""
从文本块列表创建文本行
X 坐标排序根据间距决定是否添加空格
"""
# 按 X 坐标排序
sorted_blocks = sorted(blocks, key=lambda b: b.center[0])
# 合并文本
text_parts = []
prev_block = None
for block in sorted_blocks:
if prev_block is not None:
# 计算水平间距
prev_right = max(p[0] for p in prev_block.bbox)
curr_left = min(p[0] for p in block.bbox)
gap = curr_left - prev_right
# 计算平均字符宽度
avg_char_width = prev_block.width / max(len(prev_block.text), 1)
# 如果间距较大,添加空格
if gap > avg_char_width * self._horizontal_gap_threshold:
text_parts.append(" ")
text_parts.append(block.text)
prev_block = block
merged_text = "".join(text_parts)
y_center = sum(b.center[1] for b in sorted_blocks) / len(sorted_blocks)
x_min = min(min(p[0] for p in b.bbox) for b in sorted_blocks)
return TextLine(
text=merged_text,
blocks=sorted_blocks,
y_center=y_center,
x_min=x_min
)
def _lines_to_text(self, lines: List[TextLine]) -> str:
"""将文本行列表转换为完整文本"""
return "\n".join(line.text for line in lines)
def _extract_info(self, full_text: str, lines: List[TextLine]) -> ExpressInfo:
"""
从文本中提取快递单信息
Args:
full_text: 完整文本
lines: 文本行列表
Returns:
结构化的快递单信息
"""
info = ExpressInfo()
# 提取快递公司
info.courier_company = self._extract_courier_company(full_text)
# 提取各字段
for field_name, patterns in self.FIELD_PATTERNS.items():
value = self._extract_field(full_text, patterns)
if value:
setattr(info, field_name, value)
# 尝试从上下文推断地址
if not info.receiver_address:
info.receiver_address = self._extract_address_from_context(lines, "")
if not info.sender_address:
info.sender_address = self._extract_address_from_context(lines, "")
return info
def _extract_courier_company(self, text: str) -> Optional[str]:
"""提取快递公司名称"""
text_upper = text.upper()
for keyword, company in self.COURIER_KEYWORDS.items():
if keyword.upper() in text_upper:
return company
return None
def _extract_field(self, text: str, patterns: List[str]) -> Optional[str]:
"""
使用正则表达式列表提取字段值
Args:
text: 待匹配文本
patterns: 正则表达式列表
Returns:
匹配到的字段值 None
"""
for pattern in patterns:
match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE)
if match:
value = match.group(1).strip()
# 清理常见的干扰字符
value = re.sub(r'[【】\[\]()]', '', value)
if value:
return value
return None
def _extract_address_from_context(
self,
lines: List[TextLine],
context_keyword: str
) -> Optional[str]:
"""
从上下文中提取地址
查找包含省/////路等关键词的行
"""
address_keywords = ["", "", "", "", "", "", "", "", "", "", "", ""]
# 查找包含上下文关键词的行索引
context_line_idx = -1
for i, line in enumerate(lines):
if context_keyword in line.text:
context_line_idx = i
break
# 在上下文行附近查找地址
search_range = range(
max(0, context_line_idx),
min(len(lines), context_line_idx + 3 if context_line_idx >= 0 else len(lines))
)
address_parts = []
for i in search_range:
line_text = lines[i].text
# 检查是否包含地址关键词
if any(kw in line_text for kw in address_keywords):
# 清理行首的标签(如 "地址:"
cleaned = re.sub(r'^[^:]*[:]\s*', '', line_text)
if cleaned and cleaned != line_text:
address_parts.append(cleaned)
elif any(kw in line_text for kw in address_keywords[:4]): # 省/市/区/县
address_parts.append(line_text)
if address_parts:
return "".join(address_parts)
return None
def merge_text_blocks(self, text_blocks: List[TextBlock]) -> str:
"""
仅合并文本块不进行字段提取
用于获取完整的合并文本
Args:
text_blocks: 文本块列表
Returns:
合并后的完整文本
"""
lines = self._merge_blocks_to_lines(text_blocks)
return self._lines_to_text(lines)

@ -0,0 +1,303 @@
# -*- coding: utf-8 -*-
"""
OCR 处理管道模块
提供图片 OCR 识别和结果解析的完整处理流程
"""
import time
import numpy as np
from typing import List, Optional, Dict, Any, Callable, TYPE_CHECKING
from dataclasses import dataclass, field
from ocr.engine import OCREngine, TextBlock
from utils.config import PipelineConfig, OCRConfig
if TYPE_CHECKING:
from ocr.express_parser import ExpressInfo, ExpressParser
@dataclass
class OCRResult:
"""
OCR 处理结果数据结构
Attributes:
image_index: 图片索引批量处理时使用
image_path: 图片路径
timestamp: 处理时间戳
processing_time_ms: 处理耗时毫秒
text_blocks: 识别出的文本块列表
roi_applied: 是否应用了 ROI 裁剪
roi_rect: ROI 矩形 (x, y, w, h)如果应用了 ROI
"""
image_index: int
image_path: Optional[str]
timestamp: float
processing_time_ms: float
text_blocks: List[TextBlock]
roi_applied: bool = False
roi_rect: Optional[tuple] = None
@property
def text_count(self) -> int:
"""识别出的文本数量"""
return len(self.text_blocks)
@property
def all_texts(self) -> List[str]:
"""获取所有识别出的文本"""
return [block.text for block in self.text_blocks]
@property
def full_text(self) -> str:
"""获取所有文本拼接结果"""
return "\n".join(self.all_texts)
@property
def average_confidence(self) -> float:
"""获取平均置信度"""
if not self.text_blocks:
return 0.0
return sum(b.confidence for b in self.text_blocks) / len(self.text_blocks)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式,便于 JSON 序列化"""
return {
"image_index": self.image_index,
"image_path": self.image_path,
"timestamp": self.timestamp,
"processing_time_ms": self.processing_time_ms,
"text_count": self.text_count,
"average_confidence": self.average_confidence,
"roi_applied": self.roi_applied,
"roi_rect": self.roi_rect,
"text_blocks": [block.to_dict() for block in self.text_blocks]
}
def filter_by_confidence(self, min_confidence: float) -> "OCRResult":
"""
按置信度过滤结果
Args:
min_confidence: 最小置信度阈值
Returns:
过滤后的 OCRResult
"""
filtered_blocks = [
block for block in self.text_blocks
if block.confidence >= min_confidence
]
return OCRResult(
image_index=self.image_index,
image_path=self.image_path,
timestamp=self.timestamp,
processing_time_ms=self.processing_time_ms,
text_blocks=filtered_blocks,
roi_applied=self.roi_applied,
roi_rect=self.roi_rect
)
def parse_express(self) -> "ExpressInfo":
"""
解析快递单信息
将分散的文本块合并并提取结构化的快递单信息
Returns:
结构化的快递单信息
"""
from ocr.express_parser import ExpressParser
parser = ExpressParser()
return parser.parse(self.text_blocks)
def merge_text(self) -> str:
"""
合并文本块为完整文本
基于位置信息智能合并同一行的文本会被合并
Returns:
合并后的完整文本
"""
from ocr.express_parser import ExpressParser
parser = ExpressParser()
return parser.merge_text_blocks(self.text_blocks)
class OCRPipeline:
"""
OCR 处理管道
负责 ROI 裁剪OCR 调用结果封装
"""
def __init__(
self,
ocr_config: OCRConfig,
pipeline_config: Optional[PipelineConfig] = None
):
"""
初始化 OCR 管道
Args:
ocr_config: OCR 引擎配置
pipeline_config: 管道配置可选
"""
self._ocr_config = ocr_config
self._pipeline_config = pipeline_config or PipelineConfig()
self._engine = OCREngine(ocr_config)
self._image_counter: int = 0
# 预留扩展点:图片预处理回调
self._image_preprocessors: List[Callable[[np.ndarray], np.ndarray]] = []
# 预留扩展点:结果后处理回调
self._result_postprocessors: List[Callable[[OCRResult], OCRResult]] = []
def initialize(self) -> None:
"""初始化管道(预加载 OCR 模型)"""
self._engine.initialize()
def add_preprocessor(
self,
preprocessor: Callable[[np.ndarray], np.ndarray]
) -> None:
"""
添加图片预处理器
Args:
preprocessor: 预处理函数接收图像返回处理后的图像
"""
self._image_preprocessors.append(preprocessor)
def add_postprocessor(
self,
postprocessor: Callable[[OCRResult], OCRResult]
) -> None:
"""
添加结果后处理器
Args:
postprocessor: 后处理函数接收 OCRResult 返回处理后的结果
"""
self._result_postprocessors.append(postprocessor)
def _apply_roi(
self,
image: np.ndarray
) -> tuple:
"""
应用 ROI 裁剪
Args:
image: 原始图片
Returns:
(裁剪后的图像, ROI 偏移量, ROI 矩形)
"""
roi_config = self._pipeline_config.roi
if not roi_config.enabled:
return image, (0, 0), None
h, w = image.shape[:2]
x, y, roi_w, roi_h = roi_config.get_roi_rect(w, h)
# 边界检查
x = max(0, min(x, w - 1))
y = max(0, min(y, h - 1))
roi_w = min(roi_w, w - x)
roi_h = min(roi_h, h - y)
cropped = image[y:y + roi_h, x:x + roi_w]
return cropped, (x, y), (x, y, roi_w, roi_h)
def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
"""
执行图片预处理
Args:
image: 原始图片
Returns:
预处理后的图片
"""
processed = image
for preprocessor in self._image_preprocessors:
processed = preprocessor(processed)
return processed
def _postprocess_result(self, result: OCRResult) -> OCRResult:
"""
执行结果后处理
Args:
result: 原始结果
Returns:
后处理后的结果
"""
processed = result
for postprocessor in self._result_postprocessors:
processed = postprocessor(processed)
return processed
def process(
self,
image: np.ndarray,
image_path: Optional[str] = None
) -> OCRResult:
"""
处理单张图片
Args:
image: 输入图片 (numpy array, BGR 格式)
image_path: 图片路径可选用于结果记录
Returns:
OCR 结果
"""
self._image_counter += 1
start_time = time.time()
# 应用 ROI 裁剪
cropped_image, roi_offset, roi_rect = self._apply_roi(image)
# 图片预处理
processed_image = self._preprocess_image(cropped_image)
# 执行 OCR
text_blocks = self._engine.recognize(processed_image, roi_offset)
# 计算处理耗时
processing_time_ms = (time.time() - start_time) * 1000
# 构建结果
result = OCRResult(
image_index=self._image_counter,
image_path=image_path,
timestamp=time.time(),
processing_time_ms=processing_time_ms,
text_blocks=text_blocks,
roi_applied=self._pipeline_config.roi.enabled,
roi_rect=roi_rect
)
# 结果后处理
result = self._postprocess_result(result)
return result
def reset_counter(self) -> None:
"""重置图片计数器"""
self._image_counter = 0
@property
def image_counter(self) -> int:
"""获取已处理的图片计数"""
return self._image_counter
@property
def config(self) -> PipelineConfig:
"""获取管道配置"""
return self._pipeline_config

@ -0,0 +1,10 @@
# Vision-OCR Dependencies
# Core dependencies
paddlepaddle>=2.5.0,<3.0.0
paddleocr>=2.7.0,<3.0.0
opencv-python>=4.8.0
numpy>=1.24.0,<2.0
# Optional dependencies (for Chinese font rendering)
Pillow>=10.0.0

@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
"""
工具模块
提供配置管理等通用功能
"""
from .config import (
Config,
OCRConfig,
InputConfig,
VisualizeConfig,
PipelineConfig,
ROIConfig,
OutputConfig,
InputMode
)
__all__ = [
"Config",
"OCRConfig",
"InputConfig",
"VisualizeConfig",
"PipelineConfig",
"ROIConfig",
"OutputConfig",
"InputMode"
]

@ -0,0 +1,242 @@
# -*- coding: utf-8 -*-
"""
配置管理模块
集中管理所有可配置参数便于维护和扩展
"""
from dataclasses import dataclass, field
from typing import Optional, Tuple, List
from enum import Enum
class InputMode(Enum):
"""输入模式枚举"""
SINGLE = "single" # 单张图片
BATCH = "batch" # 多张图片列表
DIRECTORY = "directory" # 目录批量
@dataclass
class InputConfig:
"""
图片输入配置
Attributes:
mode: 输入模式
image_path: 单张图片路径
image_paths: 多张图片路径列表
directory: 图片目录路径
pattern: 文件匹配模式 "*.jpg"
recursive: 是否递归搜索子目录
"""
mode: InputMode = InputMode.SINGLE
image_path: Optional[str] = None
image_paths: Optional[List[str]] = None
directory: Optional[str] = None
pattern: Optional[str] = None
recursive: bool = False
def __post_init__(self):
"""参数校验"""
if self.mode == InputMode.SINGLE and not self.image_path:
raise ValueError("单张图片模式下必须指定 image_path")
if self.mode == InputMode.BATCH and not self.image_paths:
raise ValueError("批量模式下必须指定 image_paths")
if self.mode == InputMode.DIRECTORY and not self.directory:
raise ValueError("目录模式下必须指定 directory")
@dataclass
class ROIConfig:
"""
感兴趣区域(ROI)配置
使用归一化坐标 (0.0 ~ 1.0)便于适配不同分辨率
Attributes:
enabled: 是否启用 ROI 裁剪
x_ratio: ROI 左上角 x 坐标比例
y_ratio: ROI 左上角 y 坐标比例
width_ratio: ROI 宽度比例
height_ratio: ROI 高度比例
"""
enabled: bool = False
x_ratio: float = 0.1
y_ratio: float = 0.1
width_ratio: float = 0.8
height_ratio: float = 0.8
def get_roi_rect(self, frame_width: int, frame_height: int) -> Tuple[int, int, int, int]:
"""
根据帧尺寸计算实际 ROI 矩形
Args:
frame_width: 帧宽度
frame_height: 帧高度
Returns:
(x, y, width, height) 像素坐标
"""
x = int(frame_width * self.x_ratio)
y = int(frame_height * self.y_ratio)
width = int(frame_width * self.width_ratio)
height = int(frame_height * self.height_ratio)
return x, y, width, height
@dataclass
class OCRConfig:
"""
OCR 引擎配置 (适配 PaddleOCR 2.x API)
Attributes:
lang: 识别语言支持 "ch"(中文), "en"(英文)
use_angle_cls: 是否启用方向分类器
use_gpu: 是否使用 GPU 加速
det_db_thresh: 文本检测阈值
det_db_box_thresh: 检测框阈值
drop_score: 低于此置信度的结果将被过滤
show_log: 是否显示 PaddleOCR 日志
det_model_dir: 检测模型目录路径
rec_model_dir: 识别模型目录路径
cls_model_dir: 分类模型目录路径
"""
lang: str = "ch"
use_angle_cls: bool = True
use_gpu: bool = False
det_db_thresh: float = 0.3
det_db_box_thresh: float = 0.5
drop_score: float = 0.5
show_log: bool = False
det_model_dir: Optional[str] = None
rec_model_dir: Optional[str] = None
cls_model_dir: Optional[str] = None
@dataclass
class PipelineConfig:
"""
OCR 处理管道配置
Attributes:
roi: ROI 配置
"""
roi: ROIConfig = field(default_factory=ROIConfig)
@dataclass
class VisualizeConfig:
"""
可视化配置
Attributes:
show_window: 是否显示可视化窗口
window_name: 窗口名称
box_color: 文本框颜色 (B, G, R)
box_thickness: 文本框线宽
text_color: 文字颜色 (B, G, R)
text_scale: 文字缩放比例
text_thickness: 文字线宽
show_confidence: 是否在文字旁显示置信度
font_path: 中文字体路径None 则使用 OpenCV 默认字体
"""
show_window: bool = False
window_name: str = "OCR Result"
box_color: Tuple[int, int, int] = (0, 255, 0)
box_thickness: int = 2
text_color: Tuple[int, int, int] = (0, 0, 255)
text_scale: float = 0.6
text_thickness: int = 1
show_confidence: bool = True
font_path: Optional[str] = None
@dataclass
class OutputConfig:
"""
输出配置
Attributes:
output_dir: 输出目录路径
save_json: 是否保存 JSON 结果
save_image: 是否保存标注后的图片
json_filename: JSON 文件名模板
image_suffix: 标注图片后缀
"""
output_dir: Optional[str] = None
save_json: bool = True
save_image: bool = False
json_filename: str = "ocr_result.json"
image_suffix: str = "_ocr"
@dataclass
class Config:
"""
全局配置类聚合所有配置模块
Attributes:
input: 输入配置
ocr: OCR 引擎配置
pipeline: 处理管道配置
visualize: 可视化配置
output: 输出配置
"""
input: Optional[InputConfig] = None
ocr: OCRConfig = field(default_factory=OCRConfig)
pipeline: PipelineConfig = field(default_factory=PipelineConfig)
visualize: VisualizeConfig = field(default_factory=VisualizeConfig)
output: OutputConfig = field(default_factory=OutputConfig)
@classmethod
def default(cls) -> "Config":
"""创建默认配置"""
return cls()
@classmethod
def for_single_image(cls, image_path: str) -> "Config":
"""
创建单张图片模式的配置
Args:
image_path: 图片路径
"""
config = cls()
config.input = InputConfig(
mode=InputMode.SINGLE,
image_path=image_path
)
return config
@classmethod
def for_directory(cls, directory: str, pattern: Optional[str] = None, recursive: bool = False) -> "Config":
"""
创建目录批量模式的配置
Args:
directory: 目录路径
pattern: 文件匹配模式
recursive: 是否递归搜索
"""
config = cls()
config.input = InputConfig(
mode=InputMode.DIRECTORY,
directory=directory,
pattern=pattern,
recursive=recursive
)
return config
@classmethod
def for_batch(cls, image_paths: List[str]) -> "Config":
"""
创建批量图片模式的配置
Args:
image_paths: 图片路径列表
"""
config = cls()
config.input = InputConfig(
mode=InputMode.BATCH,
image_paths=image_paths
)
return config

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
"""
可视化模块
提供 OCR 结果的可视化绘制功能
"""
from .draw import OCRVisualizer
__all__ = ["OCRVisualizer"]

@ -0,0 +1,368 @@
# -*- coding: utf-8 -*-
"""
可视化模块
在图像上绘制 OCR 识别结果
"""
import os
import sys
import cv2
import numpy as np
from typing import List, Optional, Tuple
from ocr.engine import TextBlock
from ocr.pipeline import OCRResult
from utils.config import VisualizeConfig
# Windows 系统常用中文字体列表(按优先级排序)
_WINDOWS_CHINESE_FONTS = [
"msyh.ttc", # 微软雅黑
"msyhbd.ttc", # 微软雅黑粗体
"simhei.ttf", # 黑体
"simsun.ttc", # 宋体
"simkai.ttf", # 楷体
]
# Linux 系统常用中文字体路径
_LINUX_CHINESE_FONTS = [
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
"/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
"/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf",
]
def _find_system_chinese_font() -> Optional[str]:
"""
自动查找系统中文字体
Returns:
字体文件路径未找到返回 None
"""
if sys.platform == "win32":
# Windows 字体目录
fonts_dir = os.path.join(os.environ.get("WINDIR", "C:\\Windows"), "Fonts")
for font_name in _WINDOWS_CHINESE_FONTS:
font_path = os.path.join(fonts_dir, font_name)
if os.path.exists(font_path):
return font_path
else:
# Linux/macOS
for font_path in _LINUX_CHINESE_FONTS:
if os.path.exists(font_path):
return font_path
return None
class OCRVisualizer:
"""
OCR 结果可视化器
在图像上绘制文本框和识别结果
"""
def __init__(self, config: VisualizeConfig):
"""
初始化可视化器
Args:
config: 可视化配置
"""
self._config = config
self._font = cv2.FONT_HERSHEY_SIMPLEX
# 尝试加载中文字体
self._pil_font = None
self._use_pil = False
# 确定字体路径:优先使用配置的路径,否则自动检测系统字体
font_path = config.font_path or _find_system_chinese_font()
if font_path:
try:
from PIL import ImageFont
self._pil_font = ImageFont.truetype(font_path, 20)
self._use_pil = True
except Exception:
# 字体加载失败,使用 OpenCV 默认字体
pass
def draw_text_blocks(
self,
frame: np.ndarray,
text_blocks: List[TextBlock],
copy: bool = True
) -> np.ndarray:
"""
在帧上绘制文本块
Args:
frame: 输入帧
text_blocks: 文本块列表
copy: 是否复制帧避免修改原帧
Returns:
绘制后的帧
"""
if copy:
frame = frame.copy()
for block in text_blocks:
self._draw_single_block(frame, block)
return frame
def draw_result(
self,
frame: np.ndarray,
result: Optional[OCRResult],
copy: bool = True
) -> np.ndarray:
"""
在帧上绘制 OCR 结果
Args:
frame: 输入帧
result: OCR 结果
copy: 是否复制帧
Returns:
绘制后的帧
"""
if copy:
frame = frame.copy()
if result is None:
return frame
# 绘制 ROI 区域(如果启用)
if result.roi_applied and result.roi_rect:
self._draw_roi(frame, result.roi_rect)
# 绘制所有文本块
for block in result.text_blocks:
self._draw_single_block(frame, block)
# 绘制状态信息
self._draw_status(frame, result)
return frame
def _draw_single_block(
self,
frame: np.ndarray,
block: TextBlock
) -> None:
"""
绘制单个文本块
Args:
frame:
block: 文本块
"""
# 获取带偏移的边界框坐标
bbox = block.bbox_with_offset
points = np.array(bbox, dtype=np.int32)
# 绘制多边形边框
cv2.polylines(
frame,
[points],
isClosed=True,
color=self._config.box_color,
thickness=self._config.box_thickness
)
# 准备显示文本
display_text = block.text
if self._config.show_confidence:
display_text = f"{block.text} ({block.confidence:.2f})"
# 计算文本位置(在边界框左上角上方)
text_x = int(min(p[0] for p in bbox))
text_y = int(min(p[1] for p in bbox)) - 5
# 确保文本不超出画面
text_y = max(text_y, 20)
# 绘制文本
if self._use_pil and self._pil_font:
self._draw_text_pil(frame, display_text, (text_x, text_y))
else:
self._draw_text_cv2(frame, display_text, (text_x, text_y))
def _draw_text_cv2(
self,
frame: np.ndarray,
text: str,
position: Tuple[int, int]
) -> None:
"""
使用 OpenCV 绘制文本不支持中文会显示方块
Args:
frame:
text: 文本
position: 位置 (x, y)
"""
# 绘制文本背景(提高可读性)
(text_width, text_height), baseline = cv2.getTextSize(
text,
self._font,
self._config.text_scale,
self._config.text_thickness
)
x, y = position
cv2.rectangle(
frame,
(x, y - text_height - 5),
(x + text_width + 5, y + 5),
(255, 255, 255),
-1
)
# 绘制文本
cv2.putText(
frame,
text,
position,
self._font,
self._config.text_scale,
self._config.text_color,
self._config.text_thickness,
cv2.LINE_AA
)
def _draw_text_pil(
self,
frame: np.ndarray,
text: str,
position: Tuple[int, int]
) -> None:
"""
使用 PIL 绘制文本支持中文
Args:
frame:
text: 文本
position: 位置 (x, y)
"""
from PIL import Image, ImageDraw
# OpenCV 图像转 PIL
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
# 获取文本尺寸
bbox = draw.textbbox(position, text, font=self._pil_font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
x, y = position
# 绘制背景
draw.rectangle(
[x - 2, y - text_height - 2, x + text_width + 2, y + 2],
fill=(255, 255, 255)
)
# 绘制文本
text_color_rgb = (
self._config.text_color[2],
self._config.text_color[1],
self._config.text_color[0]
)
draw.text(position, text, font=self._pil_font, fill=text_color_rgb)
# PIL 图像转回 OpenCV
result = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
np.copyto(frame, result)
def _draw_roi(
self,
frame: np.ndarray,
roi_rect: Tuple[int, int, int, int]
) -> None:
"""
绘制 ROI 区域
Args:
frame:
roi_rect: ROI 矩形 (x, y, width, height)
"""
x, y, w, h = roi_rect
cv2.rectangle(
frame,
(x, y),
(x + w, y + h),
(255, 255, 0), # 青色
2,
cv2.LINE_AA
)
def _draw_status(
self,
frame: np.ndarray,
result: OCRResult
) -> None:
"""
绘制状态信息
Args:
frame:
result: OCR 结果
"""
h, w = frame.shape[:2]
# 状态文本
status_lines = [
f"Image: {result.image_index}",
f"Texts: {result.text_count}",
f"Time: {result.processing_time_ms:.1f}ms"
]
y_offset = 25
for line in status_lines:
cv2.putText(
frame,
line,
(10, y_offset),
self._font,
0.5,
(0, 255, 0),
1,
cv2.LINE_AA
)
y_offset += 20
def show(
self,
frame: np.ndarray,
wait_key: int = 1
) -> int:
"""
显示帧并等待按键
Args:
frame:
wait_key: 等待时间毫秒
Returns:
按下的键码
"""
if not self._config.show_window:
return -1
cv2.imshow(self._config.window_name, frame)
return cv2.waitKey(wait_key)
def close(self) -> None:
"""关闭所有窗口"""
cv2.destroyAllWindows()
@property
def config(self) -> VisualizeConfig:
"""获取配置"""
return self._config
Loading…
Cancel
Save