LogicGoInfotechSpaces's picture
Detect PyTorch models and provide clear error message - Repository contains generator.pt (PyTorch) not FastAI model
0454a91
"""
Colorize model wrapper using FastAI GAN Colorization Model
Hammad712/GAN-Colorization-Model
"""
from __future__ import annotations
import logging
import os
from typing import Tuple
# Ensure cache directory is set before any HF imports
# (main.py should have set these, but ensure they're set here too)
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir
import torch
from PIL import Image
from fastai.vision.all import *
from huggingface_hub import from_pretrained_fastai, hf_hub_download, list_repo_files
from app.config import settings
logger = logging.getLogger(__name__)
def _ensure_cache_dir() -> str:
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
try:
os.makedirs(cache_dir, exist_ok=True)
except Exception as exc:
logger.warning("Could not create cache directory %s: %s", cache_dir, exc)
# Ensure all cache env vars point to this directory
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir
return cache_dir
class ColorizeModel:
"""Colorization model using FastAI GAN model."""
def __init__(self, model_id: str | None = None) -> None:
self.cache_dir = _ensure_cache_dir()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ.setdefault("OMP_NUM_THREADS", "1")
# Use FastAI model ID from config or default
self.model_id = model_id or settings.MODEL_ID
self.output_caption = getattr(settings, "FASTAI_OUTPUT_CAPTION", "Colorized using GAN-Colorization-Model")
logger.info("Loading FastAI GAN Colorization model: %s", self.model_id)
try:
# Try using from_pretrained_fastai first
try:
self.learn = from_pretrained_fastai(self.model_id)
logger.info("FastAI GAN Colorization model loaded successfully via from_pretrained_fastai")
except Exception as e1:
logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1))
# Fallback: manually download and load the model file
# First, list files in the repository to find the actual model file
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
try:
repo_files = list_repo_files(repo_id=self.model_id, token=hf_token)
logger.info("Repository files: %s", repo_files)
# Look for .pkl files (FastAI) or .pt files (PyTorch)
pkl_files = [f for f in repo_files if f.endswith('.pkl')]
pt_files = [f for f in repo_files if f.endswith('.pt')]
if pkl_files:
model_filenames = pkl_files
logger.info("Found .pkl files in repository: %s", pkl_files)
model_type = "fastai"
elif pt_files:
model_filenames = pt_files
logger.info("Found .pt files in repository: %s", pt_files)
model_type = "pytorch"
else:
# Fallback to common filenames
model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
model_type = "fastai" # Default assumption
except Exception as list_err:
logger.warning("Could not list repository files: %s. Trying common filenames...", str(list_err))
# Fallback to common filenames
model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
model_type = "fastai"
model_path = None
for filename in model_filenames:
try:
model_path = hf_hub_download(
repo_id=self.model_id,
filename=filename,
cache_dir=self.cache_dir,
token=hf_token
)
logger.info("Found model file: %s", filename)
# Determine model type from extension
if filename.endswith('.pt'):
model_type = "pytorch"
elif filename.endswith('.pkl'):
model_type = "fastai"
break
except Exception as dl_err:
logger.debug("Failed to download %s: %s", filename, str(dl_err))
continue
if model_path and os.path.exists(model_path):
if model_type == "pytorch":
# Load PyTorch model - this is a GAN generator
logger.info("Loading PyTorch model from: %s", model_path)
# Note: This requires knowing the model architecture
# For now, we'll try to load it and see if it works
logger.warning("PyTorch model loading not fully implemented. This model may not work correctly.")
raise RuntimeError(
f"Repository '{self.model_id}' contains a PyTorch model (generator.pt), "
f"not a FastAI model. FastAI models must be .pkl files created with FastAI's export. "
f"Please use a FastAI-compatible colorization model, or switch to a different model backend."
)
else:
# Load the model using FastAI's load_learner
logger.info("Loading FastAI model from: %s", model_path)
self.learn = load_learner(model_path)
logger.info("FastAI GAN Colorization model loaded successfully from %s", model_path)
else:
# If no model file found, raise error with more details
raise RuntimeError(
f"Could not find model file in repository '{self.model_id}'. "
f"Tried: {', '.join(model_filenames)}. "
f"Original error: {str(e1)}"
)
except Exception as e:
error_msg = (
f"Failed to load FastAI model '{self.model_id}'. "
f"Error: {str(e)}\n"
f"Please check the MODEL_ID environment variable. "
f"Default model: 'Hammad712/GAN-Colorization-Model'"
)
logger.error(error_msg)
raise RuntimeError(error_msg) from e
def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
"""
Colorize a grayscale or color image using FastAI GAN model.
Args:
image: PIL Image (grayscale or color)
num_inference_steps: Ignored for FastAI model (kept for API compatibility)
Returns:
Tuple of (colorized PIL Image, caption string)
"""
try:
original_size = image.size
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# FastAI predict expects a PIL Image
logger.info("Running FastAI GAN colorization...")
# Use the model's predict method
# FastAI predict for image models typically returns the output image directly
# or as the first element of a tuple
prediction = self.learn.predict(image)
# Extract the colorized image from prediction
# Handle different return types from FastAI
if isinstance(prediction, (list, tuple)):
# If tuple/list, first element is usually the prediction
colorized = prediction[0] if len(prediction) > 0 else image
else:
# Direct return
colorized = prediction
# Ensure we have a PIL Image
if not isinstance(colorized, Image.Image):
# If it's a tensor, convert to PIL
if isinstance(colorized, torch.Tensor):
# Handle tensor conversion
if colorized.dim() == 4:
colorized = colorized[0] # Remove batch dimension
if colorized.dim() == 3:
# Convert CHW to HWC and denormalize if needed
colorized = colorized.permute(1, 2, 0).cpu()
# Clamp values to [0, 1] if float, or [0, 255] if uint8
if colorized.dtype == torch.float32 or colorized.dtype == torch.float16:
colorized = torch.clamp(colorized, 0, 1)
colorized = (colorized * 255).byte()
colorized = Image.fromarray(colorized.numpy(), 'RGB')
else:
raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
else:
raise ValueError(f"Unexpected prediction type: {type(colorized)}")
# Ensure RGB mode
if colorized.mode != "RGB":
colorized = colorized.convert("RGB")
# Resize back to original size if needed
if colorized.size != original_size:
colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
logger.info("Colorization completed successfully")
return colorized, self.output_caption
except Exception as e:
logger.error("Error during colorization: %s", str(e))
raise RuntimeError(f"Colorization failed: {str(e)}") from e