|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
import yaml |
|
|
from typing import Optional, Tuple, Union |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import logging |
|
|
import traceback |
|
|
|
|
|
from basicsr.models import create_model |
|
|
from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite |
|
|
from basicsr.utils.options import parse |
|
|
|
|
|
|
|
|
def setup_logger(name, log_level=logging.INFO): |
|
|
"""Set up logger.""" |
|
|
logger = logging.getLogger(name) |
|
|
logger.setLevel(log_level) |
|
|
|
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
|
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
|
console_handler.setFormatter(formatter) |
|
|
logger.addHandler(console_handler) |
|
|
|
|
|
return logger |
|
|
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
class NAFNetDeblur: |
|
|
def __init__(self, config_path: str = 'options/test/REDS/NAFNet-width64.yml'): |
|
|
""" |
|
|
Initialize the NAFNet deblurring model. |
|
|
|
|
|
Args: |
|
|
config_path: Path to the model configuration YAML file |
|
|
""" |
|
|
try: |
|
|
logger.info(f"Initializing NAFNet with config: {config_path}") |
|
|
|
|
|
module_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
if not os.path.isabs(config_path): |
|
|
config_path = os.path.join(module_dir, config_path) |
|
|
|
|
|
|
|
|
if not os.path.exists(config_path): |
|
|
error_msg = f"Config file not found: {config_path}" |
|
|
logger.error(error_msg) |
|
|
raise FileNotFoundError(error_msg) |
|
|
|
|
|
|
|
|
opt = parse(config_path, is_train=False) |
|
|
opt["dist"] = False |
|
|
|
|
|
|
|
|
logger.info("Creating model") |
|
|
self.model = create_model(opt) |
|
|
|
|
|
|
|
|
try: |
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
logger.info(f"Using device: {self.device}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to set device. Error: {str(e)}") |
|
|
logger.warning("Using CPU mode") |
|
|
self.device = torch.device('cpu') |
|
|
|
|
|
|
|
|
self.inputs_dir = os.path.join(module_dir, 'inputs') |
|
|
self.outputs_dir = os.path.join(module_dir, 'outputs') |
|
|
|
|
|
|
|
|
os.makedirs(self.inputs_dir, exist_ok=True) |
|
|
os.makedirs(self.outputs_dir, exist_ok=True) |
|
|
|
|
|
logger.info("Model initialized successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize model: {str(e)}") |
|
|
logger.error(traceback.format_exc()) |
|
|
raise |
|
|
|
|
|
def imread(self, img_path): |
|
|
"""Read an image from file.""" |
|
|
img = cv2.imread(img_path) |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
return img |
|
|
|
|
|
def img2tensor(self, img, bgr2rgb=False, float32=True): |
|
|
"""Convert image to tensor.""" |
|
|
img = img.astype(np.float32) / 255.0 |
|
|
return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32) |
|
|
|
|
|
def deblur_image(self, image: Union[str, np.ndarray, bytes]) -> np.ndarray: |
|
|
""" |
|
|
Deblur an image. |
|
|
|
|
|
Args: |
|
|
image: Input image as a file path, numpy array, or bytes |
|
|
|
|
|
Returns: |
|
|
Deblurred image as a numpy array |
|
|
""" |
|
|
try: |
|
|
|
|
|
if isinstance(image, str): |
|
|
|
|
|
logger.info(f"Loading image from path: {image}") |
|
|
img = self.imread(image) |
|
|
if img is None: |
|
|
raise ValueError(f"Failed to read image from {image}") |
|
|
elif isinstance(image, bytes): |
|
|
|
|
|
logger.info("Loading image from bytes") |
|
|
nparr = np.frombuffer(image, np.uint8) |
|
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
if img is None: |
|
|
|
|
|
pil_img = Image.open(BytesIO(image)) |
|
|
img = np.array(pil_img.convert('RGB')) |
|
|
else: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
elif isinstance(image, np.ndarray): |
|
|
|
|
|
logger.info("Processing image from numpy array") |
|
|
img = image.copy() |
|
|
if img.shape[2] == 3 and img.dtype == np.uint8: |
|
|
if img[0,0,0] > img[0,0,2]: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
else: |
|
|
raise ValueError(f"Unsupported image type: {type(image)}") |
|
|
|
|
|
|
|
|
if img is None or img.size == 0: |
|
|
raise ValueError("Image is empty or invalid") |
|
|
|
|
|
logger.info(f"Image shape: {img.shape}, dtype: {img.dtype}") |
|
|
|
|
|
|
|
|
if len(img.shape) != 3 or img.shape[2] != 3: |
|
|
raise ValueError(f"Image must have 3 channels, got shape {img.shape}") |
|
|
|
|
|
|
|
|
max_dim = max(img.shape[0], img.shape[1]) |
|
|
if max_dim > 2000: |
|
|
scale_factor = 2000 / max_dim |
|
|
new_h = int(img.shape[0] * scale_factor) |
|
|
new_w = int(img.shape[1] * scale_factor) |
|
|
logger.warning(f"Image too large, resizing from {img.shape[:2]} to {(new_h, new_w)}") |
|
|
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
|
|
|
logger.info("Converting image to tensor") |
|
|
img_tensor = self.img2tensor(img) |
|
|
|
|
|
|
|
|
logger.info("Running inference with model") |
|
|
with torch.no_grad(): |
|
|
try: |
|
|
self.model.feed_data(data={'lq': img_tensor.unsqueeze(dim=0)}) |
|
|
|
|
|
if self.model.opt['val'].get('grids', False): |
|
|
self.model.grids() |
|
|
|
|
|
self.model.test() |
|
|
|
|
|
if self.model.opt['val'].get('grids', False): |
|
|
self.model.grids_inverse() |
|
|
|
|
|
visuals = self.model.get_current_visuals() |
|
|
result = tensor2img([visuals['result']]) |
|
|
except Exception as e: |
|
|
logger.error(f"Error during model inference: {str(e)}") |
|
|
logger.error(traceback.format_exc()) |
|
|
raise |
|
|
|
|
|
logger.info("Image deblurred successfully") |
|
|
return result |
|
|
except Exception as e: |
|
|
logger.error(f"Error in deblur_image: {str(e)}") |
|
|
logger.error(traceback.format_exc()) |
|
|
raise |
|
|
|
|
|
def save_image(self, image: np.ndarray, output_path: str) -> str: |
|
|
"""Save an image to the given path.""" |
|
|
try: |
|
|
|
|
|
save_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
if not os.path.isabs(output_path): |
|
|
|
|
|
output_path = os.path.join(self.outputs_dir, output_path) |
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
|
|
|
cv2.imwrite(output_path, save_img) |
|
|
logger.info(f"Image saved to {output_path}") |
|
|
return output_path |
|
|
except Exception as e: |
|
|
logger.error(f"Error saving image: {str(e)}") |
|
|
logger.error(traceback.format_exc()) |
|
|
raise |
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Main function to test the NAFNet deblurring model. |
|
|
Processes all images in the inputs directory and saves results to outputs directory. |
|
|
""" |
|
|
try: |
|
|
|
|
|
deblur_model = NAFNetDeblur() |
|
|
|
|
|
|
|
|
inputs_dir = deblur_model.inputs_dir |
|
|
outputs_dir = deblur_model.outputs_dir |
|
|
|
|
|
|
|
|
input_files = [f for f in os.listdir(inputs_dir) if os.path.isfile(os.path.join(inputs_dir, f)) |
|
|
and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))] |
|
|
|
|
|
if not input_files: |
|
|
logger.warning(f"No image files found in {inputs_dir}") |
|
|
print(f"No image files found in {inputs_dir}. Please add some images and try again.") |
|
|
return |
|
|
|
|
|
logger.info(f"Found {len(input_files)} images to process") |
|
|
|
|
|
|
|
|
for input_file in input_files: |
|
|
try: |
|
|
input_path = os.path.join(inputs_dir, input_file) |
|
|
output_file = f"deblurred_{input_file}" |
|
|
output_path = os.path.join(outputs_dir, output_file) |
|
|
|
|
|
logger.info(f"Processing {input_file}...") |
|
|
|
|
|
|
|
|
deblurred_img = deblur_model.deblur_image(input_path) |
|
|
|
|
|
|
|
|
deblur_model.save_image(deblurred_img, output_path) |
|
|
|
|
|
logger.info(f"Saved result to {output_path}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing {input_file}: {str(e)}") |
|
|
logger.error(traceback.format_exc()) |
|
|
|
|
|
logger.info("Processing complete!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in main function: {str(e)}") |
|
|
logger.error(traceback.format_exc()) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |