LogicGoInfotechSpaces commited on
Commit
2f136a8
·
1 Parent(s): 5c1f200

Fix model loading: use /data for HF cache directory instead of /app/data

Browse files
Files changed (3) hide show
  1. Dockerfile +3 -3
  2. app/colorize_model.py +23 -12
  3. app/main.py +14 -5
Dockerfile CHANGED
@@ -31,11 +31,11 @@ RUN mkdir -p /data/uploads /data/results && chmod -R 777 /data
31
  # This allows the credentials to be passed as a secret and written to file at runtime
32
  RUN echo '#!/bin/sh' > /entrypoint.sh && \
33
  echo 'set -e' >> /entrypoint.sh && \
 
 
34
  echo 'if [ -n "$FIREBASE_CREDENTIALS" ]; then' >> /entrypoint.sh && \
35
- echo ' mkdir -p /data' >> /entrypoint.sh && \
36
- echo ' touch /data/firebase-adminsdk.json' >> /entrypoint.sh && \
37
- echo ' chmod 600 /data/firebase-adminsdk.json' >> /entrypoint.sh && \
38
  echo ' printf "%s" "$FIREBASE_CREDENTIALS" > /data/firebase-adminsdk.json' >> /entrypoint.sh && \
 
39
  echo 'fi' >> /entrypoint.sh && \
40
  echo 'exec "$@"' >> /entrypoint.sh && \
41
  chmod +x /entrypoint.sh
 
31
  # This allows the credentials to be passed as a secret and written to file at runtime
32
  RUN echo '#!/bin/sh' > /entrypoint.sh && \
33
  echo 'set -e' >> /entrypoint.sh && \
34
+ echo 'mkdir -p /data/uploads /data/results' >> /entrypoint.sh && \
35
+ echo 'chmod -R 777 /data' >> /entrypoint.sh && \
36
  echo 'if [ -n "$FIREBASE_CREDENTIALS" ]; then' >> /entrypoint.sh && \
 
 
 
37
  echo ' printf "%s" "$FIREBASE_CREDENTIALS" > /data/firebase-adminsdk.json' >> /entrypoint.sh && \
38
+ echo ' chmod 600 /data/firebase-adminsdk.json' >> /entrypoint.sh && \
39
  echo 'fi' >> /entrypoint.sh && \
40
  echo 'exec "$@"' >> /entrypoint.sh && \
41
  chmod +x /entrypoint.sh
app/colorize_model.py CHANGED
@@ -30,25 +30,36 @@ class ColorizeModel:
30
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
  logger.info("Using device: %s", self.device)
32
  self.dtype = torch.float16 if self.device == "cuda" else torch.float32
33
- self.hf_token = os.getenv("HF_TOKEN") or None
 
34
 
35
  # Configure writable cache to avoid permission issues on Spaces
36
- # Prefer user home cache: ~/.cache/huggingface
37
- default_home_cache = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
38
- hf_cache_dir = os.getenv("HF_HOME", default_home_cache)
39
- os.environ.setdefault("HF_HOME", hf_cache_dir)
40
- os.environ.setdefault("HUGGINGFACE_HUB_CACHE", hf_cache_dir)
41
- os.environ.setdefault("TRANSFORMERS_CACHE", hf_cache_dir)
 
 
 
42
  try:
43
  os.makedirs(hf_cache_dir, exist_ok=True)
44
- except Exception:
45
- # Fallback to a local data dir if home is not writable
46
- hf_cache_dir = os.path.abspath(os.path.join(".", "data", "hf_cache"))
 
 
 
47
  os.environ["HF_HOME"] = hf_cache_dir
48
  os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache_dir
49
  os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir
50
- os.makedirs(hf_cache_dir, exist_ok=True)
51
- logger.info("HF cache directory: %s", hf_cache_dir)
 
 
 
 
52
 
53
  # Avoid libgomp warning by setting a valid integer
54
  os.environ.setdefault("OMP_NUM_THREADS", "1")
 
30
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
  logger.info("Using device: %s", self.device)
32
  self.dtype = torch.float16 if self.device == "cuda" else torch.float32
33
+ # Check for Hugging Face token (try both environment variable names)
34
+ self.hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or None
35
 
