LogicGoInfotechSpaces commited on
Commit
0a1a3e1
·
1 Parent(s): 293fc40

Fix image colorization: Add PyTorch GAN colorizer fallback, update Dockerfile to use main_fastai, and add missing dependencies

Browse files
Files changed (4) hide show
  1. Dockerfile +1 -1
  2. app/main_fastai.py +68 -39
  3. app/pytorch_colorizer.py +247 -0
  4. requirements.txt +4 -1
Dockerfile CHANGED
@@ -63,4 +63,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
63
  ENTRYPOINT ["/entrypoint.sh"]
64
 
65
  # Run the application (port will be set via environment variable)
66
- CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port ${PORT:-7860}"]
 
63
  ENTRYPOINT ["/entrypoint.sh"]
64
 
65
  # Run the application (port will be set via environment variable)
66
+ CMD ["sh", "-c", "uvicorn app.main_fastai:app --host 0.0.0.0 --port ${PORT:-7860}"]
app/main_fastai.py CHANGED
@@ -34,6 +34,7 @@ from fastai.vision.all import *
34
  from huggingface_hub import from_pretrained_fastai
35
 
36
  from app.config import settings
 
37
 
38
  # Configure logging
39
  logging.basicConfig(
@@ -94,30 +95,50 @@ app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
94
 
95
  # Initialize FastAI model
96
  learn = None
 
97
  model_load_error: Optional[str] = None
 
98
 
99
  @app.on_event("startup")
100
  async def startup_event():
101
- """Load FastAI model on startup"""
102
- global learn, model_load_error
 
 
 
103
  try:
104
- model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
105
- logger.info("🔄 Loading FastAI GAN Colorization Model: %s", model_id)
106
  learn = from_pretrained_fastai(model_id)
107
- logger.info("✅ Model loaded successfully!")
 
108
  model_load_error = None
 
109
  except Exception as e:
110
  error_msg = str(e)
111
- logger.error(" Failed to load model: %s", error_msg)
 
 
 
 
 
 
 
 
 
 
 
112
  model_load_error = error_msg
 
113
  # Don't raise - allow health check to work
114
 
115
  @app.on_event("shutdown")
116
  async def shutdown_event():
117
  """Cleanup on shutdown"""
118
- global learn
119
  if learn:
120
  del learn
 
 
121
  logger.info("Application shutdown")
122
 
123
  def _extract_bearer_token(authorization_header: str | None) -> str | None:
@@ -182,7 +203,8 @@ async def health_check():
182
  """Health check endpoint"""
183
  response = {
184
  "status": "healthy",
185
- "model_loaded": learn is not None,
 
186
  "model_id": os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
187
  }
188
  if model_load_error:
@@ -191,38 +213,45 @@ async def health_check():
191
 
192
  def colorize_pil(image: Image.Image) -> Image.Image:
193
  """Run model prediction and return colorized image"""
194
- if learn is None:
195
- raise RuntimeError("Model not loaded")
196
- if image.mode != "RGB":
197
- image = image.convert("RGB")
198
- pred = learn.predict(image)
199
- # Handle different return types from FastAI
200
- if isinstance(pred, (list, tuple)):
201
- colorized = pred[0] if len(pred) > 0 else image
202
- else:
203
- colorized = pred
204
-
205
- # Ensure we have a PIL Image
206
- if not isinstance(colorized, Image.Image):
207
- if isinstance(colorized, torch.Tensor):
208
- # Convert tensor to PIL
209
- if colorized.dim() == 4:
210
- colorized = colorized[0]
211
- if colorized.dim() == 3:
212
- colorized = colorized.permute(1, 2, 0).cpu()
213
- if colorized.dtype in (torch.float32, torch.float16):
214
- colorized = torch.clamp(colorized, 0, 1)
215
- colorized = (colorized * 255).byte()
216
- colorized = Image.fromarray(colorized.numpy(), 'RGB')
217
- else:
218
- raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
219
  else:
220
- raise ValueError(f"Unexpected prediction type: {type(colorized)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- if colorized.mode != "RGB":
223
- colorized = colorized.convert("RGB")
 
224
 
225
- return colorized
 
226
 
227
  @app.post("/colorize")
228
  async def colorize_api(
@@ -233,7 +262,7 @@ async def colorize_api(
233
  Upload a black & white image -> returns colorized image.
234
  Requires Firebase authentication unless DISABLE_AUTH=true
235
  """
236
- if learn is None:
237
  raise HTTPException(status_code=503, detail="Colorization model not loaded")
238
 
239
  if not file.content_type or not file.content_type.startswith("image/"):
@@ -270,7 +299,7 @@ def gradio_colorize(image):
270
  if image is None:
271
  return None
272
  try:
273
- if learn is None:
274
  return None
275
  return colorize_pil(image)
276
  except Exception as e:
 
34
  from huggingface_hub import from_pretrained_fastai
35
 
36
  from app.config import settings
37
+ from app.pytorch_colorizer import PyTorchColorizer
38
 
39
  # Configure logging
40
  logging.basicConfig(
 
95
 
96
  # Initialize FastAI model
97
  learn = None
98
+ pytorch_colorizer = None
99
  model_load_error: Optional[str] = None
100
+ model_type: str = "none" # "fastai", "pytorch", or "none"
101
 
102
  @app.on_event("startup")
103
  async def startup_event():
104
+ """Load FastAI or PyTorch model on startup"""
105
+ global learn, pytorch_colorizer, model_load_error, model_type
106
+ model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
107
+
108
+ # Try FastAI first
109
  try:
110
+ logger.info("🔄 Attempting to load FastAI GAN Colorization Model: %s", model_id)
 
111
  learn = from_pretrained_fastai(model_id)
112
+ logger.info("✅ FastAI model loaded successfully!")
113
+ model_type = "fastai"
114
  model_load_error = None
115
+ return
116
  except Exception as e:
117
  error_msg = str(e)
118
+ logger.warning("⚠️ FastAI model loading failed: %s. Trying PyTorch fallback...", error_msg)
119
+
120
+ # Fallback to PyTorch
121
+ try:
122
+ logger.info("🔄 Attempting to load PyTorch GAN Colorization Model: %s", model_id)
123
+ pytorch_colorizer = PyTorchColorizer(model_id=model_id, model_filename="generator.pt")
124
+ logger.info("✅ PyTorch model loaded successfully!")
125
+ model_type = "pytorch"
126
+ model_load_error = None
127
+ except Exception as e:
128
+ error_msg = str(e)
129
+ logger.error("❌ Failed to load both FastAI and PyTorch models: %s", error_msg)
130
  model_load_error = error_msg
131
+ model_type = "none"
132
  # Don't raise - allow health check to work
133
 
134
  @app.on_event("shutdown")
135
  async def shutdown_event():
136
  """Cleanup on shutdown"""
137
+ global learn, pytorch_colorizer
138
  if learn:
139
  del learn
140
+ if pytorch_colorizer:
141
+ del pytorch_colorizer
142
  logger.info("Application shutdown")
143
 
144
  def _extract_bearer_token(authorization_header: str | None) -> str | None:
 
203
  """Health check endpoint"""
204
  response = {
205
  "status": "healthy",
206
+ "model_loaded": (learn is not None) or (pytorch_colorizer is not None),
207
+ "model_type": model_type,
208
  "model_id": os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
209
  }
210
  if model_load_error:
 
213
 
214
  def colorize_pil(image: Image.Image) -> Image.Image:
215
  """Run model prediction and return colorized image"""
216
+ # Try FastAI first
217
+ if learn is not None:
218
+ if image.mode != "RGB":
219
+ image = image.convert("RGB")
220
+ pred = learn.predict(image)
221
+ # Handle different return types from FastAI
222
+ if isinstance(pred, (list, tuple)):
223
+ colorized = pred[0] if len(pred) > 0 else image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  else:
225
+ colorized = pred
226
+
227
+ # Ensure we have a PIL Image
228
+ if not isinstance(colorized, Image.Image):
229
+ if isinstance(colorized, torch.Tensor):
230
+ # Convert tensor to PIL
231
+ if colorized.dim() == 4:
232
+ colorized = colorized[0]
233
+ if colorized.dim() == 3:
234
+ colorized = colorized.permute(1, 2, 0).cpu()
235
+ if colorized.dtype in (torch.float32, torch.float16):
236
+ colorized = torch.clamp(colorized, 0, 1)
237
+ colorized = (colorized * 255).byte()
238
+ colorized = Image.fromarray(colorized.numpy(), 'RGB')
239
+ else:
240
+ raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
241
+ else:
242
+ raise ValueError(f"Unexpected prediction type: {type(colorized)}")
243
+
244
+ if colorized.mode != "RGB":
245
+ colorized = colorized.convert("RGB")
246
+
247
+ return colorized
248
 
249
+ # Fallback to PyTorch
250
+ elif pytorch_colorizer is not None:
251
+ return pytorch_colorizer.colorize(image)
252
 
253
+ else:
254
+ raise RuntimeError("No colorization model loaded")
255
 
256
  @app.post("/colorize")
257
  async def colorize_api(
 
262
  Upload a black & white image -> returns colorized image.
263
  Requires Firebase authentication unless DISABLE_AUTH=true
264
  """
265
+ if learn is None and pytorch_colorizer is None:
266
  raise HTTPException(status_code=503, detail="Colorization model not loaded")
267
 
268
  if not file.content_type or not file.content_type.startswith("image/"):
 
299
  if image is None:
300
  return None
301
  try:
302
+ if learn is None and pytorch_colorizer is None:
303
  return None
304
  return colorize_pil(image)
305
  except Exception as e:
app/pytorch_colorizer.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch GAN Colorization Model Loader
3
+ Handles loading and inference for PyTorch GAN colorization models
4
+ """
5
+ import functools
6
+ import logging
7
+ import os
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class UNetGenerator(nn.Module):
20
+ """
21
+ U-Net Generator for Image Colorization
22
+ Common architecture for GAN-based colorization models
23
+ """
24
+ def __init__(self, input_nc=1, output_nc=3, num_downs=8, ngf=64, use_dropout=False):
25
+ super(UNetGenerator, self).__init__()
26
+
27
+ # Build U-Net
28
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None,
29
+ norm_layer=nn.BatchNorm2d, innermost=True)
30
+ for i in range(num_downs - 5):
31
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None,
32
+ submodule=unet_block, norm_layer=nn.BatchNorm2d,
33
+ use_dropout=use_dropout)
34
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None,
35
+ submodule=unet_block, norm_layer=nn.BatchNorm2d)
36
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None,
37
+ submodule=unet_block, norm_layer=nn.BatchNorm2d)
38
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None,
39
+ submodule=unet_block, norm_layer=nn.BatchNorm2d)
40
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc,
41
+ submodule=unet_block, outermost=True,
42
+ norm_layer=nn.BatchNorm2d)
43
+
44
+ def forward(self, input):
45
+ return self.model(input)
46
+
47
+
48
+ class UnetSkipConnectionBlock(nn.Module):
49
+ """Defines the Unet submodule with skip connection"""
50
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
51
+ submodule=None, outermost=False, innermost=False,
52
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
53
+ super(UnetSkipConnectionBlock, self).__init__()
54
+ self.outermost = outermost
55
+ if type(norm_layer) == functools.partial:
56
+ use_bias = norm_layer.func == nn.InstanceNorm2d
57
+ else:
58
+ use_bias = norm_layer == nn.InstanceNorm2d
59
+ if input_nc is None:
60
+ input_nc = outer_nc
61
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
62
+ stride=2, padding=1, bias=use_bias)
63
+ downrelu = nn.LeakyReLU(0.2, True)
64
+ downnorm = norm_layer(inner_nc)
65
+ uprelu = nn.ReLU(True)
66
+ upnorm = norm_layer(outer_nc)
67
+
68
+ if outermost:
69
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
70
+ kernel_size=4, stride=2,
71
+ padding=1)
72
+ down = [downconv]
73
+ up = [uprelu, upconv, nn.Tanh()]
74
+ model = down + [submodule] + up
75
+ elif innermost:
76
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
77
+ kernel_size=4, stride=2,
78
+ padding=1, bias=use_bias)
79
+ down = [downrelu, downconv]
80
+ up = [uprelu, upconv, upnorm]
81
+ model = down + up
82
+ else:
83
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
84
+ kernel_size=4, stride=2,
85
+ padding=1, bias=use_bias)
86
+ down = [downrelu, downconv, downnorm]
87
+ up = [uprelu, upconv, upnorm]
88
+
89
+ if use_dropout:
90
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
91
+ else:
92
+ model = down + [submodule] + up
93
+
94
+ self.model = nn.Sequential(*model)
95
+
96
+ def forward(self, x):
97
+ if self.outermost:
98
+ return self.model(x)
99
+ else:
100
+ return torch.cat([x, self.model(x)], 1)
101
+
102
+
103
+ class PyTorchColorizer:
104
+ """PyTorch GAN Colorization Model"""
105
+
106
+ def __init__(self, model_id: str = "Hammad712/GAN-Colorization-Model", model_filename: str = "generator.pt"):
107
+ self.model_id = model_id
108
+ self.model_filename = model_filename
109
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
110
+ self.model = None
111
+ self.cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
112
+
113
+ logger.info(f"Loading PyTorch GAN colorization model: {model_id}/{model_filename}")
114
+ self._load_model()
115
+
116
+ def _load_model(self):
117
+ """Load the PyTorch model"""
118
+ try:
119
+ # Download model file
120
+ model_path = hf_hub_download(
121
+ repo_id=self.model_id,
122
+ filename=self.model_filename,
123
+ cache_dir=self.cache_dir
124
+ )
125
+
126
+ logger.info(f"Model downloaded to: {model_path}")
127
+
128
+ # Try loading the model file
129
+ # First, try loading as a complete model (if saved with torch.save(model, path))
130
+ try:
131
+ loaded_obj = torch.load(model_path, map_location=self.device)
132
+
133
+ # Check if it's already a model instance
134
+ if isinstance(loaded_obj, nn.Module):
135
+ self.model = loaded_obj
136
+ self.model.eval()
137
+ self.model.to(self.device)
138
+ logger.info("✅ Loaded complete model object")
139
+ return
140
+
141
+ # Otherwise, it's likely a state_dict
142
+ state_dict = loaded_obj
143
+
144
+ except Exception as e:
145
+ logger.error(f"Failed to load model file: {e}")
146
+ raise
147
+
148
+ # Try different model architectures with state_dict
149
+ model_configs = [
150
+ {"input_nc": 1, "output_nc": 3, "num_downs": 8, "ngf": 64},
151
+ {"input_nc": 1, "output_nc": 3, "num_downs": 7, "ngf": 64},
152
+ {"input_nc": 1, "output_nc": 3, "num_downs": 8, "ngf": 32},
153
+ {"input_nc": 1, "output_nc": 3, "num_downs": 6, "ngf": 64},
154
+ ]
155
+
156
+ loaded = False
157
+ for config in model_configs:
158
+ try:
159
+ model = UNetGenerator(**config)
160
+ # Try strict loading first
161
+ try:
162
+ model.load_state_dict(state_dict, strict=True)
163
+ logger.info(f"✅ Successfully loaded model with strict matching: {config}")
164
+ except:
165
+ # If strict fails, try non-strict
166
+ model.load_state_dict(state_dict, strict=False)
167
+ logger.info(f"✅ Successfully loaded model with non-strict matching: {config}")
168
+
169
+ model.eval()
170
+ model.to(self.device)
171
+ self.model = model
172
+ loaded = True
173
+ break
174
+ except Exception as e:
175
+ logger.debug(f"Failed to load with config {config}: {e}")
176
+ continue
177
+
178
+ if not loaded:
179
+ # Last resort: try with default config and non-strict loading
180
+ try:
181
+ logger.warning("Attempting to load model with default config and non-strict matching")
182
+ model = UNetGenerator(input_nc=1, output_nc=3, num_downs=8, ngf=64)
183
+ model.load_state_dict(state_dict, strict=False)
184
+ model.eval()
185
+ model.to(self.device)
186
+ self.model = model
187
+ logger.info("✅ Model loaded with fallback method")
188
+ except Exception as e:
189
+ logger.error(f"Failed to load model: {e}")
190
+ raise RuntimeError(
191
+ f"Could not load PyTorch model. Tried multiple architectures. "
192
+ f"Last error: {e}. "
193
+ f"The model architecture may not match the expected U-Net structure."
194
+ )
195
+
196
+ except Exception as e:
197
+ logger.error(f"Error loading PyTorch model: {e}")
198
+ raise RuntimeError(f"Failed to load PyTorch colorization model: {e}")
199
+
200
+ def colorize(self, image: Image.Image) -> Image.Image:
201
+ """
202
+ Colorize a grayscale or color image
203
+
204
+ Args:
205
+ image: PIL Image (will be converted to grayscale if color)
206
+
207
+ Returns:
208
+ Colorized PIL Image
209
+ """
210
+ if self.model is None:
211
+ raise RuntimeError("Model not loaded")
212
+
213
+ original_size = image.size
214
+
215
+ # Convert to grayscale if needed
216
+ if image.mode != "L":
217
+ image = image.convert("L")
218
+
219
+ # Transform to tensor
220
+ transform = transforms.Compose([
221
+ transforms.Resize((256, 256)), # Common size for GAN models
222
+ transforms.ToTensor(),
223
+ transforms.Normalize(mean=[0.5], std=[0.5]) # Normalize to [-1, 1]
224
+ ])
225
+
226
+ input_tensor = transform(image).unsqueeze(0).to(self.device)
227
+
228
+ # Run inference
229
+ with torch.no_grad():
230
+ output_tensor = self.model(input_tensor)
231
+
232
+ # Convert output back to PIL Image
233
+ # Output is typically in range [-1, 1] from Tanh activation
234
+ output_tensor = output_tensor.squeeze(0).cpu()
235
+ output_tensor = (output_tensor + 1) / 2.0 # Denormalize from [-1, 1] to [0, 1]
236
+ output_tensor = torch.clamp(output_tensor, 0, 1)
237
+
238
+ # Convert to numpy and then PIL
239
+ output_array = (output_tensor.permute(1, 2, 0).numpy() * 255).astype('uint8')
240
+ output_image = Image.fromarray(output_array, 'RGB')
241
+
242
+ # Resize back to original size
243
+ if output_image.size != original_size:
244
+ output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
245
+
246
+ return output_image
247
+
requirements.txt CHANGED
@@ -4,4 +4,7 @@ fastapi
4
  uvicorn
5
  gradio
6
  pillow
7
- firebase-admin
 
 
 
 
4
  uvicorn
5
  gradio
6
  pillow
7
+ firebase-admin
8
+ fastai
9
+ huggingface_hub
10
+ pydantic-settings