from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Request, Form from typing import Optional from fastapi.responses import FileResponse from huggingface_hub import hf_hub_download import uuid import os import io import json from PIL import Image import torch from torchvision import transforms from app.database import ( get_database, log_api_call, log_image_upload, log_colorization, log_media_click, close_connection, ) try: from firebase_admin import auth as firebase_auth except ImportError: firebase_auth = None # ------------------------------------------------- # 🚀 FastAPI App # ------------------------------------------------- app = FastAPI(title="Text-Guided Image Colorization API") # ------------------------------------------------- # 🔐 Firebase Initialization (ENV-based) # ------------------------------------------------- try: import firebase_admin from firebase_admin import credentials, app_check firebase_json = os.getenv("FIREBASE_CREDENTIALS") if firebase_json: print("🔥 Loading Firebase credentials from ENV...") firebase_dict = json.loads(firebase_json) cred = credentials.Certificate(firebase_dict) firebase_admin.initialize_app(cred) else: print("⚠️ No Firebase credentials found. Firebase disabled.") except Exception as e: print("❌ Firebase initialization failed:", e) # ------------------------------------------------- # 📁 Directories (FIXED FOR HUGGINGFACE SPACES) # ------------------------------------------------- UPLOAD_DIR = "/tmp/uploads" RESULTS_DIR = "/tmp/results" os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(RESULTS_DIR, exist_ok=True) MEDIA_CLICK_DEFAULT_CATEGORY = os.getenv("DEFAULT_CATEGORY_FALLBACK", "69368fcd2e46bd68ae1889b2") # ------------------------------------------------- # 🧠 Load GAN Colorization Model # ------------------------------------------------- MODEL_REPO = "Hammad712/GAN-Colorization-Model" MODEL_FILENAME = "generator.pt" print("⬇️ Downloading model...") model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) print("📦 Loading model weights...") state_dict = torch.load(model_path, map_location="cpu") # NOTE: Replace with real model architecture # from model import ColorizeNet # model = ColorizeNet() # model.load_state_dict(state_dict) # model.eval() def colorize_image(img: Image.Image): """ Dummy colorizer (replace with real model.predict) """ transform = transforms.ToTensor() tensor = transform(img.convert("L")).unsqueeze(0) tensor = tensor.repeat(1, 3, 1, 1) output_img = transforms.ToPILImage()(tensor.squeeze()) return output_img # ------------------------------------------------- # 🗄️ MongoDB Initialization # ------------------------------------------------- @app.on_event("startup") async def startup_event(): """Initialize MongoDB on startup""" try: db = get_database() if db is not None: print("✅ MongoDB initialized successfully!") except Exception as e: print(f"⚠️ MongoDB initialization failed: {e}") @app.on_event("shutdown") async def shutdown_event(): """Cleanup on shutdown""" close_connection() print("Application shutdown") # ------------------------------------------------- # 🩺 Health Check # ------------------------------------------------- @app.get("/health") def health_check(request: Request): response = {"status": "healthy", "model_loaded": True} # Log API call log_api_call( endpoint="/health", method="GET", status_code=200, response_data=response, ip_address=request.client.host if request.client else None ) return response # ------------------------------------------------- # 🔐 Firebase Token Validator # ------------------------------------------------- def verify_app_check_token(token: str): if not token or len(token) < 20: raise HTTPException(status_code=401, detail="Invalid Firebase App Check token") return True def _resolve_user_id(request: Request, supplied_user_id: Optional[str]) -> Optional[str]: """Return supplied user_id if provided and not empty, otherwise None (will auto-generate in log_media_click).""" if supplied_user_id and supplied_user_id.strip(): return supplied_user_id.strip() return None # ------------------------------------------------- # 📤 Upload Image # ------------------------------------------------- @app.post("/upload") async def upload_image( request: Request, file: UploadFile = File(...), x_firebase_appcheck: str = Header(None), user_id: Optional[str] = Form(None), category_id: Optional[str] = Form(None), categoryId: Optional[str] = Form(None), ): verify_app_check_token(x_firebase_appcheck) ip_address = request.client.host if request.client else None effective_user_id = _resolve_user_id(request, user_id) effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None if effective_category_id: effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id if not effective_category_id: effective_category_id = None if not file.content_type.startswith("image/"): log_api_call( endpoint="/upload", method="POST", status_code=400, error="Invalid file type", ip_address=ip_address ) raise HTTPException(status_code=400, detail="Invalid file type") image_id = f"{uuid.uuid4()}.jpg" file_path = os.path.join(UPLOAD_DIR, image_id) img_bytes = await file.read() file_size = len(img_bytes) with open(file_path, "wb") as f: f.write(img_bytes) base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" response_data = { "success": True, "image_id": image_id.replace(".jpg", ""), "file_url": f"{base_url}/uploads/{image_id}" } # Log to MongoDB log_image_upload( image_id=image_id.replace(".jpg", ""), filename=file.filename or image_id, file_size=file_size, content_type=file.content_type or "image/jpeg", user_id=effective_user_id, ip_address=ip_address ) log_api_call( endpoint="/upload", method="POST", status_code=200, request_data={"filename": file.filename, "content_type": file.content_type}, response_data=response_data, user_id=effective_user_id, ip_address=ip_address ) log_media_click( user_id=effective_user_id, category_id=effective_category_id, endpoint_path=str(request.url.path), default_category_id=MEDIA_CLICK_DEFAULT_CATEGORY, ) return response_data # ------------------------------------------------- # 🎨 Colorize Image # ------------------------------------------------- @app.post("/colorize") async def colorize( request: Request, file: UploadFile = File(...), x_firebase_appcheck: str = Header(None), user_id: Optional[str] = Form(None), category_id: Optional[str] = Form(None), categoryId: Optional[str] = Form(None), ): import time start_time = time.time() verify_app_check_token(x_firebase_appcheck) ip_address = request.client.host if request.client else None effective_user_id = _resolve_user_id(request, user_id) effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None if effective_category_id: effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id if not effective_category_id: effective_category_id = None if not file.content_type.startswith("image/"): error_msg = "Invalid file type" log_api_call( endpoint="/colorize", method="POST", status_code=400, error=error_msg, ip_address=ip_address ) # Log failed colorization log_colorization( result_id=None, model_type="gan", processing_time=None, user_id=effective_user_id, ip_address=ip_address, status="failed", error=error_msg ) raise HTTPException(status_code=400, detail=error_msg) try: img = Image.open(io.BytesIO(await file.read())) output_img = colorize_image(img) processing_time = time.time() - start_time result_id = f"{uuid.uuid4()}.jpg" output_path = os.path.join(RESULTS_DIR, result_id) output_img.save(output_path) base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" result_id_clean = result_id.replace(".jpg", "") response_data = { "success": True, "result_id": result_id_clean, "download_url": f"{base_url}/results/{result_id}", "api_download": f"{base_url}/download/{result_id_clean}" } # Log to MongoDB (colorization_db -> colorizations) log_colorization( result_id=result_id_clean, model_type="gan", processing_time=processing_time, user_id=effective_user_id, ip_address=ip_address, status="success" ) log_api_call( endpoint="/colorize", method="POST", status_code=200, request_data={"filename": file.filename, "content_type": file.content_type}, response_data=response_data, user_id=effective_user_id, ip_address=ip_address ) log_media_click( user_id=effective_user_id, category_id=effective_category_id, endpoint_path=str(request.url.path), default_category_id=MEDIA_CLICK_DEFAULT_CATEGORY, ) return response_data except Exception as e: error_msg = str(e) logger.error("Error colorizing image: %s", error_msg) # Log failed colorization to colorizations collection log_colorization( result_id=None, model_type="gan", processing_time=None, user_id=effective_user_id, ip_address=ip_address, status="failed", error=error_msg ) log_api_call( endpoint="/colorize", method="POST", status_code=500, error=error_msg, user_id=effective_user_id, ip_address=ip_address ) raise HTTPException(status_code=500, detail=f"Error colorizing image: {error_msg}") # ------------------------------------------------- # ⬇️ Download via API (Secure) # ------------------------------------------------- @app.get("/download/{file_id}") def download_result( request: Request, file_id: str, x_firebase_appcheck: str = Header(None) ): verify_app_check_token(x_firebase_appcheck) ip_address = request.client.host if request.client else None filename = f"{file_id}.jpg" path = os.path.join(RESULTS_DIR, filename) if not os.path.exists(path): log_api_call( endpoint=f"/download/{file_id}", method="GET", status_code=404, error="Result not found", ip_address=ip_address ) raise HTTPException(status_code=404, detail="Result not found") log_api_call( endpoint=f"/download/{file_id}", method="GET", status_code=200, request_data={"file_id": file_id}, ip_address=ip_address ) return FileResponse(path, media_type="image/jpeg") # ------------------------------------------------- # 🌐 Public Result File # ------------------------------------------------- @app.get("/results/{filename}") def get_result(request: Request, filename: str): ip_address = request.client.host if request.client else None path = os.path.join(RESULTS_DIR, filename) if not os.path.exists(path): log_api_call( endpoint=f"/results/{filename}", method="GET", status_code=404, error="Result not found", ip_address=ip_address ) raise HTTPException(status_code=404, detail="Result not found") log_api_call( endpoint=f"/results/{filename}", method="GET", status_code=200, request_data={"filename": filename}, ip_address=ip_address ) return FileResponse(path, media_type="image/jpeg") # ------------------------------------------------- # 🌐 Public Uploaded File # ------------------------------------------------- @app.get("/uploads/{filename}") def get_upload(request: Request, filename: str): ip_address = request.client.host if request.client else None path = os.path.join(UPLOAD_DIR, filename) if not os.path.exists(path): log_api_call( endpoint=f"/uploads/{filename}", method="GET", status_code=404, error="File not found", ip_address=ip_address ) raise HTTPException(status_code=404, detail="File not found") log_api_call( endpoint=f"/uploads/{filename}", method="GET", status_code=200, request_data={"filename": filename}, ip_address=ip_address ) return FileResponse(path, media_type="image/jpeg")