LogicGoInfotechSpaces's picture
Update app/main.py
19bed74 verified
raw
history blame
2.03 kB
import torch
from torch import nn
from torchvision import transforms
from PIL import Image
import gradio as gr
import io
import base64
import requests
# Download model from Hugging Face Hub
MODEL_PATH = "generator.pt"
HF_MODEL_REPO = "Hammad712/GAN-Colorization-Model"
# Auto-download model if not present
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=MODEL_PATH)
# Load the model (generator)
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = torch.load(model_path, map_location=device)
generator.eval()
# Define transforms
transform_gray = transforms.Compose([
transforms.Resize((256, 256)),
transforms.Grayscale(),
transforms.ToTensor()
])
transform_color = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
to_pil = transforms.ToPILImage()
# Colorization function
def colorize_image(input_image):
img = transform_gray(input_image).unsqueeze(0).to(device)
with torch.no_grad():
output = generator(img)
output_image = to_pil(output.squeeze().cpu().clamp(0, 1))
return output_image
# ---- Gradio Interface ----
iface = gr.Interface(
fn=colorize_image,
inputs=gr.Image(type="pil", label="Upload Grayscale Image"),
outputs=gr.Image(type="pil", label="Colorized Image"),
title="GAN Image Colorization (Hammad712)",
description="Colorizes black and white images using GAN model"
)
# ---- API Endpoint ----
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import StreamingResponse
app = FastAPI(title="GAN Colorization API")
@app.post("/colorize")
async def colorize_api(file: UploadFile = File(...)):
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
colorized = colorize_image(image)
buf = io.BytesIO()
colorized.save(buf, format="PNG")
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")
# Mount Gradio UI on /
import gradio as gr
app = gr.mount_gradio_app(app, iface, path="/")