LogicGoInfotechSpaces commited on
Commit
0454a91
·
1 Parent(s): a2d6cd7

Detect PyTorch models and provide clear error message - Repository contains generator.pt (PyTorch) not FastAI model

Browse files
Files changed (1) hide show
  1. app/colorize_model.py +36 -10
app/colorize_model.py CHANGED
@@ -69,18 +69,27 @@ class ColorizeModel:
69
  try:
70
  repo_files = list_repo_files(repo_id=self.model_id, token=hf_token)
71
  logger.info("Repository files: %s", repo_files)
72
- # Look for .pkl files
73
  pkl_files = [f for f in repo_files if f.endswith('.pkl')]
74
- if not pkl_files:
75
- # Also try common FastAI model file names
76
- model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl"]
77
- else:
78
  model_filenames = pkl_files
79
  logger.info("Found .pkl files in repository: %s", pkl_files)
 
 
 
 
 
 
 
 
 
80
  except Exception as list_err:
81
  logger.warning("Could not list repository files: %s. Trying common filenames...", str(list_err))
82
  # Fallback to common filenames
83
- model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl"]
 
84
 
85
  model_path = None
86
  for filename in model_filenames:
@@ -92,16 +101,33 @@ class ColorizeModel:
92
  token=hf_token
93
  )
94
  logger.info("Found model file: %s", filename)
 
 
 
 
 
95
  break
96
  except Exception as dl_err:
97
  logger.debug("Failed to download %s: %s", filename, str(dl_err))
98
  continue
99
 
100
  if model_path and os.path.exists(model_path):
101
- # Load the model using FastAI's load_learner
102
- logger.info("Loading model from: %s", model_path)
103
- self.learn = load_learner(model_path)
104
- logger.info("FastAI GAN Colorization model loaded successfully from %s", model_path)
 
 
 
 
 
 
 
 
 
 
 
 
105
  else:
106
  # If no model file found, raise error with more details
107
  raise RuntimeError(
 
69
  try:
70
  repo_files = list_repo_files(repo_id=self.model_id, token=hf_token)
71
  logger.info("Repository files: %s", repo_files)
72
+ # Look for .pkl files (FastAI) or .pt files (PyTorch)
73
  pkl_files = [f for f in repo_files if f.endswith('.pkl')]
74
+ pt_files = [f for f in repo_files if f.endswith('.pt')]
75
+
76
+ if pkl_files:
 
77
  model_filenames = pkl_files
78
  logger.info("Found .pkl files in repository: %s", pkl_files)
79
+ model_type = "fastai"
80
+ elif pt_files:
81
+ model_filenames = pt_files
82
+ logger.info("Found .pt files in repository: %s", pt_files)
83
+ model_type = "pytorch"
84
+ else:
85
+ # Fallback to common filenames
86
+ model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
87
+ model_type = "fastai" # Default assumption
88
  except Exception as list_err:
89
  logger.warning("Could not list repository files: %s. Trying common filenames...", str(list_err))
90
  # Fallback to common filenames
91
+ model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
92
+ model_type = "fastai"
93
 
94
  model_path = None
95
  for filename in model_filenames:
 
101
  token=hf_token
102
  )
103
  logger.info("Found model file: %s", filename)
104
+ # Determine model type from extension
105
+ if filename.endswith('.pt'):
106
+ model_type = "pytorch"
107
+ elif filename.endswith('.pkl'):
108
+ model_type = "fastai"
109
  break
110
  except Exception as dl_err:
111
  logger.debug("Failed to download %s: %s", filename, str(dl_err))
112
  continue
113
 
114
  if model_path and os.path.exists(model_path):
115
+ if model_type == "pytorch":
116
+ # Load PyTorch model - this is a GAN generator
117
+ logger.info("Loading PyTorch model from: %s", model_path)
118
+ # Note: This requires knowing the model architecture
119
+ # For now, we'll try to load it and see if it works
120
+ logger.warning("PyTorch model loading not fully implemented. This model may not work correctly.")
121
+ raise RuntimeError(
122
+ f"Repository '{self.model_id}' contains a PyTorch model (generator.pt), "
123
+ f"not a FastAI model. FastAI models must be .pkl files created with FastAI's export. "
124
+ f"Please use a FastAI-compatible colorization model, or switch to a different model backend."
125
+ )
126
+ else:
127
+ # Load the model using FastAI's load_learner
128
+ logger.info("Loading FastAI model from: %s", model_path)
129
+ self.learn = load_learner(model_path)
130
+ logger.info("FastAI GAN Colorization model loaded successfully from %s", model_path)
131
  else:
132
  # If no model file found, raise error with more details
133
  raise RuntimeError(