You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
195 lines
5.0 KiB
Python
195 lines
5.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Vision-OCR REST API 主入口
|
|
基于 FastAPI 的 OCR 服务
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
|
|
# 设置项目根目录和模型路径
|
|
_PROJECT_ROOT = Path(__file__).parent.parent
|
|
_MODELS_DIR = _PROJECT_ROOT / "models"
|
|
_MODELS_DIR.mkdir(exist_ok=True)
|
|
os.environ["PADDLEOCR_HOME"] = str(_MODELS_DIR)
|
|
|
|
# 将项目根目录添加到 Python 路径
|
|
sys.path.insert(0, str(_PROJECT_ROOT))
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from api.exceptions import OCRAPIException
|
|
from api.routes import health_router, ocr_router
|
|
from api.version import API_VERSION
|
|
from ocr.pipeline import OCRPipeline
|
|
from utils.config import OCRConfig, PipelineConfig
|
|
|
|
# 配置日志
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
logger = logging.getLogger("vision-ocr")
|
|
|
|
|
|
def _get_ocr_config() -> OCRConfig:
|
|
"""
|
|
获取 OCR 配置
|
|
|
|
检查模型是否存在于项目目录,如果存在则使用本地模型
|
|
"""
|
|
det_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_det_infer")
|
|
rec_model_dir = str(_MODELS_DIR / "ch_PP-OCRv4_rec_infer")
|
|
cls_model_dir = str(_MODELS_DIR / "ch_ppocr_mobile_v2.0_cls_infer")
|
|
|
|
# 检查模型是否已下载
|
|
models_exist = (
|
|
Path(det_model_dir).exists()
|
|
and Path(rec_model_dir).exists()
|
|
and Path(cls_model_dir).exists()
|
|
)
|
|
|
|
return OCRConfig(
|
|
lang="ch",
|
|
use_angle_cls=True,
|
|
use_gpu=False,
|
|
drop_score=0.5,
|
|
det_model_dir=det_model_dir if models_exist else None,
|
|
rec_model_dir=rec_model_dir if models_exist else None,
|
|
cls_model_dir=cls_model_dir if models_exist else None,
|
|
)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""
|
|
应用生命周期管理
|
|
|
|
启动时预加载 OCR 模型,关闭时清理资源
|
|
"""
|
|
# 启动: 加载 OCR 模型
|
|
logger.info("正在初始化 OCR API 服务...")
|
|
|
|
ocr_config = _get_ocr_config()
|
|
if ocr_config.det_model_dir is None:
|
|
logger.warning("模型未在项目目录中找到")
|
|
logger.warning("对于 Windows 中文用户名用户,请先运行: python download_models.py")
|
|
logger.info("回退到默认 PaddleOCR 模型路径...")
|
|
|
|
pipeline_config = PipelineConfig()
|
|
|
|
logger.info("正在加载 OCR 模型...")
|
|
pipeline = OCRPipeline(ocr_config, pipeline_config)
|
|
pipeline.initialize()
|
|
|
|
app.state.ocr_pipeline = pipeline
|
|
app.state.model_loaded = True
|
|
logger.info("OCR 模型加载完成,服务已就绪")
|
|
|
|
yield
|
|
|
|
# 关闭: 清理资源
|
|
logger.info("正在关闭 OCR API 服务...")
|
|
app.state.model_loaded = False
|
|
logger.info("服务已关闭")
|
|
|
|
|
|
# 创建 FastAPI 应用
|
|
app = FastAPI(
|
|
title="Vision-OCR API",
|
|
description="基于 PaddleOCR 的图片 OCR 识别服务,支持通用文字识别和快递单解析",
|
|
version=API_VERSION,
|
|
lifespan=lifespan,
|
|
docs_url="/docs",
|
|
redoc_url="/redoc",
|
|
openapi_url="/openapi.json",
|
|
)
|
|
|
|
# 配置 CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # 生产环境应限制具体域名
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# 全局异常处理
|
|
@app.exception_handler(OCRAPIException)
|
|
async def ocr_exception_handler(request: Request, exc: OCRAPIException):
|
|
"""处理 OCR API 自定义异常"""
|
|
logger.warning(
|
|
"OCR API 异常: %s - %s (路径: %s)",
|
|
type(exc).__name__,
|
|
exc.message,
|
|
request.url.path,
|
|
)
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={
|
|
"success": False,
|
|
"data": None,
|
|
"error": {
|
|
"code": type(exc).__name__,
|
|
"message": exc.message,
|
|
},
|
|
},
|
|
)
|
|
|
|
|
|
@app.exception_handler(Exception)
|
|
async def general_exception_handler(request: Request, exc: Exception):
|
|
"""处理未捕获的异常"""
|
|
logger.error(
|
|
"未捕获的异常: %s - %s (路径: %s)",
|
|
type(exc).__name__,
|
|
str(exc),
|
|
request.url.path,
|
|
exc_info=True,
|
|
)
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={
|
|
"success": False,
|
|
"data": None,
|
|
"error": {
|
|
"code": "InternalServerError",
|
|
"message": f"服务器内部错误: {str(exc)}",
|
|
},
|
|
},
|
|
)
|
|
|
|
|
|
# 注册路由
|
|
app.include_router(health_router, prefix="/api/v1")
|
|
app.include_router(ocr_router, prefix="/api/v1")
|
|
|
|
|
|
@app.get("/", tags=["根路径"])
|
|
async def root():
|
|
"""API 根路径,返回基本信息"""
|
|
return {
|
|
"name": "Vision-OCR API",
|
|
"version": API_VERSION,
|
|
"docs": "/docs",
|
|
"health": "/api/v1/health",
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"api.main:app",
|
|
host="0.0.0.0",
|
|
port=8000,
|
|
reload=True,
|
|
)
|