You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
369 lines
9.0 KiB
Python
369 lines
9.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
可视化模块
|
|
在图像上绘制 OCR 识别结果
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import cv2
|
|
import numpy as np
|
|
from typing import List, Optional, Tuple
|
|
|
|
from ocr.engine import TextBlock
|
|
from ocr.pipeline import OCRResult
|
|
from utils.config import VisualizeConfig
|
|
|
|
|
|
# Windows 系统常用中文字体列表(按优先级排序)
|
|
_WINDOWS_CHINESE_FONTS = [
|
|
"msyh.ttc", # 微软雅黑
|
|
"msyhbd.ttc", # 微软雅黑粗体
|
|
"simhei.ttf", # 黑体
|
|
"simsun.ttc", # 宋体
|
|
"simkai.ttf", # 楷体
|
|
]
|
|
|
|
# Linux 系统常用中文字体路径
|
|
_LINUX_CHINESE_FONTS = [
|
|
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
|
|
"/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
|
|
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
|
|
"/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf",
|
|
]
|
|
|
|
|
|
def _find_system_chinese_font() -> Optional[str]:
|
|
"""
|
|
自动查找系统中文字体
|
|
|
|
Returns:
|
|
字体文件路径,未找到返回 None
|
|
"""
|
|
if sys.platform == "win32":
|
|
# Windows 字体目录
|
|
fonts_dir = os.path.join(os.environ.get("WINDIR", "C:\\Windows"), "Fonts")
|
|
for font_name in _WINDOWS_CHINESE_FONTS:
|
|
font_path = os.path.join(fonts_dir, font_name)
|
|
if os.path.exists(font_path):
|
|
return font_path
|
|
else:
|
|
# Linux/macOS
|
|
for font_path in _LINUX_CHINESE_FONTS:
|
|
if os.path.exists(font_path):
|
|
return font_path
|
|
|
|
return None
|
|
|
|
|
|
class OCRVisualizer:
|
|
"""
|
|
OCR 结果可视化器
|
|
在图像上绘制文本框和识别结果
|
|
"""
|
|
|
|
def __init__(self, config: VisualizeConfig):
|
|
"""
|
|
初始化可视化器
|
|
|
|
Args:
|
|
config: 可视化配置
|
|
"""
|
|
self._config = config
|
|
self._font = cv2.FONT_HERSHEY_SIMPLEX
|
|
|
|
# 尝试加载中文字体
|
|
self._pil_font = None
|
|
self._use_pil = False
|
|
|
|
# 确定字体路径:优先使用配置的路径,否则自动检测系统字体
|
|
font_path = config.font_path or _find_system_chinese_font()
|
|
|
|
if font_path:
|
|
try:
|
|
from PIL import ImageFont
|
|
self._pil_font = ImageFont.truetype(font_path, 20)
|
|
self._use_pil = True
|
|
except Exception:
|
|
# 字体加载失败,使用 OpenCV 默认字体
|
|
pass
|
|
|
|
def draw_text_blocks(
|
|
self,
|
|
frame: np.ndarray,
|
|
text_blocks: List[TextBlock],
|
|
copy: bool = True
|
|
) -> np.ndarray:
|
|
"""
|
|
在帧上绘制文本块
|
|
|
|
Args:
|
|
frame: 输入帧
|
|
text_blocks: 文本块列表
|
|
copy: 是否复制帧(避免修改原帧)
|
|
|
|
Returns:
|
|
绘制后的帧
|
|
"""
|
|
if copy:
|
|
frame = frame.copy()
|
|
|
|
for block in text_blocks:
|
|
self._draw_single_block(frame, block)
|
|
|
|
return frame
|
|
|
|
def draw_result(
|
|
self,
|
|
frame: np.ndarray,
|
|
result: Optional[OCRResult],
|
|
copy: bool = True
|
|
) -> np.ndarray:
|
|
"""
|
|
在帧上绘制 OCR 结果
|
|
|
|
Args:
|
|
frame: 输入帧
|
|
result: OCR 结果
|
|
copy: 是否复制帧
|
|
|
|
Returns:
|
|
绘制后的帧
|
|
"""
|
|
if copy:
|
|
frame = frame.copy()
|
|
|
|
if result is None:
|
|
return frame
|
|
|
|
# 绘制 ROI 区域(如果启用)
|
|
if result.roi_applied and result.roi_rect:
|
|
self._draw_roi(frame, result.roi_rect)
|
|
|
|
# 绘制所有文本块
|
|
for block in result.text_blocks:
|
|
self._draw_single_block(frame, block)
|
|
|
|
# 绘制状态信息
|
|
self._draw_status(frame, result)
|
|
|
|
return frame
|
|
|
|
def _draw_single_block(
|
|
self,
|
|
frame: np.ndarray,
|
|
block: TextBlock
|
|
) -> None:
|
|
"""
|
|
绘制单个文本块
|
|
|
|
Args:
|
|
frame: 帧
|
|
block: 文本块
|
|
"""
|
|
# 获取带偏移的边界框坐标
|
|
bbox = block.bbox_with_offset
|
|
points = np.array(bbox, dtype=np.int32)
|
|
|
|
# 绘制多边形边框
|
|
cv2.polylines(
|
|
frame,
|
|
[points],
|
|
isClosed=True,
|
|
color=self._config.box_color,
|
|
thickness=self._config.box_thickness
|
|
)
|
|
|
|
# 准备显示文本
|
|
display_text = block.text
|
|
if self._config.show_confidence:
|
|
display_text = f"{block.text} ({block.confidence:.2f})"
|
|
|
|
# 计算文本位置(在边界框左上角上方)
|
|
text_x = int(min(p[0] for p in bbox))
|
|
text_y = int(min(p[1] for p in bbox)) - 5
|
|
|
|
# 确保文本不超出画面
|
|
text_y = max(text_y, 20)
|
|
|
|
# 绘制文本
|
|
if self._use_pil and self._pil_font:
|
|
self._draw_text_pil(frame, display_text, (text_x, text_y))
|
|
else:
|
|
self._draw_text_cv2(frame, display_text, (text_x, text_y))
|
|
|
|
def _draw_text_cv2(
|
|
self,
|
|
frame: np.ndarray,
|
|
text: str,
|
|
position: Tuple[int, int]
|
|
) -> None:
|
|
"""
|
|
使用 OpenCV 绘制文本(不支持中文,会显示方块)
|
|
|
|
Args:
|
|
frame: 帧
|
|
text: 文本
|
|
position: 位置 (x, y)
|
|
"""
|
|
# 绘制文本背景(提高可读性)
|
|
(text_width, text_height), baseline = cv2.getTextSize(
|
|
text,
|
|
self._font,
|
|
self._config.text_scale,
|
|
self._config.text_thickness
|
|
)
|
|
|
|
x, y = position
|
|
cv2.rectangle(
|
|
frame,
|
|
(x, y - text_height - 5),
|
|
(x + text_width + 5, y + 5),
|
|
(255, 255, 255),
|
|
-1
|
|
)
|
|
|
|
# 绘制文本
|
|
cv2.putText(
|
|
frame,
|
|
text,
|
|
position,
|
|
self._font,
|
|
self._config.text_scale,
|
|
self._config.text_color,
|
|
self._config.text_thickness,
|
|
cv2.LINE_AA
|
|
)
|
|
|
|
def _draw_text_pil(
|
|
self,
|
|
frame: np.ndarray,
|
|
text: str,
|
|
position: Tuple[int, int]
|
|
) -> None:
|
|
"""
|
|
使用 PIL 绘制文本(支持中文)
|
|
|
|
Args:
|
|
frame: 帧
|
|
text: 文本
|
|
position: 位置 (x, y)
|
|
"""
|
|
from PIL import Image, ImageDraw
|
|
|
|
# OpenCV 图像转 PIL
|
|
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|
draw = ImageDraw.Draw(pil_image)
|
|
|
|
# 获取文本尺寸
|
|
bbox = draw.textbbox(position, text, font=self._pil_font)
|
|
text_width = bbox[2] - bbox[0]
|
|
text_height = bbox[3] - bbox[1]
|
|
|
|
x, y = position
|
|
|
|
# 绘制背景
|
|
draw.rectangle(
|
|
[x - 2, y - text_height - 2, x + text_width + 2, y + 2],
|
|
fill=(255, 255, 255)
|
|
)
|
|
|
|
# 绘制文本
|
|
text_color_rgb = (
|
|
self._config.text_color[2],
|
|
self._config.text_color[1],
|
|
self._config.text_color[0]
|
|
)
|
|
draw.text(position, text, font=self._pil_font, fill=text_color_rgb)
|
|
|
|
# PIL 图像转回 OpenCV
|
|
result = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
|
np.copyto(frame, result)
|
|
|
|
def _draw_roi(
|
|
self,
|
|
frame: np.ndarray,
|
|
roi_rect: Tuple[int, int, int, int]
|
|
) -> None:
|
|
"""
|
|
绘制 ROI 区域
|
|
|
|
Args:
|
|
frame: 帧
|
|
roi_rect: ROI 矩形 (x, y, width, height)
|
|
"""
|
|
x, y, w, h = roi_rect
|
|
cv2.rectangle(
|
|
frame,
|
|
(x, y),
|
|
(x + w, y + h),
|
|
(255, 255, 0), # 青色
|
|
2,
|
|
cv2.LINE_AA
|
|
)
|
|
|
|
def _draw_status(
|
|
self,
|
|
frame: np.ndarray,
|
|
result: OCRResult
|
|
) -> None:
|
|
"""
|
|
绘制状态信息
|
|
|
|
Args:
|
|
frame: 帧
|
|
result: OCR 结果
|
|
"""
|
|
h, w = frame.shape[:2]
|
|
|
|
# 状态文本
|
|
status_lines = [
|
|
f"Image: {result.image_index}",
|
|
f"Texts: {result.text_count}",
|
|
f"Time: {result.processing_time_ms:.1f}ms"
|
|
]
|
|
|
|
y_offset = 25
|
|
for line in status_lines:
|
|
cv2.putText(
|
|
frame,
|
|
line,
|
|
(10, y_offset),
|
|
self._font,
|
|
0.5,
|
|
(0, 255, 0),
|
|
1,
|
|
cv2.LINE_AA
|
|
)
|
|
y_offset += 20
|
|
|
|
def show(
|
|
self,
|
|
frame: np.ndarray,
|
|
wait_key: int = 1
|
|
) -> int:
|
|
"""
|
|
显示帧并等待按键
|
|
|
|
Args:
|
|
frame: 帧
|
|
wait_key: 等待时间(毫秒)
|
|
|
|
Returns:
|
|
按下的键码
|
|
"""
|
|
if not self._config.show_window:
|
|
return -1
|
|
|
|
cv2.imshow(self._config.window_name, frame)
|
|
return cv2.waitKey(wait_key)
|
|
|
|
def close(self) -> None:
|
|
"""关闭所有窗口"""
|
|
cv2.destroyAllWindows()
|
|
|
|
@property
|
|
def config(self) -> VisualizeConfig:
|
|
"""获取配置"""
|
|
return self._config
|