You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

173 lines
4.5 KiB
Python

# -*- coding: utf-8 -*-
"""
Vision-OCR REST API 主入口
基于 FastAPI 的 OCR 服务
"""
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
# 设置项目根目录和模型路径
_PROJECT_ROOT = Path(__file__).parent.parent
_MODELS_DIR = _PROJECT_ROOT / "models"
_MODELS_DIR.mkdir(exist_ok=True)
os.environ["PADDLEOCR_HOME"] = str(_MODELS_DIR)
# 将项目根目录添加到 Python 路径
sys.path.insert(0, str(_PROJECT_ROOT))
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from api.exceptions import OCRAPIException
from api.routes import health_router, ocr_router
from ocr.pipeline import OCRPipeline
from utils.config import OCRConfig, PipelineConfig
def _get_ocr_config() -> OCRConfig:
"""
获取 OCR 配置
检查模型是否存在于项目目录,如果存在则使用本地模型
"""
det_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_det_infer")
rec_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_rec_infer")
cls_model_dir = str(_MODELS_DIR / "ch_ppocr_mobile_v2.0_cls_infer")
# 检查模型是否已下载
models_exist = (
Path(det_model_dir).exists()
and Path(rec_model_dir).exists()
and Path(cls_model_dir).exists()
)
return OCRConfig(
lang="ch",
use_angle_cls=True,
use_gpu=False,
drop_score=0.5,
det_model_dir=det_model_dir if models_exist else None,
rec_model_dir=rec_model_dir if models_exist else None,
cls_model_dir=cls_model_dir if models_exist else None,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用生命周期管理
启动时预加载 OCR 模型,关闭时清理资源
"""
# 启动: 加载 OCR 模型
print("[INFO] 正在初始化 OCR API 服务...")
ocr_config = _get_ocr_config()
if ocr_config.det_model_dir is None:
print("[WARN] 模型未在项目目录中找到")
print("[WARN] 对于 Windows 中文用户名用户,请先运行:")
print("[WARN] python download_models.py")
print("[INFO] 回退到默认 PaddleOCR 模型路径...")
pipeline_config = PipelineConfig()
print("[INFO] 正在加载 OCR 模型...")
pipeline = OCRPipeline(ocr_config, pipeline_config)
pipeline.initialize()
app.state.ocr_pipeline = pipeline
app.state.model_loaded = True
print("[INFO] OCR 模型加载完成,服务已就绪")
yield
# 关闭: 清理资源
print("[INFO] 正在关闭 OCR API 服务...")
app.state.model_loaded = False
print("[INFO] 服务已关闭")
# 创建 FastAPI 应用
app = FastAPI(
title="Vision-OCR API",
description="基于 PaddleOCR 的图片 OCR 识别服务,支持通用文字识别和快递单解析",
version="1.0.0",
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
)
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境应限制具体域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局异常处理
@app.exception_handler(OCRAPIException)
async def ocr_exception_handler(request: Request, exc: OCRAPIException):
"""处理 OCR API 自定义异常"""
return JSONResponse(
status_code=exc.status_code,
content={
"success": False,
"data": None,
"error": {
"code": type(exc).__name__,
"message": exc.message,
},
},
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""处理未捕获的异常"""
return JSONResponse(
status_code=500,
content={
"success": False,
"data": None,
"error": {
"code": "InternalServerError",
"message": f"服务器内部错误: {str(exc)}",
},
},
)
# 注册路由
app.include_router(health_router, prefix="/api/v1")
app.include_router(ocr_router, prefix="/api/v1")
@app.get("/", tags=["根路径"])
async def root():
"""API 根路径,返回基本信息"""
return {
"name": "Vision-OCR API",
"version": "1.0.0",
"docs": "/docs",
"health": "/api/v1/health",
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"api.main:app",
host="0.0.0.0",
port=8000,
reload=True,
)