You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

279 lines
7.4 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""
图片加载模块
提供图片加载功能,支持单张图片、多张图片和目录批量加载
"""
import cv2
import numpy as np
from pathlib import Path
from typing import List, Optional, Generator, Union, Tuple
from dataclasses import dataclass
# 支持的图片格式
SUPPORTED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'}
@dataclass
class ImageInfo:
"""
图片信息数据结构
Attributes:
path: 图片文件路径
image: 图片数据 (numpy array, BGR 格式)
width: 图片宽度
height: 图片高度
filename: 文件名(不含路径)
"""
path: str
image: np.ndarray
width: int
height: int
filename: str
@classmethod
def from_image(cls, path: str, image: np.ndarray) -> "ImageInfo":
"""
从图片数据创建 ImageInfo
Args:
path: 图片路径
image: 图片数据
Returns:
ImageInfo 实例
"""
h, w = image.shape[:2]
return cls(
path=path,
image=image,
width=w,
height=h,
filename=Path(path).name
)
class ImageLoader:
"""
图片加载器
支持单张图片、多张图片列表和目录批量加载
"""
def __init__(self, supported_extensions: Optional[set] = None):
"""
初始化图片加载器
Args:
supported_extensions: 支持的图片扩展名集合,默认使用 SUPPORTED_EXTENSIONS
"""
self._extensions = supported_extensions or SUPPORTED_EXTENSIONS
def load(self, path: Union[str, Path]) -> Optional[ImageInfo]:
"""
加载单张图片
Args:
path: 图片文件路径
Returns:
ImageInfo 对象,加载失败返回 None
"""
path = Path(path)
if not path.exists():
print(f"[ERROR] 文件不存在: {path}")
return None
if not path.is_file():
print(f"[ERROR] 路径不是文件: {path}")
return None
if path.suffix.lower() not in self._extensions:
print(f"[ERROR] 不支持的图片格式: {path.suffix}")
return None
# 使用 cv2.imdecode 处理中文路径
image = self._read_image(str(path))
if image is None:
print(f"[ERROR] 无法读取图片: {path}")
return None
return ImageInfo.from_image(str(path), image)
def load_batch(self, paths: List[Union[str, Path]]) -> Generator[ImageInfo, None, None]:
"""
批量加载多张图片
Args:
paths: 图片路径列表
Yields:
ImageInfo 对象
"""
for path in paths:
info = self.load(path)
if info is not None:
yield info
def load_directory(
self,
dir_path: Union[str, Path],
pattern: Optional[str] = None,
recursive: bool = False
) -> Generator[ImageInfo, None, None]:
"""
加载目录中的所有图片
Args:
dir_path: 目录路径
pattern: 文件匹配模式(如 "*.jpg"None 表示加载所有支持的格式
recursive: 是否递归搜索子目录
Yields:
ImageInfo 对象
"""
dir_path = Path(dir_path)
if not dir_path.exists():
print(f"[ERROR] 目录不存在: {dir_path}")
return
if not dir_path.is_dir():
print(f"[ERROR] 路径不是目录: {dir_path}")
return
# 获取文件列表
if pattern:
if recursive:
files = list(dir_path.rglob(pattern))
else:
files = list(dir_path.glob(pattern))
else:
# 加载所有支持的格式
files = []
for ext in self._extensions:
if recursive:
files.extend(dir_path.rglob(f"*{ext}"))
files.extend(dir_path.rglob(f"*{ext.upper()}"))
else:
files.extend(dir_path.glob(f"*{ext}"))
files.extend(dir_path.glob(f"*{ext.upper()}"))
# 按文件名排序
files = sorted(set(files), key=lambda p: p.name)
print(f"[INFO] 在目录 {dir_path} 中找到 {len(files)} 张图片")
for file_path in files:
info = self.load(file_path)
if info is not None:
yield info
def get_image_paths(
self,
dir_path: Union[str, Path],
pattern: Optional[str] = None,
recursive: bool = False
) -> List[str]:
"""
获取目录中所有图片的路径列表(不加载图片)
Args:
dir_path: 目录路径
pattern: 文件匹配模式
recursive: 是否递归搜索
Returns:
图片路径列表
"""
dir_path = Path(dir_path)
if not dir_path.exists() or not dir_path.is_dir():
return []
if pattern:
if recursive:
files = list(dir_path.rglob(pattern))
else:
files = list(dir_path.glob(pattern))
else:
files = []
for ext in self._extensions:
if recursive:
files.extend(dir_path.rglob(f"*{ext}"))
files.extend(dir_path.rglob(f"*{ext.upper()}"))
else:
files.extend(dir_path.glob(f"*{ext}"))
files.extend(dir_path.glob(f"*{ext.upper()}"))
return sorted([str(f) for f in set(files)], key=lambda p: Path(p).name)
def _read_image(self, path: str) -> Optional[np.ndarray]:
"""
读取图片,支持中文路径
Args:
path: 图片路径
Returns:
图片数据,读取失败返回 None
"""
# 使用 numpy 和 imdecode 处理中文路径
try:
with open(path, 'rb') as f:
data = np.frombuffer(f.read(), dtype=np.uint8)
image = cv2.imdecode(data, cv2.IMREAD_COLOR)
return image
except Exception as e:
print(f"[ERROR] 读取图片失败: {path}, 错误: {e}")
return None
@property
def supported_extensions(self) -> set:
"""获取支持的图片扩展名"""
return self._extensions.copy()
def load_image(path: Union[str, Path]) -> Optional[np.ndarray]:
"""
便捷函数:加载单张图片
Args:
path: 图片路径
Returns:
图片数据 (numpy array, BGR 格式),加载失败返回 None
"""
loader = ImageLoader()
info = loader.load(path)
return info.image if info else None
def load_images(
paths: Optional[List[Union[str, Path]]] = None,
directory: Optional[Union[str, Path]] = None,
pattern: Optional[str] = None,
recursive: bool = False
) -> Generator[ImageInfo, None, None]:
"""
便捷函数:批量加载图片
Args:
paths: 图片路径列表
directory: 图片目录
pattern: 文件匹配模式
recursive: 是否递归搜索
Yields:
ImageInfo 对象
"""
loader = ImageLoader()
if paths:
yield from loader.load_batch(paths)
elif directory:
yield from loader.load_directory(directory, pattern, recursive)