feat(auth): accept Firebase Auth id_token (Authorization Bearer) in addition to App Check; add Postman collection and test script; default MODEL_ID to ControlNet color
2ae242d
| """ | |
| FastAPI application for image colorization using ColorizeNet model | |
| with Firebase App Check integration | |
| """ | |
| import os | |
| import uuid | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, Request | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| import firebase_admin | |
| from firebase_admin import credentials, app_check, auth as firebase_auth | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import io | |
| from app.colorize_model import ColorizeModel | |
| from app.config import settings | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Colorize API", | |
| description="Image colorization API using ColorizeNet model", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure appropriately for production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize Firebase Admin SDK | |
| firebase_cred_path = os.getenv("FIREBASE_CREDENTIALS_PATH", "colorize-662df-firebase-adminsdk-fbsvc-e080668793.json") | |
| if os.path.exists(firebase_cred_path): | |
| try: | |
| cred = credentials.Certificate(firebase_cred_path) | |
| firebase_admin.initialize_app(cred) | |
| logger.info("Firebase Admin SDK initialized") | |
| except Exception as e: | |
| logger.warning("Failed to initialize Firebase: %s", str(e)) | |
| firebase_admin.initialize_app() | |
| else: | |
| logger.warning("Firebase credentials file not found. App Check will be disabled.") | |
| try: | |
| firebase_admin.initialize_app() | |
| except: | |
| pass | |
| # Create directories | |
| UPLOAD_DIR = Path("uploads") | |
| RESULT_DIR = Path("results") | |
| UPLOAD_DIR.mkdir(exist_ok=True) | |
| RESULT_DIR.mkdir(exist_ok=True) | |
| # Mount static files for serving results | |
| app.mount("/results", StaticFiles(directory="results"), name="results") | |
| app.mount("/uploads", StaticFiles(directory="uploads"), name="uploads") | |
| # Initialize ColorizeNet model | |
| colorize_model = None | |
| async def startup_event(): | |
| """Initialize the colorization model on startup""" | |
| global colorize_model | |
| try: | |
| logger.info("Loading ColorizeNet model...") | |
| colorize_model = ColorizeModel(settings.MODEL_ID) | |
| logger.info("ColorizeNet model loaded successfully") | |
| except Exception as e: | |
| logger.error("Failed to load ColorizeNet model: %s", str(e)) | |
| # Don't raise - allow health check to work even if model fails | |
| async def shutdown_event(): | |
| """Cleanup on shutdown""" | |
| global colorize_model | |
| if colorize_model: | |
| del colorize_model | |
| logger.info("Application shutdown") | |
| def _extract_bearer_token(authorization_header: str | None) -> str | None: | |
| if not authorization_header: | |
| return None | |
| parts = authorization_header.split(" ", 1) | |
| if len(parts) == 2 and parts[0].lower() == "bearer": | |
| return parts[1].strip() | |
| return None | |
| async def verify_request(request: Request): | |
| """ | |
| Accept either: | |
| - Firebase Auth id_token via Authorization: Bearer <id_token> | |
| - Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true) | |
| """ | |
| # Try Firebase Auth id_token first if present | |
| bearer = _extract_bearer_token(request.headers.get("Authorization")) | |
| if bearer: | |
| try: | |
| decoded = firebase_auth.verify_id_token(bearer) | |
| request.state.user = decoded # make claims available if needed | |
| logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid")) | |
| return True | |
| except Exception as e: | |
| logger.warning("Auth token verification failed: %s", str(e)) | |
| # fall through to App Check if enabled | |
| # If App Check is enabled, require valid App Check token | |
| if settings.ENABLE_APP_CHECK: | |
| app_check_token = request.headers.get("X-Firebase-AppCheck") | |
| if not app_check_token: | |
| raise HTTPException(status_code=401, detail="Missing App Check token") | |
| try: | |
| app_check_claims = app_check.verify_token(app_check_token) | |
| logger.info("App Check token verified for: %s", app_check_claims.get("app_id")) | |
| return True | |
| except Exception as e: | |
| logger.warning("App Check token verification failed: %s", str(e)) | |
| raise HTTPException(status_code=401, detail="Invalid App Check token") | |
| # Neither token required nor provided → allow (App Check disabled) | |
| return True | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": colorize_model is not None | |
| } | |
| async def upload_image( | |
| file: UploadFile = File(...), | |
| verified: bool = Depends(verify_request) | |
| ): | |
| """ | |
| Upload an image and return the uploaded image URL | |
| """ | |
| if not file.content_type or not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Generate unique filename | |
| file_id = str(uuid.uuid4()) | |
| file_extension = Path(file.filename).suffix or ".jpg" | |
| filename = f"{file_id}{file_extension}" | |
| filepath = UPLOAD_DIR / filename | |
| # Save uploaded file | |
| try: | |
| contents = await file.read() | |
| with open(filepath, "wb") as f: | |
| f.write(contents) | |
| logger.info("Image uploaded: %s", filename) | |
| # Return the URL to access the uploaded image | |
| base_url = os.getenv("BASE_URL", os.getenv("SPACE_HOST", "http://localhost:7860")) | |
| image_url = f"{base_url}/uploads/{filename}" | |
| return { | |
| "success": True, | |
| "image_id": file_id, | |
| "image_url": image_url, | |
| "filename": filename | |
| } | |
| except Exception as e: | |
| logger.error("Error uploading image: %s", str(e)) | |
| raise HTTPException(status_code=500, detail=f"Error uploading image: {str(e)}") | |
| async def colorize_image( | |
| file: UploadFile = File(...), | |
| verified: bool = Depends(verify_request) | |
| ): | |
| """ | |
| Colorize an uploaded grayscale image using ColorizeNet | |
| Returns the colorized image URL | |
| """ | |
| if colorize_model is None: | |
| raise HTTPException(status_code=503, detail="Colorization model not loaded") | |
| if not file.content_type or not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| try: | |
| # Read image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| # Convert to RGB if needed | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Colorize the image | |
| logger.info("Colorizing image...") | |
| colorized_image = colorize_model.colorize(image) | |
| # Save colorized image | |
| file_id = str(uuid.uuid4()) | |
| result_filename = f"{file_id}.jpg" | |
| result_filepath = RESULT_DIR / result_filename | |
| colorized_image.save(result_filepath, "JPEG", quality=95) | |
| logger.info("Colorized image saved: %s", result_filename) | |
| # Return URLs | |
| base_url = os.getenv("BASE_URL", os.getenv("SPACE_HOST", "http://localhost:7860")) | |
| download_url = f"{base_url}/results/{result_filename}" | |
| api_download_url = f"{base_url}/download/{file_id}" | |
| return { | |
| "success": True, | |
| "result_id": file_id, | |
| "download_url": download_url, | |
| "api_download_url": api_download_url, | |
| "filename": result_filename | |
| } | |
| except Exception as e: | |
| logger.error("Error colorizing image: %s", str(e)) | |
| raise HTTPException(status_code=500, detail=f"Error colorizing image: {str(e)}") | |
| async def download_result( | |
| file_id: str, | |
| verified: bool = Depends(verify_request) | |
| ): | |
| """ | |
| Download the colorized image by file ID | |
| """ | |
| result_filepath = RESULT_DIR / f"{file_id}.jpg" | |
| if not result_filepath.exists(): | |
| raise HTTPException(status_code=404, detail="Result not found") | |
| return FileResponse( | |
| result_filepath, | |
| media_type="image/jpeg", | |
| filename=f"colorized_{file_id}.jpg" | |
| ) | |
| async def get_result_file(filename: str): | |
| """ | |
| Serve result files directly (public endpoint for browser access) | |
| """ | |
| result_filepath = RESULT_DIR / filename | |
| if not result_filepath.exists(): | |
| raise HTTPException(status_code=404, detail="File not found") | |
| return FileResponse( | |
| result_filepath, | |
| media_type="image/jpeg" | |
| ) | |
| async def get_upload_file(filename: str): | |
| """ | |
| Serve uploaded files directly | |
| """ | |
| upload_filepath = UPLOAD_DIR / filename | |
| if not upload_filepath.exists(): | |
| raise HTTPException(status_code=404, detail="File not found") | |
| return FileResponse( | |
| upload_filepath, | |
| media_type="image/jpeg" | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", "7860")) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |