feat(auth): accept Firebase Auth id_token (Authorization Bearer) in addition to App Check; add Postman collection and test script; default MODEL_ID to ControlNet color
2ae242d
| """ | |
| ColorizeNet model wrapper for image colorization | |
| """ | |
| import logging | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline | |
| from diffusers.utils import load_image | |
| from transformers import pipeline | |
| from huggingface_hub import hf_hub_download | |
| from app.config import settings | |
| logger = logging.getLogger(__name__) | |
| class ColorizeModel: | |
| """Wrapper for ColorizeNet model""" | |
| def __init__(self, model_id: str | None = None): | |
| """ | |
| Initialize the ColorizeNet model | |
| Args: | |
| model_id: Hugging Face model ID for ColorizeNet | |
| """ | |
| if model_id is None: | |
| model_id = settings.MODEL_ID | |
| self.model_id = model_id | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info("Using device: %s", self.device) | |
| self.dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| try: | |
| # Try loading as ControlNet with Stable Diffusion | |
| logger.info("Attempting to load model as ControlNet: %s", self.model_id) | |
| try: | |
| # Load ControlNet model | |
| self.controlnet = ControlNetModel.from_pretrained( | |
| self.model_id, | |
| torch_dtype=self.dtype | |
| ) | |
| # Try SDXL first, fallback to SD 1.5 | |
| try: | |
| self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| controlnet=self.controlnet, | |
| torch_dtype=self.dtype, | |
| safety_checker=None, | |
| requires_safety_checker=False | |
| ) | |
| logger.info("Loaded with SDXL base model") | |
| except: | |
| self.pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| controlnet=self.controlnet, | |
| torch_dtype=self.dtype, | |
| safety_checker=None, | |
| requires_safety_checker=False | |
| ) | |
| logger.info("Loaded with SD 1.5 base model") | |
| self.pipe.to(self.device) | |
| # Enable memory efficient attention if available | |
| if hasattr(self.pipe, "enable_xformers_memory_efficient_attention"): | |
| try: | |
| self.pipe.enable_xformers_memory_efficient_attention() | |
| logger.info("XFormers memory efficient attention enabled") | |
| except Exception as e: | |
| logger.warning("Could not enable XFormers: %s", str(e)) | |
| logger.info("ColorizeNet model loaded successfully as ControlNet") | |
| self.model_type = "controlnet" | |
| except Exception as e: | |
| logger.warning("Failed to load as ControlNet: %s", str(e)) | |
| # Fallback: try as image-to-image pipeline | |
| logger.info("Trying to load as image-to-image pipeline...") | |
| self.pipe = pipeline( | |
| "image-to-image", | |
| model=self.model_id, | |
| device=0 if self.device == "cuda" else -1, | |
| torch_dtype=self.dtype | |
| ) | |
| logger.info("ColorizeNet model loaded using image-to-image pipeline") | |
| self.model_type = "pipeline" | |
| except Exception as e: | |
| logger.error("Failed to load ColorizeNet model: %s", str(e)) | |
| raise RuntimeError(f"Could not load ColorizeNet model: {str(e)}") | |
| def preprocess_image(self, image: Image.Image) -> Image.Image: | |
| """ | |
| Preprocess image for colorization | |
| Args: | |
| image: PIL Image | |
| Returns: | |
| Preprocessed PIL Image | |
| """ | |
| # Convert to grayscale if needed | |
| if image.mode != "L": | |
| # Convert to grayscale | |
| image = image.convert("L") | |
| # Convert back to RGB (grayscale image with 3 channels) | |
| image = image.convert("RGB") | |
| # Resize to standard size (512x512 for SD models) | |
| image = image.resize((512, 512), Image.Resampling.LANCZOS) | |
| return image | |
| def colorize(self, image: Image.Image, num_inference_steps: int = None) -> Image.Image: | |
| """ | |
| Colorize a grayscale image | |
| Args: | |
| image: PIL Image (grayscale or color) | |
| num_inference_steps: Number of inference steps (auto-adjusted for CPU/GPU) | |
| Returns: | |
| Colorized PIL Image | |
| """ | |
| try: | |
| # Optimize inference steps based on device | |
| if num_inference_steps is None: | |
| # Use fewer steps on CPU for faster processing | |
| num_inference_steps = 8 if self.device == "cpu" else 20 | |
| # Preprocess image | |
| control_image = self.preprocess_image(image) | |
| original_size = image.size | |
| # Prepare prompt for colorization | |
| prompt = "colorize this black and white image, high quality, detailed, vibrant colors, natural colors" | |
| negative_prompt = "black and white, grayscale, monochrome, low quality, blurry, desaturated" | |
| # Adjust guidance scale for CPU (lower = faster) | |
| guidance_scale = 5.0 if self.device == "cpu" else 7.5 | |
| # Generate colorized image based on model type | |
| if self.model_type == "controlnet": | |
| # Use ControlNet pipeline | |
| result = self.pipe( | |
| prompt=prompt, | |
| image=control_image, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| controlnet_conditioning_scale=1.0, | |
| generator=torch.Generator(device=self.device).manual_seed(42) | |
| ) | |
| if isinstance(result, dict) and "images" in result: | |
| colorized = result["images"][0] | |
| elif isinstance(result, list) and len(result) > 0: | |
| colorized = result[0] | |
| else: | |
| colorized = result | |
| else: | |
| # Use pipeline directly | |
| result = self.pipe( | |
| control_image, | |
| prompt=prompt, | |
| num_inference_steps=num_inference_steps | |
| ) | |
| if isinstance(result, dict) and "images" in result: | |
| colorized = result["images"][0] | |
| elif isinstance(result, list) and len(result) > 0: | |
| colorized = result[0] | |
| else: | |
| colorized = result | |
| # Ensure we have a PIL Image | |
| if not isinstance(colorized, Image.Image): | |
| if isinstance(colorized, np.ndarray): | |
| # Handle numpy array | |
| if colorized.dtype != np.uint8: | |
| colorized = (colorized * 255).astype(np.uint8) | |
| if len(colorized.shape) == 3 and colorized.shape[2] == 3: | |
| colorized = Image.fromarray(colorized, 'RGB') | |
| else: | |
| colorized = Image.fromarray(colorized) | |
| elif torch.is_tensor(colorized): | |
| # Handle torch tensor | |
| colorized = colorized.cpu().permute(1, 2, 0).numpy() | |
| colorized = (colorized * 255).astype(np.uint8) | |
| colorized = Image.fromarray(colorized, 'RGB') | |
| else: | |
| raise ValueError(f"Unexpected output type: {type(colorized)}") | |
| # Resize back to original size | |
| if original_size != (512, 512): | |
| colorized = colorized.resize(original_size, Image.Resampling.LANCZOS) | |
| return colorized | |
| except Exception as e: | |
| logger.error("Error during colorization: %s", str(e)) | |
| raise | |