File size: 2,321 Bytes
b7c5baf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from PIL import Image
import torch
import numpy as np
from typing import IO
import cv2
from torchvision import transforms

# Import the globally loaded models instance
from model_loader import models

class ImagePreprocessor:
    """
    Handles preprocessing of images for the FFT CNN model.
    """
    def __init__(self):
        """
        Initializes the preprocessor.
        """
        self.device = models.device
        # Define the image transformations, matching the training process
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])

    def process(self, image_file: IO) -> torch.Tensor:
        """
        Opens an image file, applies FFT, preprocesses it, and returns a tensor.

        Args:
            image_file (IO): The image file object (e.g., from a file upload).

        Returns:
            torch.Tensor: The preprocessed image as a tensor, ready for the model.
        """
        try:
            # Read the image file into a numpy array
            image_np = np.frombuffer(image_file.read(), np.uint8)
            # Decode the image as grayscale
            img = cv2.imdecode(image_np, cv2.IMREAD_GRAYSCALE)
        except Exception as e:
            print(f"Error reading or decoding image: {e}")
            raise ValueError("Invalid or corrupted image file.")

        if img is None:
            raise ValueError("Could not decode image. File may be empty or corrupted.")

        # 1. Apply Fast Fourier Transform (FFT)
        f = np.fft.fft2(img)
        fshift = np.fft.fftshift(f)
        magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)

        # Normalize the magnitude spectrum to be in the range [0, 255]
        magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX)
        magnitude_spectrum = np.uint8(magnitude_spectrum)
        
        # 2. Apply torchvision transforms
        image_tensor = self.transform(magnitude_spectrum)
        
        # Add a batch dimension and move to the correct device
        image_tensor = image_tensor.unsqueeze(0).to(self.device)
        
        return image_tensor

# Create a single instance of the preprocessor
preprocessor = ImagePreprocessor()