File size: 3,717 Bytes
6ab520d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from app.auth.jwt import get_current_user
from fastapi import UploadFile
from httpx import AsyncClient
from main import app

# 💡 NOTE Run tests with: pytest tests/test_router_model.py -v


@pytest.mark.asyncio
async def test_predict():
    mock_file = AsyncMock(spec=UploadFile)
    mock_file.filename = "test_image.png"
    mock_file.read = AsyncMock(return_value=b"fake-image-data")

    mock_user = MagicMock()
    mock_user.id = 1

    mock_current_user = MagicMock()
    mock_current_user.return_value = "testtoken"

    app.dependency_overrides[get_current_user] = lambda: mock_current_user

    with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"):
        with patch(
            "app.model.router.model_predict", new_callable=AsyncMock
        ) as mock_model_predict:
            with patch("app.model.router.os.path.exists", return_value=False):
                mock_model_predict.return_value = ("cat", 0.95)
                with patch("builtins.open", new_callable=MagicMock):
                    async with AsyncClient(app=app, base_url="http://test") as ac:
                        response = await ac.post(
                            "/model/predict",
                            files={
                                "file": (
                                    "test_image.png",
                                    mock_file.read.return_value,
                                    "image/png",
                                )
                            },
                            headers={"Authorization": "Bearer testtoken"},
                        )

                        assert response.status_code == 200

                        response_data = response.json()
                        assert response_data["success"] is True
                        assert response_data["prediction"] == "cat"
                        assert response_data["score"] == 0.95
                        assert response_data["image_file_name"] == "fakehash123"


@pytest.mark.asyncio
async def test_predict_fails_bad_extension():
    mock_file = AsyncMock(spec=UploadFile)
    mock_file.filename = "test_image.png"
    mock_file.read = AsyncMock(return_value=b"fake-image-data")

    mock_user = MagicMock()
    mock_user.id = 1

    mock_current_user = MagicMock()
    mock_current_user.return_value = "testtoken"

    app.dependency_overrides[get_current_user] = lambda: mock_current_user

    with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"):
        with patch(
            "app.model.router.model_predict", new_callable=AsyncMock
        ) as mock_model_predict:
            with patch("app.model.router.os.path.exists", return_value=False):
                mock_model_predict.return_value = ("cat", 0.95)
                with patch("builtins.open", new_callable=MagicMock):
                    async with AsyncClient(app=app, base_url="http://test") as ac:
                        response = await ac.post(
                            "/model/predict",
                            files={
                                "file": (
                                    "test_image.pdf",
                                    mock_file.read.return_value,
                                    "image/png",
                                )
                            },
                            headers={"Authorization": "Bearer testtoken"},
                        )

                        assert response.status_code == 400
                        assert response.json() == {
                            "detail": "File type is not supported."
                        }