Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| # Import the globally loaded models instance | |
| from model_loader import models | |
| class Interferencer: | |
| """ | |
| Performs inference using the FFT CNN model. | |
| """ | |
| def __init__(self): | |
| """ | |
| Initializes the interferencer with the loaded model. | |
| """ | |
| self.fft_model = models.fft_model | |
| def predict(self, image_tensor: torch.Tensor) -> dict: | |
| """ | |
| Takes a preprocessed image tensor and returns the classification result. | |
| Args: | |
| image_tensor (torch.Tensor): The preprocessed image tensor. | |
| Returns: | |
| dict: A dictionary containing the classification label and confidence score. | |
| """ | |
| # 1. Get model outputs (logits) | |
| outputs = self.fft_model(image_tensor) | |
| # 2. Apply softmax to get probabilities | |
| probabilities = F.softmax(outputs, dim=1) | |
| # 3. Get the confidence and the predicted class index | |
| confidence, predicted_idx = torch.max(probabilities, 1) | |
| prediction = predicted_idx.item() | |
| # 4. Map the prediction to a human-readable label | |
| # Ensure this mapping matches the labels used during training | |
| # Typically: 0 -> fake, 1 -> real | |
| label_map = {0: 'fake', 1: 'real'} | |
| classification_label = label_map.get(prediction, "unknown") | |
| return { | |
| "classification": classification_label, | |
| "confidence": confidence.item() | |
| } | |
| # Create a single instance of the interferencer | |
| interferencer = Interferencer() | |