import torch from torch import nn from torchvision import transforms from PIL import Image import gradio as gr import io import base64 import requests # Download model from Hugging Face Hub MODEL_PATH = "generator.pt" HF_MODEL_REPO = "Hammad712/GAN-Colorization-Model" # Auto-download model if not present from huggingface_hub import hf_hub_download model_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=MODEL_PATH) # Load the model (generator) device = "cuda" if torch.cuda.is_available() else "cpu" generator = torch.load(model_path, map_location=device) generator.eval() # Define transforms transform_gray = transforms.Compose([ transforms.Resize((256, 256)), transforms.Grayscale(), transforms.ToTensor() ]) transform_color = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor() ]) to_pil = transforms.ToPILImage() # Colorization function def colorize_image(input_image): img = transform_gray(input_image).unsqueeze(0).to(device) with torch.no_grad(): output = generator(img) output_image = to_pil(output.squeeze().cpu().clamp(0, 1)) return output_image # ---- Gradio Interface ---- iface = gr.Interface( fn=colorize_image, inputs=gr.Image(type="pil", label="Upload Grayscale Image"), outputs=gr.Image(type="pil", label="Colorized Image"), title="GAN Image Colorization (Hammad712)", description="Colorizes black and white images using GAN model" ) # ---- API Endpoint ---- from fastapi import FastAPI, File, UploadFile from fastapi.responses import StreamingResponse app = FastAPI(title="GAN Colorization API") @app.post("/colorize") async def colorize_api(file: UploadFile = File(...)): image = Image.open(io.BytesIO(await file.read())).convert("RGB") colorized = colorize_image(image) buf = io.BytesIO() colorized.save(buf, format="PNG") buf.seek(0) return StreamingResponse(buf, media_type="image/png") # Mount Gradio UI on / import gradio as gr app = gr.mount_gradio_app(app, iface, path="/")