|
|
# -*- coding: utf-8 -*-
|
|
|
"""
|
|
|
图片加载模块
|
|
|
提供图片加载功能,支持单张图片、多张图片和目录批量加载
|
|
|
"""
|
|
|
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
from pathlib import Path
|
|
|
from typing import List, Optional, Generator, Union, Tuple
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
# 支持的图片格式
|
|
|
SUPPORTED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'}
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class ImageInfo:
|
|
|
"""
|
|
|
图片信息数据结构
|
|
|
|
|
|
Attributes:
|
|
|
path: 图片文件路径
|
|
|
image: 图片数据 (numpy array, BGR 格式)
|
|
|
width: 图片宽度
|
|
|
height: 图片高度
|
|
|
filename: 文件名(不含路径)
|
|
|
"""
|
|
|
path: str
|
|
|
image: np.ndarray
|
|
|
width: int
|
|
|
height: int
|
|
|
filename: str
|
|
|
|
|
|
@classmethod
|
|
|
def from_image(cls, path: str, image: np.ndarray) -> "ImageInfo":
|
|
|
"""
|
|
|
从图片数据创建 ImageInfo
|
|
|
|
|
|
Args:
|
|
|
path: 图片路径
|
|
|
image: 图片数据
|
|
|
|
|
|
Returns:
|
|
|
ImageInfo 实例
|
|
|
"""
|
|
|
h, w = image.shape[:2]
|
|
|
return cls(
|
|
|
path=path,
|
|
|
image=image,
|
|
|
width=w,
|
|
|
height=h,
|
|
|
filename=Path(path).name
|
|
|
)
|
|
|
|
|
|
|
|
|
class ImageLoader:
|
|
|
"""
|
|
|
图片加载器
|
|
|
支持单张图片、多张图片列表和目录批量加载
|
|
|
"""
|
|
|
|
|
|
def __init__(self, supported_extensions: Optional[set] = None):
|
|
|
"""
|
|
|
初始化图片加载器
|
|
|
|
|
|
Args:
|
|
|
supported_extensions: 支持的图片扩展名集合,默认使用 SUPPORTED_EXTENSIONS
|
|
|
"""
|
|
|
self._extensions = supported_extensions or SUPPORTED_EXTENSIONS
|
|
|
|
|
|
def load(self, path: Union[str, Path]) -> Optional[ImageInfo]:
|
|
|
"""
|
|
|
加载单张图片
|
|
|
|
|
|
Args:
|
|
|
path: 图片文件路径
|
|
|
|
|
|
Returns:
|
|
|
ImageInfo 对象,加载失败返回 None
|
|
|
"""
|
|
|
path = Path(path)
|
|
|
|
|
|
if not path.exists():
|
|
|
print(f"[ERROR] 文件不存在: {path}")
|
|
|
return None
|
|
|
|
|
|
if not path.is_file():
|
|
|
print(f"[ERROR] 路径不是文件: {path}")
|
|
|
return None
|
|
|
|
|
|
if path.suffix.lower() not in self._extensions:
|
|
|
print(f"[ERROR] 不支持的图片格式: {path.suffix}")
|
|
|
return None
|
|
|
|
|
|
# 使用 cv2.imdecode 处理中文路径
|
|
|
image = self._read_image(str(path))
|
|
|
|
|
|
if image is None:
|
|
|
print(f"[ERROR] 无法读取图片: {path}")
|
|
|
return None
|
|
|
|
|
|
return ImageInfo.from_image(str(path), image)
|
|
|
|
|
|
def load_batch(self, paths: List[Union[str, Path]]) -> Generator[ImageInfo, None, None]:
|
|
|
"""
|
|
|
批量加载多张图片
|
|
|
|
|
|
Args:
|
|
|
paths: 图片路径列表
|
|
|
|
|
|
Yields:
|
|
|
ImageInfo 对象
|
|
|
"""
|
|
|
for path in paths:
|
|
|
info = self.load(path)
|
|
|
if info is not None:
|
|
|
yield info
|
|
|
|
|
|
def load_directory(
|
|
|
self,
|
|
|
dir_path: Union[str, Path],
|
|
|
pattern: Optional[str] = None,
|
|
|
recursive: bool = False
|
|
|
) -> Generator[ImageInfo, None, None]:
|
|
|
"""
|
|
|
加载目录中的所有图片
|
|
|
|
|
|
Args:
|
|
|
dir_path: 目录路径
|
|
|
pattern: 文件匹配模式(如 "*.jpg"),None 表示加载所有支持的格式
|
|
|
recursive: 是否递归搜索子目录
|
|
|
|
|
|
Yields:
|
|
|
ImageInfo 对象
|
|
|
"""
|
|
|
dir_path = Path(dir_path)
|
|
|
|
|
|
if not dir_path.exists():
|
|
|
print(f"[ERROR] 目录不存在: {dir_path}")
|
|
|
return
|
|
|
|
|
|
if not dir_path.is_dir():
|
|
|
print(f"[ERROR] 路径不是目录: {dir_path}")
|
|
|
return
|
|
|
|
|
|
# 获取文件列表
|
|
|
if pattern:
|
|
|
if recursive:
|
|
|
files = list(dir_path.rglob(pattern))
|
|
|
else:
|
|
|
files = list(dir_path.glob(pattern))
|
|
|
else:
|
|
|
# 加载所有支持的格式
|
|
|
files = []
|
|
|
for ext in self._extensions:
|
|
|
if recursive:
|
|
|
files.extend(dir_path.rglob(f"*{ext}"))
|
|
|
files.extend(dir_path.rglob(f"*{ext.upper()}"))
|
|
|
else:
|
|
|
files.extend(dir_path.glob(f"*{ext}"))
|
|
|
files.extend(dir_path.glob(f"*{ext.upper()}"))
|
|
|
|
|
|
# 按文件名排序
|
|
|
files = sorted(set(files), key=lambda p: p.name)
|
|
|
|
|
|
print(f"[INFO] 在目录 {dir_path} 中找到 {len(files)} 张图片")
|
|
|
|
|
|
for file_path in files:
|
|
|
info = self.load(file_path)
|
|
|
if info is not None:
|
|
|
yield info
|
|
|
|
|
|
def get_image_paths(
|
|
|
self,
|
|
|
dir_path: Union[str, Path],
|
|
|
pattern: Optional[str] = None,
|
|
|
recursive: bool = False
|
|
|
) -> List[str]:
|
|
|
"""
|
|
|
获取目录中所有图片的路径列表(不加载图片)
|
|
|
|
|
|
Args:
|
|
|
dir_path: 目录路径
|
|
|
pattern: 文件匹配模式
|
|
|
recursive: 是否递归搜索
|
|
|
|
|
|
Returns:
|
|
|
图片路径列表
|
|
|
"""
|
|
|
dir_path = Path(dir_path)
|
|
|
|
|
|
if not dir_path.exists() or not dir_path.is_dir():
|
|
|
return []
|
|
|
|
|
|
if pattern:
|
|
|
if recursive:
|
|
|
files = list(dir_path.rglob(pattern))
|
|
|
else:
|
|
|
files = list(dir_path.glob(pattern))
|
|
|
else:
|
|
|
files = []
|
|
|
for ext in self._extensions:
|
|
|
if recursive:
|
|
|
files.extend(dir_path.rglob(f"*{ext}"))
|
|
|
files.extend(dir_path.rglob(f"*{ext.upper()}"))
|
|
|
else:
|
|
|
files.extend(dir_path.glob(f"*{ext}"))
|
|
|
files.extend(dir_path.glob(f"*{ext.upper()}"))
|
|
|
|
|
|
return sorted([str(f) for f in set(files)], key=lambda p: Path(p).name)
|
|
|
|
|
|
def _read_image(self, path: str) -> Optional[np.ndarray]:
|
|
|
"""
|
|
|
读取图片,支持中文路径
|
|
|
|
|
|
Args:
|
|
|
path: 图片路径
|
|
|
|
|
|
Returns:
|
|
|
图片数据,读取失败返回 None
|
|
|
"""
|
|
|
# 使用 numpy 和 imdecode 处理中文路径
|
|
|
try:
|
|
|
with open(path, 'rb') as f:
|
|
|
data = np.frombuffer(f.read(), dtype=np.uint8)
|
|
|
image = cv2.imdecode(data, cv2.IMREAD_COLOR)
|
|
|
return image
|
|
|
except Exception as e:
|
|
|
print(f"[ERROR] 读取图片失败: {path}, 错误: {e}")
|
|
|
return None
|
|
|
|
|
|
@property
|
|
|
def supported_extensions(self) -> set:
|
|
|
"""获取支持的图片扩展名"""
|
|
|
return self._extensions.copy()
|
|
|
|
|
|
|
|
|
def load_image(path: Union[str, Path]) -> Optional[np.ndarray]:
|
|
|
"""
|
|
|
便捷函数:加载单张图片
|
|
|
|
|
|
Args:
|
|
|
path: 图片路径
|
|
|
|
|
|
Returns:
|
|
|
图片数据 (numpy array, BGR 格式),加载失败返回 None
|
|
|
"""
|
|
|
loader = ImageLoader()
|
|
|
info = loader.load(path)
|
|
|
return info.image if info else None
|
|
|
|
|
|
|
|
|
def load_images(
|
|
|
paths: Optional[List[Union[str, Path]]] = None,
|
|
|
directory: Optional[Union[str, Path]] = None,
|
|
|
pattern: Optional[str] = None,
|
|
|
recursive: bool = False
|
|
|
) -> Generator[ImageInfo, None, None]:
|
|
|
"""
|
|
|
便捷函数:批量加载图片
|
|
|
|
|
|
Args:
|
|
|
paths: 图片路径列表
|
|
|
directory: 图片目录
|
|
|
pattern: 文件匹配模式
|
|
|
recursive: 是否递归搜索
|
|
|
|
|
|
Yields:
|
|
|
ImageInfo 对象
|
|
|
"""
|
|
|
loader = ImageLoader()
|
|
|
|
|
|
if paths:
|
|
|
yield from loader.load_batch(paths)
|
|
|
elif directory:
|
|
|
yield from loader.load_directory(directory, pattern, recursive)
|