LogicGoInfotechSpaces commited on
Commit
5a116a6
·
verified ·
1 Parent(s): df14637

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +50 -349
app/main.py CHANGED
@@ -1,383 +1,84 @@
1
- """
2
- FastAPI application for FastAI GAN Image Colorization
3
- with Firebase Authentication and Gradio UI
4
- """
5
- import os
6
- # Set environment variables BEFORE any imports
7
- os.environ["OMP_NUM_THREADS"] = "1"
8
- os.environ["HF_HOME"] = "/tmp/hf_cache"
9
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
10
- os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
11
- os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf_cache"
12
- os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache"
13
- os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config"
14
- os.environ["GRADIO_TEMP_DIR"] = "/tmp/gradio"
15
-
16
  import io
17
  import uuid
18
- import logging
19
- from pathlib import Path
20
- from typing import Optional
21
-
22
- from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request
23
- from fastapi.responses import FileResponse, JSONResponse
24
- from fastapi.middleware.cors import CORSMiddleware
25
- from fastapi.staticfiles import StaticFiles
26
- import firebase_admin
27
- from firebase_admin import credentials, app_check, auth as firebase_auth
28
  from PIL import Image
 
 
 
 
29
  import torch
30
  import uvicorn
31
- import gradio as gr
32
-
33
- # FastAI imports
34
- from fastai.vision.all import *
35
- from huggingface_hub import from_pretrained_fastai, hf_hub_download, list_repo_files
36
-
37
- from app.config import settings
38
-
39
- # Configure logging
40
- logging.basicConfig(
41
- level=logging.INFO,
42
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
43
- )
44
- logger = logging.getLogger(__name__)
45
-
46
- # Create writable directories
47
- Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True)
48
- Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True)
49
- Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True)
50
- Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True)
51
-
52
- # Initialize FastAPI app
53
- app = FastAPI(
54
- title="FastAI Image Colorizer API",
55
- description="Image colorization using FastAI GAN model with Firebase authentication",
56
- version="1.0.0"
57
- )
58
-
59
- # CORS middleware
60
- app.add_middleware(
61
- CORSMiddleware,
62
- allow_origins=["*"],
63
- allow_credentials=True,
64
- allow_methods=["*"],
65
- allow_headers=["*"],
66
- )
67
-
68
- # Initialize Firebase Admin SDK
69
- firebase_cred_path = os.getenv("FIREBASE_CREDENTIALS_PATH", "/tmp/firebase-adminsdk.json")
70
- if os.path.exists(firebase_cred_path):
71
- try:
72
- cred = credentials.Certificate(firebase_cred_path)
73
- firebase_admin.initialize_app(cred)
74
- logger.info("Firebase Admin SDK initialized")
75
- except Exception as e:
76
- logger.warning("Failed to initialize Firebase: %s", str(e))
77
- try:
78
- firebase_admin.initialize_app()
79
- except:
80
- pass
81
- else:
82
- logger.warning("Firebase credentials file not found. App Check will be disabled.")
83
- try:
84
- firebase_admin.initialize_app()
85
- except:
86
- pass
87
-
88
- # Storage directories
89
- UPLOAD_DIR = Path("/tmp/colorize_uploads")
90
- RESULT_DIR = Path("/tmp/colorize_results")
91
-
92
- # Mount static files
93
- app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
94
- app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
95
-
96
- # Initialize FastAI model
97
- learn = None
98
- model_load_error: Optional[str] = None
99
-
100
- @app.on_event("startup")
101
- async def startup_event():
102
- """Load FastAI model on startup"""
103
- global learn, model_load_error
104
- try:
105
- model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
106
- logger.info("🔄 Loading FastAI GAN Colorization Model: %s", model_id)
107
-
108
- # Try using from_pretrained_fastai first
109
- try:
110
- learn = from_pretrained_fastai(model_id)
111
- logger.info("✅ Model loaded successfully via from_pretrained_fastai!")
112
- model_load_error = None
113
- except Exception as e1:
114
- logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1))
115
- # Fallback: manually download and load the model file
116
- hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
117
-
118
- # List repository files to find the actual model file
119
- model_filenames = []
120
- model_type = "fastai"
121
-
122
- try:
123
- repo_files = list_repo_files(repo_id=model_id, token=hf_token)
124
- logger.info("Repository files: %s", repo_files)
125
- pkl_files = [f for f in repo_files if f.endswith('.pkl')]
126
- pt_files = [f for f in repo_files if f.endswith('.pt')]
127
-
128
- if pkl_files:
129
- model_filenames = pkl_files
130
- logger.info("Found .pkl files in repository: %s", pkl_files)
131
- model_type = "fastai"
132
- elif pt_files:
133
- model_filenames = pt_files
134
- logger.info("Found .pt files in repository: %s", pt_files)
135
- model_type = "pytorch"
136
- else:
137
- model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
138
- model_type = "fastai"
139
- except Exception as list_err:
140
- logger.warning("Could not list repository files: %s. Trying common filenames...", str(list_err))
141
- model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
142
- model_type = "fastai"
143
-
144
- # Try to download and load the model file
145
- cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
146
- model_path = None
147
- for filename in model_filenames:
148
- try:
149
- model_path = hf_hub_download(
150
- repo_id=model_id,
151
- filename=filename,
152
- cache_dir=cache_dir,
153
- token=hf_token
154
- )
155
- logger.info("Found model file: %s", filename)
156
- if filename.endswith('.pt'):
157
- model_type = "pytorch"
158
- elif filename.endswith('.pkl'):
159
- model_type = "fastai"
160
- break
161
- except Exception as dl_err:
162
- logger.debug("Failed to download %s: %s", filename, str(dl_err))
163
- continue
164
-
165
- if model_path and os.path.exists(model_path):
166
- if model_type == "pytorch":
167
- error_msg = (
168
- f"Repository '{model_id}' contains a PyTorch model (.pt file), "
169
- f"not a FastAI model. FastAI models must be .pkl files created with FastAI's export. "
170
- f"Please use a FastAI-compatible colorization model, or switch to a different model backend."
171
- )
172
- logger.error(error_msg)
173
- model_load_error = error_msg
174
- raise RuntimeError(error_msg)
175
- else:
176
- logger.info("Loading FastAI model from: %s", model_path)
177
- learn = load_learner(model_path)
178
- logger.info("✅ Model loaded successfully from %s", model_path)
179
- model_load_error = None
180
- else:
181
- error_msg = (
182
- f"Could not find model file in repository '{model_id}'. "
183
- f"Tried: {', '.join(model_filenames)}. "
184
- f"Original error: {str(e1)}"
185
- )
186
- logger.error(error_msg)
187
- model_load_error = error_msg
188
- raise RuntimeError(error_msg)
189
-
190
- except Exception as e:
191
- error_msg = str(e)
192
- if not model_load_error:
193
- model_load_error = error_msg
194
- logger.error("❌ Failed to load model: %s", error_msg)
195
- # Don't raise - allow health check to work
196
-
197
- @app.on_event("shutdown")
198
- async def shutdown_event():
199
- """Cleanup on shutdown"""
200
- global learn
201
- if learn:
202
- del learn
203
- logger.info("Application shutdown")
204
-
205
- def _extract_bearer_token(authorization_header: str | None) -> str | None:
206
- if not authorization_header:
207
- return None
208
- parts = authorization_header.split(" ", 1)
209
- if len(parts) == 2 and parts[0].lower() == "bearer":
210
- return parts[1].strip()
211
- return None
212
-
213
- async def verify_request(request: Request):
214
- """
215
- Verify Firebase authentication
216
- Accept either:
217
- - Firebase Auth id_token via Authorization: Bearer <id_token>
218
- - Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true)
219
- """
220
- # If Firebase is not initialized or auth is explicitly disabled, allow
221
- if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true":
222
- return True
223
-
224
- # Try Firebase Auth id_token first if present
225
- bearer = _extract_bearer_token(request.headers.get("Authorization"))
226
- if bearer:
227
- try:
228
- decoded = firebase_auth.verify_id_token(bearer)
229
- request.state.user = decoded
230
- logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid"))
231
- return True
232
- except Exception as e:
233
- logger.warning("Auth token verification failed: %s", str(e))
234
-
235
- # If App Check is enabled, require valid App Check token
236
- if settings.ENABLE_APP_CHECK:
237
- app_check_token = request.headers.get("X-Firebase-AppCheck")
238
- if not app_check_token:
239
- raise HTTPException(status_code=401, detail="Missing App Check token")
240
- try:
241
- app_check_claims = app_check.verify_token(app_check_token)
242
- logger.info("App Check token verified for: %s", app_check_claims.get("app_id"))
243
- return True
244
- except Exception as e:
245
- logger.warning("App Check token verification failed: %s", str(e))
246
- raise HTTPException(status_code=401, detail="Invalid App Check token")
247
 
