# -*- coding: utf-8 -*- """ 图片加载模块 提供图片加载功能,支持单张图片、多张图片和目录批量加载 """ import cv2 import numpy as np from pathlib import Path from typing import List, Optional, Generator, Union, Tuple from dataclasses import dataclass # 支持的图片格式 SUPPORTED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'} @dataclass class ImageInfo: """ 图片信息数据结构 Attributes: path: 图片文件路径 image: 图片数据 (numpy array, BGR 格式) width: 图片宽度 height: 图片高度 filename: 文件名(不含路径) """ path: str image: np.ndarray width: int height: int filename: str @classmethod def from_image(cls, path: str, image: np.ndarray) -> "ImageInfo": """ 从图片数据创建 ImageInfo Args: path: 图片路径 image: 图片数据 Returns: ImageInfo 实例 """ h, w = image.shape[:2] return cls( path=path, image=image, width=w, height=h, filename=Path(path).name ) class ImageLoader: """ 图片加载器 支持单张图片、多张图片列表和目录批量加载 """ def __init__(self, supported_extensions: Optional[set] = None): """ 初始化图片加载器 Args: supported_extensions: 支持的图片扩展名集合,默认使用 SUPPORTED_EXTENSIONS """ self._extensions = supported_extensions or SUPPORTED_EXTENSIONS def load(self, path: Union[str, Path]) -> Optional[ImageInfo]: """ 加载单张图片 Args: path: 图片文件路径 Returns: ImageInfo 对象,加载失败返回 None """ path = Path(path) if not path.exists(): print(f"[ERROR] 文件不存在: {path}") return None if not path.is_file(): print(f"[ERROR] 路径不是文件: {path}") return None if path.suffix.lower() not in self._extensions: print(f"[ERROR] 不支持的图片格式: {path.suffix}") return None # 使用 cv2.imdecode 处理中文路径 image = self._read_image(str(path)) if image is None: print(f"[ERROR] 无法读取图片: {path}") return None return ImageInfo.from_image(str(path), image) def load_batch(self, paths: List[Union[str, Path]]) -> Generator[ImageInfo, None, None]: """ 批量加载多张图片 Args: paths: 图片路径列表 Yields: ImageInfo 对象 """ for path in paths: info = self.load(path) if info is not None: yield info def load_directory( self, dir_path: Union[str, Path], pattern: Optional[str] = None, recursive: bool = False ) -> Generator[ImageInfo, None, None]: """ 加载目录中的所有图片 Args: dir_path: 目录路径 pattern: 文件匹配模式(如 "*.jpg"),None 表示加载所有支持的格式 recursive: 是否递归搜索子目录 Yields: ImageInfo 对象 """ dir_path = Path(dir_path) if not dir_path.exists(): print(f"[ERROR] 目录不存在: {dir_path}") return if not dir_path.is_dir(): print(f"[ERROR] 路径不是目录: {dir_path}") return # 获取文件列表 if pattern: if recursive: files = list(dir_path.rglob(pattern)) else: files = list(dir_path.glob(pattern)) else: # 加载所有支持的格式 files = [] for ext in self._extensions: if recursive: files.extend(dir_path.rglob(f"*{ext}")) files.extend(dir_path.rglob(f"*{ext.upper()}")) else: files.extend(dir_path.glob(f"*{ext}")) files.extend(dir_path.glob(f"*{ext.upper()}")) # 按文件名排序 files = sorted(set(files), key=lambda p: p.name) print(f"[INFO] 在目录 {dir_path} 中找到 {len(files)} 张图片") for file_path in files: info = self.load(file_path) if info is not None: yield info def get_image_paths( self, dir_path: Union[str, Path], pattern: Optional[str] = None, recursive: bool = False ) -> List[str]: """ 获取目录中所有图片的路径列表(不加载图片) Args: dir_path: 目录路径 pattern: 文件匹配模式 recursive: 是否递归搜索 Returns: 图片路径列表 """ dir_path = Path(dir_path) if not dir_path.exists() or not dir_path.is_dir(): return [] if pattern: if recursive: files = list(dir_path.rglob(pattern)) else: files = list(dir_path.glob(pattern)) else: files = [] for ext in self._extensions: if recursive: files.extend(dir_path.rglob(f"*{ext}")) files.extend(dir_path.rglob(f"*{ext.upper()}")) else: files.extend(dir_path.glob(f"*{ext}")) files.extend(dir_path.glob(f"*{ext.upper()}")) return sorted([str(f) for f in set(files)], key=lambda p: Path(p).name) def _read_image(self, path: str) -> Optional[np.ndarray]: """ 读取图片,支持中文路径 Args: path: 图片路径 Returns: 图片数据,读取失败返回 None """ # 使用 numpy 和 imdecode 处理中文路径 try: with open(path, 'rb') as f: data = np.frombuffer(f.read(), dtype=np.uint8) image = cv2.imdecode(data, cv2.IMREAD_COLOR) return image except Exception as e: print(f"[ERROR] 读取图片失败: {path}, 错误: {e}") return None @property def supported_extensions(self) -> set: """获取支持的图片扩展名""" return self._extensions.copy() def load_image(path: Union[str, Path]) -> Optional[np.ndarray]: """ 便捷函数:加载单张图片 Args: path: 图片路径 Returns: 图片数据 (numpy array, BGR 格式),加载失败返回 None """ loader = ImageLoader() info = loader.load(path) return info.image if info else None def load_images( paths: Optional[List[Union[str, Path]]] = None, directory: Optional[Union[str, Path]] = None, pattern: Optional[str] = None, recursive: bool = False ) -> Generator[ImageInfo, None, None]: """ 便捷函数:批量加载图片 Args: paths: 图片路径列表 directory: 图片目录 pattern: 文件匹配模式 recursive: 是否递归搜索 Yields: ImageInfo 对象 """ loader = ImageLoader() if paths: yield from loader.load_batch(paths) elif directory: yield from loader.load_directory(directory, pattern, recursive)