import torch from pathlib import Path from huggingface_hub import hf_hub_download from model import FFTCNN # Import the model architecture class ModelLoader: """ A class to load and hold the PyTorch CNN model. """ def __init__(self, model_repo_id: str, model_filename: str): """ Initializes the ModelLoader and loads the model. Args: model_repo_id (str): The repository ID on Hugging Face. model_filename (str): The name of the model file (.pth) in the repository. """ self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") self.fft_model = self._load_fft_model(repo_id=model_repo_id, filename=model_filename) print("FFT CNN model loaded successfully.") def _load_fft_model(self, repo_id: str, filename: str): """ Downloads and loads the FFT CNN model from a Hugging Face Hub repository. Args: repo_id (str): The repository ID on Hugging Face. filename (str): The name of the model file (.pth) in the repository. Returns: The loaded PyTorch model object. """ print(f"Downloading FFT CNN model from Hugging Face repo: {repo_id}") try: # Download the model file from the Hub. It returns the cached path. model_path = hf_hub_download(repo_id=repo_id, filename=filename) print(f"Model downloaded to: {model_path}") # Initialize the model architecture model = FFTCNN() # Load the saved weights (state_dict) into the model model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device))) # Set the model to evaluation mode model.to(self.device) model.eval() return model except Exception as e: print(f"Error downloading or loading model from Hugging Face: {e}") raise # --- Global Model Instance --- MODEL_REPO_ID = 'rhnsa/real_forged_classifier' MODEL_FILENAME = 'fft_cnn_model_78.pth' models = ModelLoader(model_repo_id=MODEL_REPO_ID, model_filename=MODEL_FILENAME)