248
- # Neither token required nor provided → allow (App Check disabled)
249
- return True
 
 
250
 
251
- @app.get("/api")
252
- async def api_info():
253
- """API info endpoint"""
254
- return {
255
- "app": "FastAI Image Colorizer API",
256
- "version": "1.0.0",
257
- "health": "/health",
258
- "colorize": "/colorize",
259
- "gradio": "/"
260
- }
261
 
262
- @app.get("/health")
263
- async def health_check():
264
- """Health check endpoint"""
265
- response = {
266
- "status": "healthy",
267
- "model_loaded": learn is not None,
268
- "model_id": os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
269
- }
270
- if model_load_error:
271
- response["model_error"] = model_load_error
272
- return response
273
 
274
- def colorize_pil(image: Image.Image) -> Image.Image:
275
- """Run model prediction and return colorized image"""
276
- if learn is None:
277
- raise RuntimeError("Model not loaded")
278
  if image.mode != "RGB":
279
  image = image.convert("RGB")
280
- pred = learn.predict(image)
281
- # Handle different return types from FastAI
282
- if isinstance(pred, (list, tuple)):
283
- colorized = pred[0] if len(pred) > 0 else image
284
- else:
285
- colorized = pred
286
-
287
- # Ensure we have a PIL Image
288
- if not isinstance(colorized, Image.Image):
289
- if isinstance(colorized, torch.Tensor):
290
- # Convert tensor to PIL
291
- if colorized.dim() == 4:
292
- colorized = colorized[0]
293
- if colorized.dim() == 3:
294
- colorized = colorized.permute(1, 2, 0).cpu()
295
- if colorized.dtype in (torch.float32, torch.float16):
296
- colorized = torch.clamp(colorized, 0, 1)
297
- colorized = (colorized * 255).byte()
298
- colorized = Image.fromarray(colorized.numpy(), 'RGB')
299
- else:
300
- raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
301
- else:
302
- raise ValueError(f"Unexpected prediction type: {type(colorized)}")
303
-
304
- if colorized.mode != "RGB":
305
- colorized = colorized.convert("RGB")
306
-
307
- return colorized
308
 
