File size: 3,717 Bytes
6ab520d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.auth.jwt import get_current_user
from fastapi import UploadFile
from httpx import AsyncClient
from main import app
# 💡 NOTE Run tests with: pytest tests/test_router_model.py -v
@pytest.mark.asyncio
async def test_predict():
mock_file = AsyncMock(spec=UploadFile)
mock_file.filename = "test_image.png"
mock_file.read = AsyncMock(return_value=b"fake-image-data")
mock_user = MagicMock()
mock_user.id = 1
mock_current_user = MagicMock()
mock_current_user.return_value = "testtoken"
app.dependency_overrides[get_current_user] = lambda: mock_current_user
with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"):
with patch(
"app.model.router.model_predict", new_callable=AsyncMock
) as mock_model_predict:
with patch("app.model.router.os.path.exists", return_value=False):
mock_model_predict.return_value = ("cat", 0.95)
with patch("builtins.open", new_callable=MagicMock):
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.post(
"/model/predict",
files={
"file": (
"test_image.png",
mock_file.read.return_value,
"image/png",
)
},
headers={"Authorization": "Bearer testtoken"},
)
assert response.status_code == 200
response_data = response.json()
assert response_data["success"] is True
assert response_data["prediction"] == "cat"
assert response_data["score"] == 0.95
assert response_data["image_file_name"] == "fakehash123"
@pytest.mark.asyncio
async def test_predict_fails_bad_extension():
mock_file = AsyncMock(spec=UploadFile)
mock_file.filename = "test_image.png"
mock_file.read = AsyncMock(return_value=b"fake-image-data")
mock_user = MagicMock()
mock_user.id = 1
mock_current_user = MagicMock()
mock_current_user.return_value = "testtoken"
app.dependency_overrides[get_current_user] = lambda: mock_current_user
with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"):
with patch(
"app.model.router.model_predict", new_callable=AsyncMock
) as mock_model_predict:
with patch("app.model.router.os.path.exists", return_value=False):
mock_model_predict.return_value = ("cat", 0.95)
with patch("builtins.open", new_callable=MagicMock):
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.post(
"/model/predict",
files={
"file": (
"test_image.pdf",
mock_file.read.return_value,
"image/png",
)
},
headers={"Authorization": "Bearer testtoken"},
)
assert response.status_code == 400
assert response.json() == {
"detail": "File type is not supported."
}
|