""" Colorize model wrapper using FastAI GAN Colorization Model Hammad712/GAN-Colorization-Model """ from __future__ import annotations import logging import os from typing import Tuple # Ensure cache directory is set before any HF imports # (main.py should have set these, but ensure they're set here too) cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache") os.environ["HF_HOME"] = cache_dir os.environ["TRANSFORMERS_CACHE"] = cache_dir os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir os.environ["HF_HUB_CACHE"] = cache_dir os.environ["XDG_CACHE_HOME"] = cache_dir import torch from PIL import Image from fastai.vision.all import * from huggingface_hub import from_pretrained_fastai, hf_hub_download, list_repo_files from app.config import settings logger = logging.getLogger(__name__) def _ensure_cache_dir() -> str: cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache") try: os.makedirs(cache_dir, exist_ok=True) except Exception as exc: logger.warning("Could not create cache directory %s: %s", cache_dir, exc) # Ensure all cache env vars point to this directory os.environ["HF_HOME"] = cache_dir os.environ["TRANSFORMERS_CACHE"] = cache_dir os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir os.environ["HF_HUB_CACHE"] = cache_dir os.environ["XDG_CACHE_HOME"] = cache_dir return cache_dir class ColorizeModel: """Colorization model using FastAI GAN model.""" def __init__(self, model_id: str | None = None) -> None: self.cache_dir = _ensure_cache_dir() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.environ.setdefault("OMP_NUM_THREADS", "1") # Use FastAI model ID from config or default self.model_id = model_id or settings.MODEL_ID self.output_caption = getattr(settings, "FASTAI_OUTPUT_CAPTION", "Colorized using GAN-Colorization-Model") logger.info("Loading FastAI GAN Colorization model: %s", self.model_id) try: # Try using from_pretrained_fastai first try: self.learn = from_pretrained_fastai(self.model_id) logger.info("FastAI GAN Colorization model loaded successfully via from_pretrained_fastai") except Exception as e1: logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1)) # Fallback: manually download and load the model file # First, list files in the repository to find the actual model file hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") try: repo_files = list_repo_files(repo_id=self.model_id, token=hf_token) logger.info("Repository files: %s", repo_files) # Look for .pkl files (FastAI) or .pt files (PyTorch) pkl_files = [f for f in repo_files if f.endswith('.pkl')] pt_files = [f for f in repo_files if f.endswith('.pt')] if pkl_files: model_filenames = pkl_files logger.info("Found .pkl files in repository: %s", pkl_files) model_type = "fastai" elif pt_files: model_filenames = pt_files logger.info("Found .pt files in repository: %s", pt_files) model_type = "pytorch" else: # Fallback to common filenames model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"] model_type = "fastai" # Default assumption except Exception as list_err: logger.warning("Could not list repository files: %s. Trying common filenames...", str(list_err)) # Fallback to common filenames model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"] model_type = "fastai" model_path = None for filename in model_filenames: try: model_path = hf_hub_download( repo_id=self.model_id, filename=filename, cache_dir=self.cache_dir, token=hf_token ) logger.info("Found model file: %s", filename) # Determine model type from extension if filename.endswith('.pt'): model_type = "pytorch" elif filename.endswith('.pkl'): model_type = "fastai" break except Exception as dl_err: logger.debug("Failed to download %s: %s", filename, str(dl_err)) continue if model_path and os.path.exists(model_path): if model_type == "pytorch": # Load PyTorch model - this is a GAN generator logger.info("Loading PyTorch model from: %s", model_path) # Note: This requires knowing the model architecture # For now, we'll try to load it and see if it works logger.warning("PyTorch model loading not fully implemented. This model may not work correctly.") raise RuntimeError( f"Repository '{self.model_id}' contains a PyTorch model (generator.pt), " f"not a FastAI model. FastAI models must be .pkl files created with FastAI's export. " f"Please use a FastAI-compatible colorization model, or switch to a different model backend." ) else: # Load the model using FastAI's load_learner logger.info("Loading FastAI model from: %s", model_path) self.learn = load_learner(model_path) logger.info("FastAI GAN Colorization model loaded successfully from %s", model_path) else: # If no model file found, raise error with more details raise RuntimeError( f"Could not find model file in repository '{self.model_id}'. " f"Tried: {', '.join(model_filenames)}. " f"Original error: {str(e1)}" ) except Exception as e: error_msg = ( f"Failed to load FastAI model '{self.model_id}'. " f"Error: {str(e)}\n" f"Please check the MODEL_ID environment variable. " f"Default model: 'Hammad712/GAN-Colorization-Model'" ) logger.error(error_msg) raise RuntimeError(error_msg) from e def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]: """ Colorize a grayscale or color image using FastAI GAN model. Args: image: PIL Image (grayscale or color) num_inference_steps: Ignored for FastAI model (kept for API compatibility) Returns: Tuple of (colorized PIL Image, caption string) """ try: original_size = image.size # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") # FastAI predict expects a PIL Image logger.info("Running FastAI GAN colorization...") # Use the model's predict method # FastAI predict for image models typically returns the output image directly # or as the first element of a tuple prediction = self.learn.predict(image) # Extract the colorized image from prediction # Handle different return types from FastAI if isinstance(prediction, (list, tuple)): # If tuple/list, first element is usually the prediction colorized = prediction[0] if len(prediction) > 0 else image else: # Direct return colorized = prediction # Ensure we have a PIL Image if not isinstance(colorized, Image.Image): # If it's a tensor, convert to PIL if isinstance(colorized, torch.Tensor): # Handle tensor conversion if colorized.dim() == 4: colorized = colorized[0] # Remove batch dimension if colorized.dim() == 3: # Convert CHW to HWC and denormalize if needed colorized = colorized.permute(1, 2, 0).cpu() # Clamp values to [0, 1] if float, or [0, 255] if uint8 if colorized.dtype == torch.float32 or colorized.dtype == torch.float16: colorized = torch.clamp(colorized, 0, 1) colorized = (colorized * 255).byte() colorized = Image.fromarray(colorized.numpy(), 'RGB') else: raise ValueError(f"Unexpected tensor shape: {colorized.shape}") else: raise ValueError(f"Unexpected prediction type: {type(colorized)}") # Ensure RGB mode if colorized.mode != "RGB": colorized = colorized.convert("RGB") # Resize back to original size if needed if colorized.size != original_size: colorized = colorized.resize(original_size, Image.Resampling.LANCZOS) logger.info("Colorization completed successfully") return colorized, self.output_caption except Exception as e: logger.error("Error during colorization: %s", str(e)) raise RuntimeError(f"Colorization failed: {str(e)}") from e