LogicGoInfotechSpaces's picture
Add success/failure logging to colorizations collection
2ebd872
from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Request, Form
from typing import Optional
from fastapi.responses import FileResponse
from huggingface_hub import hf_hub_download
import uuid
import os
import io
import json
from PIL import Image
import torch
from torchvision import transforms
from app.database import (
get_database,
log_api_call,
log_image_upload,
log_colorization,
log_media_click,
close_connection,
)
try:
from firebase_admin import auth as firebase_auth
except ImportError:
firebase_auth = None
# -------------------------------------------------
# 🚀 FastAPI App
# -------------------------------------------------
app = FastAPI(title="Text-Guided Image Colorization API")
# -------------------------------------------------
# 🔐 Firebase Initialization (ENV-based)
# -------------------------------------------------
try:
import firebase_admin
from firebase_admin import credentials, app_check
firebase_json = os.getenv("FIREBASE_CREDENTIALS")
if firebase_json:
print("🔥 Loading Firebase credentials from ENV...")
firebase_dict = json.loads(firebase_json)
cred = credentials.Certificate(firebase_dict)
firebase_admin.initialize_app(cred)
else:
print("⚠️ No Firebase credentials found. Firebase disabled.")
except Exception as e:
print("❌ Firebase initialization failed:", e)
# -------------------------------------------------
# 📁 Directories (FIXED FOR HUGGINGFACE SPACES)
# -------------------------------------------------
UPLOAD_DIR = "/tmp/uploads"
RESULTS_DIR = "/tmp/results"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
MEDIA_CLICK_DEFAULT_CATEGORY = os.getenv("DEFAULT_CATEGORY_FALLBACK", "69368fcd2e46bd68ae1889b2")
# -------------------------------------------------
# 🧠 Load GAN Colorization Model
# -------------------------------------------------
MODEL_REPO = "Hammad712/GAN-Colorization-Model"
MODEL_FILENAME = "generator.pt"
print("⬇️ Downloading model...")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
print("📦 Loading model weights...")
state_dict = torch.load(model_path, map_location="cpu")
# NOTE: Replace with real model architecture
# from model import ColorizeNet
# model = ColorizeNet()
# model.load_state_dict(state_dict)
# model.eval()
def colorize_image(img: Image.Image):
""" Dummy colorizer (replace with real model.predict) """
transform = transforms.ToTensor()
tensor = transform(img.convert("L")).unsqueeze(0)
tensor = tensor.repeat(1, 3, 1, 1)
output_img = transforms.ToPILImage()(tensor.squeeze())
return output_img
# -------------------------------------------------
# 🗄️ MongoDB Initialization
# -------------------------------------------------
@app.on_event("startup")
async def startup_event():
"""Initialize MongoDB on startup"""
try:
db = get_database()
if db is not None:
print("✅ MongoDB initialized successfully!")
except Exception as e:
print(f"⚠️ MongoDB initialization failed: {e}")
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
close_connection()
print("Application shutdown")
# -------------------------------------------------
# 🩺 Health Check
# -------------------------------------------------
@app.get("/health")
def health_check(request: Request):
response = {"status": "healthy", "model_loaded": True}
# Log API call
log_api_call(
endpoint="/health",
method="GET",
status_code=200,
response_data=response,
ip_address=request.client.host if request.client else None
)
return response
# -------------------------------------------------
# 🔐 Firebase Token Validator
# -------------------------------------------------
def verify_app_check_token(token: str):
if not token or len(token) < 20:
raise HTTPException(status_code=401, detail="Invalid Firebase App Check token")
return True
def _resolve_user_id(request: Request, supplied_user_id: Optional[str]) -> Optional[str]:
"""Return supplied user_id if provided and not empty, otherwise None (will auto-generate in log_media_click)."""
if supplied_user_id and supplied_user_id.strip():
return supplied_user_id.strip()
return None
# -------------------------------------------------
# 📤 Upload Image
# -------------------------------------------------
@app.post("/upload")
async def upload_image(
request: Request,
file: UploadFile = File(...),
x_firebase_appcheck: str = Header(None),
user_id: Optional[str] = Form(None),
category_id: Optional[str] = Form(None),
categoryId: Optional[str] = Form(None),
):
verify_app_check_token(x_firebase_appcheck)
ip_address = request.client.host if request.client else None
effective_user_id = _resolve_user_id(request, user_id)
effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None
if effective_category_id:
effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id
if not effective_category_id:
effective_category_id = None
if not file.content_type.startswith("image/"):
log_api_call(
endpoint="/upload",
method="POST",
status_code=400,
error="Invalid file type",
ip_address=ip_address
)
raise HTTPException(status_code=400, detail="Invalid file type")
image_id = f"{uuid.uuid4()}.jpg"
file_path = os.path.join(UPLOAD_DIR, image_id)
img_bytes = await file.read()
file_size = len(img_bytes)
with open(file_path, "wb") as f:
f.write(img_bytes)
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
response_data = {
"success": True,
"image_id": image_id.replace(".jpg", ""),
"file_url": f"{base_url}/uploads/{image_id}"
}
# Log to MongoDB
log_image_upload(
image_id=image_id.replace(".jpg", ""),
filename=file.filename or image_id,
file_size=file_size,
content_type=file.content_type or "image/jpeg",
user_id=effective_user_id,
ip_address=ip_address
)
log_api_call(
endpoint="/upload",
method="POST",
status_code=200,
request_data={"filename": file.filename, "content_type": file.content_type},
response_data=response_data,
user_id=effective_user_id,
ip_address=ip_address
)
log_media_click(
user_id=effective_user_id,
category_id=effective_category_id,
endpoint_path=str(request.url.path),
default_category_id=MEDIA_CLICK_DEFAULT_CATEGORY,
)
return response_data
# -------------------------------------------------
# 🎨 Colorize Image
# -------------------------------------------------
@app.post("/colorize")
async def colorize(
request: Request,
file: UploadFile = File(...),
x_firebase_appcheck: str = Header(None),
user_id: Optional[str] = Form(None),
category_id: Optional[str] = Form(None),
categoryId: Optional[str] = Form(None),
):
import time
start_time = time.time()
verify_app_check_token(x_firebase_appcheck)
ip_address = request.client.host if request.client else None
effective_user_id = _resolve_user_id(request, user_id)
effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None
if effective_category_id:
effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id
if not effective_category_id:
effective_category_id = None
if not file.content_type.startswith("image/"):
error_msg = "Invalid file type"
log_api_call(
endpoint="/colorize",
method="POST",
status_code=400,
error=error_msg,
ip_address=ip_address
)
# Log failed colorization
log_colorization(
result_id=None,
model_type="gan",
processing_time=None,
user_id=effective_user_id,
ip_address=ip_address,
status="failed",
error=error_msg
)
raise HTTPException(status_code=400, detail=error_msg)
try:
img = Image.open(io.BytesIO(await file.read()))
output_img = colorize_image(img)
processing_time = time.time() - start_time
result_id = f"{uuid.uuid4()}.jpg"
output_path = os.path.join(RESULTS_DIR, result_id)
output_img.save(output_path)
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
result_id_clean = result_id.replace(".jpg", "")
response_data = {
"success": True,
"result_id": result_id_clean,
"download_url": f"{base_url}/results/{result_id}",
"api_download": f"{base_url}/download/{result_id_clean}"
}
# Log to MongoDB (colorization_db -> colorizations)
log_colorization(
result_id=result_id_clean,
model_type="gan",
processing_time=processing_time,
user_id=effective_user_id,
ip_address=ip_address,
status="success"
)
log_api_call(
endpoint="/colorize",
method="POST",
status_code=200,
request_data={"filename": file.filename, "content_type": file.content_type},
response_data=response_data,
user_id=effective_user_id,
ip_address=ip_address
)
log_media_click(
user_id=effective_user_id,
category_id=effective_category_id,
endpoint_path=str(request.url.path),
default_category_id=MEDIA_CLICK_DEFAULT_CATEGORY,
)
return response_data
except Exception as e:
error_msg = str(e)
logger.error("Error colorizing image: %s", error_msg)
# Log failed colorization to colorizations collection
log_colorization(
result_id=None,
model_type="gan",
processing_time=None,
user_id=effective_user_id,
ip_address=ip_address,
status="failed",
error=error_msg
)
log_api_call(
endpoint="/colorize",
method="POST",
status_code=500,
error=error_msg,
user_id=effective_user_id,
ip_address=ip_address
)
raise HTTPException(status_code=500, detail=f"Error colorizing image: {error_msg}")
# -------------------------------------------------
# ⬇️ Download via API (Secure)
# -------------------------------------------------
@app.get("/download/{file_id}")
def download_result(
request: Request,
file_id: str,
x_firebase_appcheck: str = Header(None)
):
verify_app_check_token(x_firebase_appcheck)
ip_address = request.client.host if request.client else None
filename = f"{file_id}.jpg"
path = os.path.join(RESULTS_DIR, filename)
if not os.path.exists(path):
log_api_call(
endpoint=f"/download/{file_id}",
method="GET",
status_code=404,
error="Result not found",
ip_address=ip_address
)
raise HTTPException(status_code=404, detail="Result not found")
log_api_call(
endpoint=f"/download/{file_id}",
method="GET",
status_code=200,
request_data={"file_id": file_id},
ip_address=ip_address
)
return FileResponse(path, media_type="image/jpeg")
# -------------------------------------------------
# 🌐 Public Result File
# -------------------------------------------------
@app.get("/results/{filename}")
def get_result(request: Request, filename: str):
ip_address = request.client.host if request.client else None
path = os.path.join(RESULTS_DIR, filename)
if not os.path.exists(path):
log_api_call(
endpoint=f"/results/{filename}",
method="GET",
status_code=404,
error="Result not found",
ip_address=ip_address
)
raise HTTPException(status_code=404, detail="Result not found")
log_api_call(
endpoint=f"/results/{filename}",
method="GET",
status_code=200,
request_data={"filename": filename},
ip_address=ip_address
)
return FileResponse(path, media_type="image/jpeg")
# -------------------------------------------------
# 🌐 Public Uploaded File
# -------------------------------------------------
@app.get("/uploads/{filename}")
def get_upload(request: Request, filename: str):
ip_address = request.client.host if request.client else None
path = os.path.join(UPLOAD_DIR, filename)
if not os.path.exists(path):
log_api_call(
endpoint=f"/uploads/{filename}",
method="GET",
status_code=404,
error="File not found",
ip_address=ip_address
)
raise HTTPException(status_code=404, detail="File not found")
log_api_call(
endpoint=f"/uploads/{filename}",
method="GET",
status_code=200,
request_data={"filename": filename},
ip_address=ip_address
)
return FileResponse(path, media_type="image/jpeg")