36
  # Configure writable cache to avoid permission issues on Spaces
37
+ # Use /data directory which is writable in Hugging Face Spaces
38
+ data_dir = os.getenv("DATA_DIR", "/data")
39
+ hf_cache_dir = os.path.join(data_dir, "hf_cache")
40
+
41
+ # Set cache environment variables
42
+ os.environ["HF_HOME"] = hf_cache_dir
43
+ os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache_dir
44
+ os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir
45
+
46
  try:
47
  os.makedirs(hf_cache_dir, exist_ok=True)
48
+ logger.info("HF cache directory: %s", hf_cache_dir)
49
+ except Exception as e:
50
+ # Fallback to user home if /data is not available (local dev)
51
+ logger.warning("Failed to create cache in /data: %s, trying home directory", str(e))
52
+ default_home_cache = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
53
+ hf_cache_dir = os.getenv("HF_HOME", default_home_cache)
54
  os.environ["HF_HOME"] = hf_cache_dir
55
  os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache_dir
56
  os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir
57
+ try:
58
+ os.makedirs(hf_cache_dir, exist_ok=True)
59
+ logger.info("HF cache directory (fallback): %s", hf_cache_dir)
60
+ except Exception as e2:
61
+ logger.error("Failed to create cache directory: %s", str(e2))
62
+ raise RuntimeError(f"Cannot create Hugging Face cache directory: {str(e2)}")
63
 
64
  # Avoid libgomp warning by setting a valid integer
65
  os.environ.setdefault("OMP_NUM_THREADS", "1")
app/main.py CHANGED
@@ -75,6 +75,7 @@ app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
75
 
76
  # Initialize ColorizeNet model
77
  colorize_model = None
 
78
 
79
  @app.get("/")
80
  async def root():
@@ -89,13 +90,17 @@ async def root():
89
  @app.on_event("startup")
90
  async def startup_event():
91
  """Initialize the colorization model on startup"""
92
- global colorize_model
93
  try:
94
- logger.info("Loading ColorizeNet model...")
 
95
  colorize_model = ColorizeModel(settings.MODEL_ID)
96
  logger.info("ColorizeNet model loaded successfully")
 
97
  except Exception as e:
98
- logger.error("Failed to load ColorizeNet model: %s", str(e))
 
 
99
  # Don't raise - allow health check to work even if model fails
100
 
101
  @app.on_event("shutdown")
@@ -156,10 +161,14 @@ async def verify_request(request: Request):
156
  @app.get("/health")
157
  async def health_check():
158
  """Health check endpoint"""
159
- return {
160
  "status": "healthy",
161
- "model_loaded": colorize_model is not None
 
162
  }
 
 
 
163
 
164
  @app.post("/upload")
165
  async def upload_image(
 
75
 
76
  # Initialize ColorizeNet model
77
  colorize_model = None
78
+ model_load_error: Optional[str] = None
79
 
80
  @app.get("/")
81
  async def root():
 
90
  @app.on_event("startup")
91
  async def startup_event():
92
  """Initialize the colorization model on startup"""
93
+ global colorize_model, model_load_error
94
  try:
95
+ logger.info("Loading ColorizeNet model with MODEL_ID: %s", settings.MODEL_ID)
96
+ logger.info("HF_TOKEN present: %s", "Yes" if os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") else "No")
97
  colorize_model = ColorizeModel(settings.MODEL_ID)
98
  logger.info("ColorizeNet model loaded successfully")
99
+ model_load_error = None
100
  except Exception as e:
101
+ error_msg = str(e)
102
+ logger.error("Failed to load ColorizeNet model: %s", error_msg)
103
+ model_load_error = error_msg
104
  # Don't raise - allow health check to work even if model fails
105
 
106
  @app.on_event("shutdown")
 
161
  @app.get("/health")
162
  async def health_check():
163
  """Health check endpoint"""
164
+ response = {
165
  "status": "healthy",
166
+ "model_loaded": colorize_model is not None,
167
+ "model_id": settings.MODEL_ID
168
  }
169
+ if model_load_error:
170
+ response["model_error"] = model_load_error
171
+ return response
172
 
173
  @app.post("/upload")
174
  async def upload_image(