309
  @app.post("/colorize")
310
- async def colorize_api(
311
- file: UploadFile = File(...),
312
- verified: bool = Depends(verify_request)
313
- ):
314
  """
315
- Upload a black & white image -> returns colorized image.
316
- Requires Firebase authentication unless DISABLE_AUTH=true
317
  """
318
- if learn is None:
319
- raise HTTPException(status_code=503, detail="Colorization model not loaded")
320
-
321
- if not file.content_type or not file.content_type.startswith("image/"):
322
- raise HTTPException(status_code=400, detail="File must be an image")
323
-
324
  try:
325
  img_bytes = await file.read()
326
  image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
327
-
328
- logger.info("Colorizing image...")
329
- colorized = colorize_pil(image)
330
-
331
  output_filename = f"{uuid.uuid4()}.png"
332
- output_path = RESULT_DIR / output_filename
333
- colorized.save(output_path, "PNG")
334
-
335
- logger.info("Colorized image saved: %s", output_filename)
336
-
337
- # Return the image file
338
- return FileResponse(
339
- output_path,
340
- media_type="image/png",
341
- filename=f"colorized_{output_filename}"
342
- )
343
  except Exception as e:
344
- logger.error("Error colorizing image: %s", str(e))
345
- raise HTTPException(status_code=500, detail=f"Error colorizing image: {str(e)}")
346
 
347
  # ==========================================================
348
- # Gradio Interface (for Space UI)
349
  # ==========================================================
350
- def gradio_colorize(image):
351
- """Gradio colorization function"""
352
- if image is None:
353
- return None
354
- try:
355
- if learn is None:
356
- return None
357
- return colorize_pil(image)
358
- except Exception as e:
359
- logger.error("Gradio colorization error: %s", str(e))
360
- return None
361
 
