You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
3.1 KiB
Python
135 lines
3.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
pytest 配置和共享 fixtures
|
|
"""
|
|
|
|
import base64
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Generator
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
# 将项目根目录添加到 Python 路径
|
|
PROJECT_ROOT = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
# 设置模型路径环境变量
|
|
os.environ["PADDLEOCR_HOME"] = str(PROJECT_ROOT / "models")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def mock_ocr_pipeline():
|
|
"""
|
|
创建模拟的 OCR Pipeline
|
|
|
|
避免在测试中加载真实的 OCR 模型
|
|
"""
|
|
from ocr.engine import TextBlock
|
|
from ocr.pipeline import OCRResult
|
|
|
|
mock_pipeline = MagicMock()
|
|
|
|
# 模拟 OCR 结果
|
|
mock_text_blocks = [
|
|
TextBlock(
|
|
text="测试文本1",
|
|
confidence=0.95,
|
|
bbox=[[10, 10], [100, 10], [100, 30], [10, 30]],
|
|
bbox_offset=(0, 0),
|
|
),
|
|
TextBlock(
|
|
text="测试文本2",
|
|
confidence=0.88,
|
|
bbox=[[10, 40], [150, 40], [150, 60], [10, 60]],
|
|
bbox_offset=(0, 0),
|
|
),
|
|
]
|
|
|
|
mock_result = OCRResult(
|
|
image_index=1,
|
|
image_path=None,
|
|
timestamp=1704672000.0,
|
|
processing_time_ms=45.6,
|
|
text_blocks=mock_text_blocks,
|
|
roi_applied=False,
|
|
roi_rect=None,
|
|
)
|
|
|
|
mock_pipeline.process.return_value = mock_result
|
|
mock_pipeline.initialize.return_value = None
|
|
mock_pipeline._pipeline_config = MagicMock()
|
|
|
|
return mock_pipeline
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def test_client(mock_ocr_pipeline) -> Generator[TestClient, None, None]:
|
|
"""
|
|
创建测试客户端
|
|
|
|
使用模拟的 OCR Pipeline 避免加载真实模型
|
|
"""
|
|
# 延迟导入以确保环境变量已设置
|
|
from api.main import app
|
|
|
|
# 设置模拟的 pipeline
|
|
app.state.ocr_pipeline = mock_ocr_pipeline
|
|
app.state.model_loaded = True
|
|
|
|
with TestClient(app) as client:
|
|
yield client
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_image_bytes() -> bytes:
|
|
"""
|
|
创建测试用的图片字节数据
|
|
|
|
生成一个简单的 100x100 白色 JPEG 图片
|
|
"""
|
|
import cv2
|
|
|
|
# 创建白色图片
|
|
image = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
|
|
|
# 编码为 JPEG
|
|
success, encoded = cv2.imencode(".jpg", image)
|
|
assert success, "图片编码失败"
|
|
|
|
return encoded.tobytes()
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_image_base64(sample_image_bytes) -> str:
|
|
"""
|
|
创建测试用的 Base64 编码图片
|
|
"""
|
|
return base64.b64encode(sample_image_bytes).decode("utf-8")
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_png_bytes() -> bytes:
|
|
"""
|
|
创建测试用的 PNG 图片字节数据
|
|
"""
|
|
import cv2
|
|
|
|
image = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
|
success, encoded = cv2.imencode(".png", image)
|
|
assert success, "PNG 编码失败"
|
|
|
|
return encoded.tobytes()
|
|
|
|
|
|
@pytest.fixture
|
|
def invalid_file_bytes() -> bytes:
|
|
"""
|
|
创建无效的文件字节数据 (非图片)
|
|
"""
|
|
return b"This is not an image file content"
|