LogicGoInfotechSpaces commited on
Commit
ec7bfd1
·
1 Parent(s): 0a1a3e1

Improve PyTorch colorizer: Better preprocessing, output handling, and debugging

Browse files
Files changed (1) hide show
  1. app/pytorch_colorizer.py +48 -7
app/pytorch_colorizer.py CHANGED
@@ -141,6 +141,12 @@ class PyTorchColorizer:
141
  # Otherwise, it's likely a state_dict
142
  state_dict = loaded_obj
143
 
 
 
 
 
 
 
144
  except Exception as e:
145
  logger.error(f"Failed to load model file: {e}")
146
  raise
@@ -216,30 +222,65 @@ class PyTorchColorizer:
216
  if image.mode != "L":
217
  image = image.convert("L")
218
 
 
 
 
 
 
 
 
 
 
 
219
  # Transform to tensor
 
220
  transform = transforms.Compose([
221
- transforms.Resize((256, 256)), # Common size for GAN models
222
- transforms.ToTensor(),
223
- transforms.Normalize(mean=[0.5], std=[0.5]) # Normalize to [-1, 1]
224
  ])
225
 
226
  input_tensor = transform(image).unsqueeze(0).to(self.device)
227
 
 
 
 
228
  # Run inference
229
  with torch.no_grad():
230
- output_tensor = self.model(input_tensor)
 
 
 
 
 
 
 
231
 
232
  # Convert output back to PIL Image
233
- # Output is typically in range [-1, 1] from Tanh activation
234
  output_tensor = output_tensor.squeeze(0).cpu()
235
- output_tensor = (output_tensor + 1) / 2.0 # Denormalize from [-1, 1] to [0, 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  output_tensor = torch.clamp(output_tensor, 0, 1)
237
 
238
  # Convert to numpy and then PIL
239
  output_array = (output_tensor.permute(1, 2, 0).numpy() * 255).astype('uint8')
240
  output_image = Image.fromarray(output_array, 'RGB')
241
 
242
- # Resize back to original size
243
  if output_image.size != original_size:
244
  output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
245
 
 
141
  # Otherwise, it's likely a state_dict
142
  state_dict = loaded_obj
143
 
144
+ # Log state dict keys to understand model structure
145
+ if isinstance(state_dict, dict):
146
+ keys = list(state_dict.keys())[:10] # First 10 keys
147
+ logger.info(f"Model state_dict keys (sample): {keys}")
148
+ logger.info(f"Total state_dict keys: {len(state_dict.keys())}")
149
+
150
  except Exception as e:
151
  logger.error(f"Failed to load model file: {e}")
152
  raise
 
222
  if image.mode != "L":
223
  image = image.convert("L")
224
 
225
+ # Try to maintain aspect ratio and use a better resize
226
+ # Many GAN models work better with 256x256 or 512x512
227
+ target_size = 256
228
+ if max(original_size) > 512:
229
+ # Scale down proportionally but keep max dimension reasonable
230
+ scale = target_size / max(original_size)
231
+ new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
232
+ else:
233
+ new_size = original_size
234
+
235
  # Transform to tensor
236
+ # GAN colorization models typically expect normalized input
237
  transform = transforms.Compose([
238
+ transforms.Resize(new_size, Image.Resampling.LANCZOS),
239
+ transforms.ToTensor(), # Converts to [0, 1]
 
240
  ])
241
 
242
  input_tensor = transform(image).unsqueeze(0).to(self.device)
243
 
244
+ # Normalize to [-1, 1] for GAN models (common for Pix2Pix style models)
245
+ input_tensor = (input_tensor - 0.5) / 0.5
246
+
247
  # Run inference
248
  with torch.no_grad():
249
+ try:
250
+ output_tensor = self.model(input_tensor)
251
+ logger.debug(f"Model output shape: {output_tensor.shape}, min: {output_tensor.min():.3f}, max: {output_tensor.max():.3f}, mean: {output_tensor.mean():.3f}")
252
+ except Exception as e:
253
+ logger.error(f"Model inference error: {e}")
254
+ # If model fails, try with different input format (without normalization)
255
+ input_tensor_alt = transform(image).unsqueeze(0).to(self.device)
256
+ output_tensor = self.model(input_tensor_alt)
257
 
258
  # Convert output back to PIL Image
 
259
  output_tensor = output_tensor.squeeze(0).cpu()
260
+
261
+ # Handle different output ranges
262
+ # Check if output is in [-1, 1] range (from Tanh) or [0, 1] range
263
+ output_min = output_tensor.min().item()
264
+ output_max = output_tensor.max().item()
265
+ logger.debug(f"Output tensor range: [{output_min:.3f}, {output_max:.3f}]")
266
+
267
+ if output_min < -0.5:
268
+ # Likely [-1, 1] range, denormalize
269
+ output_tensor = (output_tensor + 1) / 2.0
270
+ logger.debug("Applied [-1, 1] denormalization")
271
+ elif output_max > 1.5:
272
+ # Might be in [0, 255] range
273
+ output_tensor = output_tensor / 255.0
274
+ logger.debug("Applied [0, 255] normalization")
275
+ # If already in [0, 1], use as-is
276
+
277
  output_tensor = torch.clamp(output_tensor, 0, 1)
278
 
279
  # Convert to numpy and then PIL
280
  output_array = (output_tensor.permute(1, 2, 0).numpy() * 255).astype('uint8')
281
  output_image = Image.fromarray(output_array, 'RGB')
282
 
283
+ # Resize back to original size with high-quality resampling
284
  if output_image.size != original_size:
285
  output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
286