|
|
import torch |
|
|
from torch import nn |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
import io |
|
|
import base64 |
|
|
import requests |
|
|
|
|
|
|
|
|
MODEL_PATH = "generator.pt" |
|
|
HF_MODEL_REPO = "Hammad712/GAN-Colorization-Model" |
|
|
|
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
model_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=MODEL_PATH) |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
generator = torch.load(model_path, map_location=device) |
|
|
generator.eval() |
|
|
|
|
|
|
|
|
transform_gray = transforms.Compose([ |
|
|
transforms.Resize((256, 256)), |
|
|
transforms.Grayscale(), |
|
|
transforms.ToTensor() |
|
|
]) |
|
|
|
|
|
transform_color = transforms.Compose([ |
|
|
transforms.Resize((256, 256)), |
|
|
transforms.ToTensor() |
|
|
]) |
|
|
|
|
|
to_pil = transforms.ToPILImage() |
|
|
|
|
|
|
|
|
def colorize_image(input_image): |
|
|
img = transform_gray(input_image).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
output = generator(img) |
|
|
output_image = to_pil(output.squeeze().cpu().clamp(0, 1)) |
|
|
return output_image |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=colorize_image, |
|
|
inputs=gr.Image(type="pil", label="Upload Grayscale Image"), |
|
|
outputs=gr.Image(type="pil", label="Colorized Image"), |
|
|
title="GAN Image Colorization (Hammad712)", |
|
|
description="Colorizes black and white images using GAN model" |
|
|
) |
|
|
|
|
|
|
|
|
from fastapi import FastAPI, File, UploadFile |
|
|
from fastapi.responses import StreamingResponse |
|
|
|
|
|
app = FastAPI(title="GAN Colorization API") |
|
|
|
|
|
@app.post("/colorize") |
|
|
async def colorize_api(file: UploadFile = File(...)): |
|
|
image = Image.open(io.BytesIO(await file.read())).convert("RGB") |
|
|
colorized = colorize_image(image) |
|
|
buf = io.BytesIO() |
|
|
colorized.save(buf, format="PNG") |
|
|
buf.seek(0) |
|
|
return StreamingResponse(buf, media_type="image/png") |
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
app = gr.mount_gradio_app(app, iface, path="/") |
|
|
|
|
|
|