LogicGoInfotechSpaces commited on
Commit
b475327
·
verified ·
1 Parent(s): fb13003

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +92 -52
app/main.py CHANGED
@@ -1,72 +1,112 @@
 
 
 
1
  import torch
2
- from torch import nn
3
- from torchvision import transforms
 
4
  from PIL import Image
 
5
  import gradio as gr
6
- import io
7
- import base64
8
- import requests
9
 
10
- # Download model from Hugging Face Hub
 
 
11
  MODEL_PATH = "generator.pt"
12
- HF_MODEL_REPO = "Hammad712/GAN-Colorization-Model"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Auto-download model if not present
15
- from huggingface_hub import hf_hub_download
16
- model_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=MODEL_PATH)
17
 
18
- # Load the model (generator)
 
 
 
 
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
- generator = torch.load(model_path, map_location=device)
21
- generator.eval()
22
 
23
- # Define transforms
24
- transform_gray = transforms.Compose([
25
- transforms.Resize((256, 256)),
26
- transforms.Grayscale(),
27
- transforms.ToTensor()
28
- ])
29
 
30
- transform_color = transforms.Compose([
31
- transforms.Resize((256, 256)),
32
- transforms.ToTensor()
33
- ])
34
 
35
- to_pil = transforms.ToPILImage()
 
 
 
 
 
 
 
 
36
 
37
- # Colorization function
38
- def colorize_image(input_image):
39
- img = transform_gray(input_image).unsqueeze(0).to(device)
40
  with torch.no_grad():
41
- output = generator(img)
42
- output_image = to_pil(output.squeeze().cpu().clamp(0, 1))
43
- return output_image
44
-
45
- # ---- Gradio Interface ----
46
- iface = gr.Interface(
47
- fn=colorize_image,
48
- inputs=gr.Image(type="pil", label="Upload Grayscale Image"),
49
- outputs=gr.Image(type="pil", label="Colorized Image"),
50
- title="GAN Image Colorization (Hammad712)",
51
- description="Colorizes black and white images using GAN model"
52
- )
53
 
54
- # ---- API Endpoint ----
55
- from fastapi import FastAPI, File, UploadFile
56
- from fastapi.responses import StreamingResponse
57
 
58
- app = FastAPI(title="GAN Colorization API")
 
 
 
59
 
60
  @app.post("/colorize")
61
- async def colorize_api(file: UploadFile = File(...)):
62
- image = Image.open(io.BytesIO(await file.read())).convert("RGB")
 
63
  colorized = colorize_image(image)
64
- buf = io.BytesIO()
65
- colorized.save(buf, format="PNG")
66
- buf.seek(0)
67
- return StreamingResponse(buf, media_type="image/png")
68
 
69
- # Mount Gradio UI on /
70
- import gradio as gr
71
- app = gr.mount_gradio_app(app, iface, path="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
 
 
 
1
+ import io
2
+ import os
3
+ import uuid
4
  import torch
5
+ import torch.nn as nn
6
+ from fastapi import FastAPI, UploadFile, File
7
+ from fastapi.responses import FileResponse
8
  from PIL import Image
9
+ import torchvision.transforms as T
10
  import gradio as gr
11
+ import uvicorn
 
 
12
 
13
+ # ==========================================================
14
+ # 🔧 PATHS
15
+ # ==========================================================
16
  MODEL_PATH = "generator.pt"
17
+ UPLOAD_DIR = "/tmp/uploads"
18
+ RESULT_DIR = "/tmp/results"
19
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
20
+ os.makedirs(RESULT_DIR, exist_ok=True)
21
+
22
+ # ==========================================================
23
+ # 🧩 Define Generator Architecture (from repo style)
24
+ # ==========================================================
25
+ class Generator(nn.Module):
26
+ def __init__(self):
27
+ super(Generator, self).__init__()
28
+ self.main = nn.Sequential(
29
+ nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
30
+ nn.ReLU(True),
31
+
32
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
33
+ nn.BatchNorm2d(128),
34
+ nn.ReLU(True),
35
+
36
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
37
+ nn.BatchNorm2d(64),
38
+ nn.ReLU(True),
39
 
40
+ nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
41
+ nn.Tanh()
42
+ )
43
 
44
+ def forward(self, x):
45
+ return self.main(x)
46
+
47
+ # ==========================================================
48
+ # 🚀 Load Model
49
+ # ==========================================================
50
  device = "cuda" if torch.cuda.is_available() else "cpu"
51
+ generator = Generator().to(device)
 
52
 
53
+ # Load weights
54
+ state_dict = torch.load(MODEL_PATH, map_location=device)
55
+ generator.load_state_dict(state_dict)
56
+ generator.eval()
 
 
57
 
58
+ print("✅ Model loaded successfully!")
 
 
 
59
 
60
+ # ==========================================================
61
+ # 🎨 Colorization Function
62
+ # ==========================================================
63
+ def colorize_image(image: Image.Image):
64
+ transform = T.Compose([
65
+ T.Resize((256, 256)),
66
+ T.Grayscale(num_output_channels=1),
67
+ T.ToTensor()
68
+ ])
69
 
70
+ img_tensor = transform(image).unsqueeze(0).to(device)
 
 
71
  with torch.no_grad():
72
+ output = generator(img_tensor)
73
+ output = (output.squeeze(0).permute(1, 2, 0).cpu().numpy() + 1) / 2.0 # Scale 0-1
 
 
 
 
 
 
 
 
 
 
74
 
75
+ output_img = Image.fromarray((output * 255).astype("uint8"))
76
+ return output_img
 
77
 
78
+ # ==========================================================
79
+ # 🌐 FASTAPI APP
80
+ # ==========================================================
81
+ app = FastAPI(title="GAN Image Colorization API")
82
 
83
  @app.post("/colorize")
84
+ async def colorize_endpoint(file: UploadFile = File(...)):
85
+ img_bytes = await file.read()
86
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
87
  colorized = colorize_image(image)
 
 
 
 
88
 
89
+ output_filename = f"{uuid.uuid4()}.png"
90
+ output_path = os.path.join(RESULT_DIR, output_filename)
91
+ colorized.save(output_path)
92
+
93
+ return FileResponse(output_path, media_type="image/png")
94
+
95
+ # ==========================================================
96
+ # 💠 GRADIO UI
97
+ # ==========================================================
98
+ def gradio_ui(image):
99
+ return colorize_image(image)
100
+
101
+ iface = gr.Interface(
102
+ fn=gradio_ui,
103
+ inputs=gr.Image(type="pil", label="Upload B&W Image"),
104
+ outputs=gr.Image(type="pil", label="Colorized Image"),
105
+ title="🎨 GAN Image Colorization",
106
+ description="Upload a black-and-white photo to get it colorized using a GAN model."
107
+ )
108
+
109
+ gradio_app = gr.mount_gradio_app(app, iface, path="/")
110
 
111
+ if __name__ == "__main__":
112
+ uvicorn.run(app, host="0.0.0.0", port=7860)