362
- title = "🎨 FastAI GAN Image Colorizer"
363
- description = "Upload a black & white photo to generate a colorized version using the FastAI GAN model."
364
 
365
  iface = gr.Interface(
366
- fn=gradio_colorize,
367
- inputs=gr.Image(type="pil", label="Upload B&W Image"),
368
- outputs=gr.Image(type="pil", label="Colorized Image"),
369
  title=title,
370
  description=description,
371
  )
372
 
373
- # Mount Gradio app at root (this will be the Space UI)
374
- # Note: This will override the root endpoint, so use /api for API info
375
- app = gr.mount_gradio_app(app, iface, path="/")
376
 
377
  # ==========================================================
378
- # Run Server
379
  # ==========================================================
380
  if __name__ == "__main__":
381
- port = int(os.getenv("PORT", "7860"))
382
- uvicorn.run(app, host="0.0.0.0", port=port)
383
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import io
2
  import uuid
3
+ import os
 
 
 
 
 
 
 
 
 
4
  from PIL import Image
5
+ from fastapi import FastAPI, UploadFile, File
6
+ from fastapi.responses import FileResponse
7
+ from huggingface_hub import from_pretrained_fastai
8
+ import gradio as gr
9
  import torch
10
  import uvicorn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # ==========================================================
13
+ # 🔧 CONFIGURATION
14
+ # ==========================================================
15
+ MODEL_ID = "Hammad712/GAN-Colorization-Model" # 👉 change this if you want another model
16
 
17
+ UPLOAD_DIR = "/tmp/uploads"
18
+ RESULT_DIR = "/tmp/results"
19
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
20
+ os.makedirs(RESULT_DIR, exist_ok=True)
 
 
 
 
 
 
21
 
22
+ # ==========================================================
23
+ # 🚀 LOAD MODEL
24
+ # ==========================================================
25
+ print(f"Loading model: {MODEL_ID}")
26
+ learn = from_pretrained_fastai(MODEL_ID)
27
+ print(" Model loaded successfully!")
 
 
 
 
 
28
 
29
+ # ==========================================================
30
+ # 🧠 Colorization Function
31
+ # ==========================================================
32
+ def colorize_image(image: Image.Image):
33
  if image.mode != "RGB":
34
  image = image.convert("RGB")
35
+ pred = learn.predict(image)[0]
36
+ return pred
37
+
38
+ # ==========================================================
39
+ # 🌐 FASTAPI APP
40
+ # ==========================================================
41
+ app = FastAPI(title="Image Colorization API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  @app.post("/colorize")
44
+ async def colorize_endpoint(file: UploadFile = File(...)):
 
 
 
45
  """
46
+ Upload a black & white image -> get colorized image
 
47
  """
 
 
 
 
 
 
48
  try:
49
  img_bytes = await file.read()
50
  image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
51
+
52
+ colorized = colorize_image(image)
 
 
53
  output_filename = f"{uuid.uuid4()}.png"
54
+ output_path = os.path.join(RESULT_DIR, output_filename)
55
+ colorized.save(output_path)
56
+
57
+ return FileResponse(output_path, media_type="image/png")
 
 
 
 
 
 
 
58
  except Exception as e:
59
+ return {"error": str(e)}
 
60
 
61
  # ==========================================================
62
+ # 🎨 GRADIO INTERFACE
63
  # ==========================================================
64
+ def gradio_interface(image):
65
+ return colorize_image(image)
 
 
 
 
 
 
 
 
 
66
 
67
+ title = "🎨 FastAI / HuggingFace Image Colorizer"
68
+ description = "Upload a black & white photo to get a colorized version."
69
 
70
  iface = gr.Interface(
71
+ fn=gradio_interface,
72
+ inputs=gr.Image(type="pil", label="Upload Image"),
73
+ outputs=gr.Image(type="pil", label="Colorized Output"),
74
  title=title,
75
  description=description,
76
  )
77
 
78
+ gradio_app = gr.mount_gradio_app(app, iface, path="/")
 
 
79
 
80
  # ==========================================================
81
+ # ▶️ RUN LOCALLY OR IN HUGGINGFACE SPACE
82
  # ==========================================================
83
  if __name__ == "__main__":
84
+ uvicorn.run(app, host="0.0.0.0", port=7860)