import os import cv2 import numpy as np import torch import yaml from typing import Optional, Tuple, Union from io import BytesIO from PIL import Image import logging import traceback from basicsr.models import create_model from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite from basicsr.utils.options import parse # Configure logging def setup_logger(name, log_level=logging.INFO): """Set up logger.""" logger = logging.getLogger(name) logger.setLevel(log_level) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') # Console handler console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logger.addHandler(console_handler) return logger logger = setup_logger(__name__) class NAFNetDeblur: def __init__(self, config_path: str = 'options/test/REDS/NAFNet-width64.yml'): """ Initialize the NAFNet deblurring model. Args: config_path: Path to the model configuration YAML file """ try: logger.info(f"Initializing NAFNet with config: {config_path}") # Make paths relative to the module directory module_dir = os.path.dirname(os.path.abspath(__file__)) if not os.path.isabs(config_path): config_path = os.path.join(module_dir, config_path) # Check if config file exists if not os.path.exists(config_path): error_msg = f"Config file not found: {config_path}" logger.error(error_msg) raise FileNotFoundError(error_msg) # Parse configuration opt = parse(config_path, is_train=False) opt["dist"] = False # Create model logger.info("Creating model") self.model = create_model(opt) # Set device try: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Using device: {self.device}") except Exception as e: logger.warning(f"Failed to set device. Error: {str(e)}") logger.warning("Using CPU mode") self.device = torch.device('cpu') # Create directories for inputs and outputs self.inputs_dir = os.path.join(module_dir, 'inputs') self.outputs_dir = os.path.join(module_dir, 'outputs') # Ensure directories exist os.makedirs(self.inputs_dir, exist_ok=True) os.makedirs(self.outputs_dir, exist_ok=True) logger.info("Model initialized successfully") except Exception as e: logger.error(f"Failed to initialize model: {str(e)}") logger.error(traceback.format_exc()) raise def imread(self, img_path): """Read an image from file.""" img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img def img2tensor(self, img, bgr2rgb=False, float32=True): """Convert image to tensor.""" img = img.astype(np.float32) / 255.0 return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32) def deblur_image(self, image: Union[str, np.ndarray, bytes]) -> np.ndarray: """ Deblur an image. Args: image: Input image as a file path, numpy array, or bytes Returns: Deblurred image as a numpy array """ try: # Handle different input types if isinstance(image, str): # Image path logger.info(f"Loading image from path: {image}") img = self.imread(image) if img is None: raise ValueError(f"Failed to read image from {image}") elif isinstance(image, bytes): # Bytes (e.g., from file upload) logger.info("Loading image from bytes") nparr = np.frombuffer(image, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is None: # Try using PIL as a fallback pil_img = Image.open(BytesIO(image)) img = np.array(pil_img.convert('RGB')) else: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) elif isinstance(image, np.ndarray): # Already a numpy array logger.info("Processing image from numpy array") img = image.copy() if img.shape[2] == 3 and img.dtype == np.uint8: if img[0,0,0] > img[0,0,2]: # Simple BGR check img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) else: raise ValueError(f"Unsupported image type: {type(image)}") # Validate image if img is None or img.size == 0: raise ValueError("Image is empty or invalid") logger.info(f"Image shape: {img.shape}, dtype: {img.dtype}") # Ensure image has 3 channels if len(img.shape) != 3 or img.shape[2] != 3: raise ValueError(f"Image must have 3 channels, got shape {img.shape}") # Resize very large images max_dim = max(img.shape[0], img.shape[1]) if max_dim > 2000: scale_factor = 2000 / max_dim new_h = int(img.shape[0] * scale_factor) new_w = int(img.shape[1] * scale_factor) logger.warning(f"Image too large, resizing from {img.shape[:2]} to {(new_h, new_w)}") img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) # Convert to tensor logger.info("Converting image to tensor") img_tensor = self.img2tensor(img) # Process the image logger.info("Running inference with model") with torch.no_grad(): try: self.model.feed_data(data={'lq': img_tensor.unsqueeze(dim=0)}) if self.model.opt['val'].get('grids', False): self.model.grids() self.model.test() if self.model.opt['val'].get('grids', False): self.model.grids_inverse() visuals = self.model.get_current_visuals() result = tensor2img([visuals['result']]) except Exception as e: logger.error(f"Error during model inference: {str(e)}") logger.error(traceback.format_exc()) raise logger.info("Image deblurred successfully") return result except Exception as e: logger.error(f"Error in deblur_image: {str(e)}") logger.error(traceback.format_exc()) raise def save_image(self, image: np.ndarray, output_path: str) -> str: """Save an image to the given path.""" try: # Convert to BGR for OpenCV save_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Save the image if not os.path.isabs(output_path): # Use the outputs directory by default output_path = os.path.join(self.outputs_dir, output_path) # Ensure the parent directory exists os.makedirs(os.path.dirname(output_path), exist_ok=True) cv2.imwrite(output_path, save_img) logger.info(f"Image saved to {output_path}") return output_path except Exception as e: logger.error(f"Error saving image: {str(e)}") logger.error(traceback.format_exc()) raise def main(): """ Main function to test the NAFNet deblurring model. Processes all images in the inputs directory and saves results to outputs directory. """ try: # Initialize the model deblur_model = NAFNetDeblur() # Get the inputs directory inputs_dir = deblur_model.inputs_dir outputs_dir = deblur_model.outputs_dir # Check if there are any images in the inputs directory input_files = [f for f in os.listdir(inputs_dir) if os.path.isfile(os.path.join(inputs_dir, f)) and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))] if not input_files: logger.warning(f"No image files found in {inputs_dir}") print(f"No image files found in {inputs_dir}. Please add some images and try again.") return logger.info(f"Found {len(input_files)} images to process") # Process each image for input_file in input_files: try: input_path = os.path.join(inputs_dir, input_file) output_file = f"deblurred_{input_file}" output_path = os.path.join(outputs_dir, output_file) logger.info(f"Processing {input_file}...") # Deblur the image deblurred_img = deblur_model.deblur_image(input_path) # Save the result deblur_model.save_image(deblurred_img, output_path) logger.info(f"Saved result to {output_path}") except Exception as e: logger.error(f"Error processing {input_file}: {str(e)}") logger.error(traceback.format_exc()) logger.info("Processing complete!") except Exception as e: logger.error(f"Error in main function: {str(e)}") logger.error(traceback.format_exc()) if __name__ == "__main__": main()