Image-Debluring / NAFNET /deblur_module.py
sayed99's picture
initialized both deblurer
61d360d
import os
import cv2
import numpy as np
import torch
import yaml
from typing import Optional, Tuple, Union
from io import BytesIO
from PIL import Image
import logging
import traceback
from basicsr.models import create_model
from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite
from basicsr.utils.options import parse
# Configure logging
def setup_logger(name, log_level=logging.INFO):
"""Set up logger."""
logger = logging.getLogger(name)
logger.setLevel(log_level)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
return logger
logger = setup_logger(__name__)
class NAFNetDeblur:
def __init__(self, config_path: str = 'options/test/REDS/NAFNet-width64.yml'):
"""
Initialize the NAFNet deblurring model.
Args:
config_path: Path to the model configuration YAML file
"""
try:
logger.info(f"Initializing NAFNet with config: {config_path}")
# Make paths relative to the module directory
module_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.isabs(config_path):
config_path = os.path.join(module_dir, config_path)
# Check if config file exists
if not os.path.exists(config_path):
error_msg = f"Config file not found: {config_path}"
logger.error(error_msg)
raise FileNotFoundError(error_msg)
# Parse configuration
opt = parse(config_path, is_train=False)
opt["dist"] = False
# Create model
logger.info("Creating model")
self.model = create_model(opt)
# Set device
try:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {self.device}")
except Exception as e:
logger.warning(f"Failed to set device. Error: {str(e)}")
logger.warning("Using CPU mode")
self.device = torch.device('cpu')
# Create directories for inputs and outputs
self.inputs_dir = os.path.join(module_dir, 'inputs')
self.outputs_dir = os.path.join(module_dir, 'outputs')
# Ensure directories exist
os.makedirs(self.inputs_dir, exist_ok=True)
os.makedirs(self.outputs_dir, exist_ok=True)
logger.info("Model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
logger.error(traceback.format_exc())
raise
def imread(self, img_path):
"""Read an image from file."""
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def img2tensor(self, img, bgr2rgb=False, float32=True):
"""Convert image to tensor."""
img = img.astype(np.float32) / 255.0
return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32)
def deblur_image(self, image: Union[str, np.ndarray, bytes]) -> np.ndarray:
"""
Deblur an image.
Args:
image: Input image as a file path, numpy array, or bytes
Returns:
Deblurred image as a numpy array
"""
try:
# Handle different input types
if isinstance(image, str):
# Image path
logger.info(f"Loading image from path: {image}")
img = self.imread(image)
if img is None:
raise ValueError(f"Failed to read image from {image}")
elif isinstance(image, bytes):
# Bytes (e.g., from file upload)
logger.info("Loading image from bytes")
nparr = np.frombuffer(image, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
# Try using PIL as a fallback
pil_img = Image.open(BytesIO(image))
img = np.array(pil_img.convert('RGB'))
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
elif isinstance(image, np.ndarray):
# Already a numpy array
logger.info("Processing image from numpy array")
img = image.copy()
if img.shape[2] == 3 and img.dtype == np.uint8:
if img[0,0,0] > img[0,0,2]: # Simple BGR check
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
else:
raise ValueError(f"Unsupported image type: {type(image)}")
# Validate image
if img is None or img.size == 0:
raise ValueError("Image is empty or invalid")
logger.info(f"Image shape: {img.shape}, dtype: {img.dtype}")
# Ensure image has 3 channels
if len(img.shape) != 3 or img.shape[2] != 3:
raise ValueError(f"Image must have 3 channels, got shape {img.shape}")
# Resize very large images
max_dim = max(img.shape[0], img.shape[1])
if max_dim > 2000:
scale_factor = 2000 / max_dim
new_h = int(img.shape[0] * scale_factor)
new_w = int(img.shape[1] * scale_factor)
logger.warning(f"Image too large, resizing from {img.shape[:2]} to {(new_h, new_w)}")
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
# Convert to tensor
logger.info("Converting image to tensor")
img_tensor = self.img2tensor(img)
# Process the image
logger.info("Running inference with model")
with torch.no_grad():
try:
self.model.feed_data(data={'lq': img_tensor.unsqueeze(dim=0)})
if self.model.opt['val'].get('grids', False):
self.model.grids()
self.model.test()
if self.model.opt['val'].get('grids', False):
self.model.grids_inverse()
visuals = self.model.get_current_visuals()
result = tensor2img([visuals['result']])
except Exception as e:
logger.error(f"Error during model inference: {str(e)}")
logger.error(traceback.format_exc())
raise
logger.info("Image deblurred successfully")
return result
except Exception as e:
logger.error(f"Error in deblur_image: {str(e)}")
logger.error(traceback.format_exc())
raise
def save_image(self, image: np.ndarray, output_path: str) -> str:
"""Save an image to the given path."""
try:
# Convert to BGR for OpenCV
save_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Save the image
if not os.path.isabs(output_path):
# Use the outputs directory by default
output_path = os.path.join(self.outputs_dir, output_path)
# Ensure the parent directory exists
os.makedirs(os.path.dirname(output_path), exist_ok=True)
cv2.imwrite(output_path, save_img)
logger.info(f"Image saved to {output_path}")
return output_path
except Exception as e:
logger.error(f"Error saving image: {str(e)}")
logger.error(traceback.format_exc())
raise
def main():
"""
Main function to test the NAFNet deblurring model.
Processes all images in the inputs directory and saves results to outputs directory.
"""
try:
# Initialize the model
deblur_model = NAFNetDeblur()
# Get the inputs directory
inputs_dir = deblur_model.inputs_dir
outputs_dir = deblur_model.outputs_dir
# Check if there are any images in the inputs directory
input_files = [f for f in os.listdir(inputs_dir) if os.path.isfile(os.path.join(inputs_dir, f))
and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
if not input_files:
logger.warning(f"No image files found in {inputs_dir}")
print(f"No image files found in {inputs_dir}. Please add some images and try again.")
return
logger.info(f"Found {len(input_files)} images to process")
# Process each image
for input_file in input_files:
try:
input_path = os.path.join(inputs_dir, input_file)
output_file = f"deblurred_{input_file}"
output_path = os.path.join(outputs_dir, output_file)
logger.info(f"Processing {input_file}...")
# Deblur the image
deblurred_img = deblur_model.deblur_image(input_path)
# Save the result
deblur_model.save_image(deblurred_img, output_path)
logger.info(f"Saved result to {output_path}")
except Exception as e:
logger.error(f"Error processing {input_file}: {str(e)}")
logger.error(traceback.format_exc())
logger.info("Processing complete!")
except Exception as e:
logger.error(f"Error in main function: {str(e)}")
logger.error(traceback.format_exc())
if __name__ == "__main__":
main()