You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

83 lines
2.2 KiB
Python

# -*- coding: utf-8 -*-
"""
模型下载脚本
将 PaddleOCR 模型下载到项目目录,避免中文路径问题
"""
import os
import tarfile
import urllib.request
from pathlib import Path
# 模型下载地址
MODELS = {
"det": {
"url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar",
"dir": "ch_PP-OCRv4_det_infer"
},
"rec": {
"url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_infer.tar",
"dir": "ch_PP-OCRv4_rec_infer"
},
"cls": {
"url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar",
"dir": "ch_ppocr_mobile_v2.0_cls_infer"
}
}
def download_and_extract(url: str, save_dir: Path, model_name: str) -> Path:
"""
下载并解压模型
Args:
url: 下载地址
save_dir: 保存目录
model_name: 模型名称
Returns:
解压后的模型目录路径
"""
save_dir.mkdir(parents=True, exist_ok=True)
tar_path = save_dir / f"{model_name}.tar"
# 下载
if not tar_path.exists():
print(f"[INFO] Downloading {model_name}...")
urllib.request.urlretrieve(url, tar_path)
print(f"[INFO] Downloaded to {tar_path}")
else:
print(f"[INFO] {model_name} already exists, skipping download")
# 解压
extract_dir = save_dir / model_name
if not extract_dir.exists():
print(f"[INFO] Extracting {model_name}...")
with tarfile.open(tar_path, "r") as tar:
tar.extractall(save_dir)
print(f"[INFO] Extracted to {extract_dir}")
return extract_dir
def main():
"""下载所有模型"""
project_root = Path(__file__).parent
models_dir = project_root / "models"
print(f"[INFO] Models will be saved to: {models_dir}")
for model_type, info in MODELS.items():
model_dir = download_and_extract(
url=info["url"],
save_dir=models_dir,
model_name=info["dir"]
)
print(f"[INFO] {model_type} model ready at: {model_dir}")
print("\n[INFO] All models downloaded successfully!")
print("[INFO] You can now run: python main.py")
if __name__ == "__main__":
main()