|
|
""" |
|
|
FastAPI application for Text-Guided Image Colorization using Hugging Face Inference API |
|
|
Uses fal-ai provider for memory-efficient inference |
|
|
""" |
|
|
import os |
|
|
import io |
|
|
import uuid |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request, Body |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
import firebase_admin |
|
|
from firebase_admin import credentials, app_check, auth as firebase_auth |
|
|
from PIL import Image |
|
|
import uvicorn |
|
|
import gradio as gr |
|
|
import httpx |
|
|
from pydantic import BaseModel, EmailStr |
|
|
|
|
|
|
|
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
from app.config import settings |
|
|
from app.database import get_database, log_api_call, log_image_upload, log_colorization, close_connection |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True) |
|
|
Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True) |
|
|
Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True) |
|
|
Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Text-Guided Image Colorization API", |
|
|
description="Image colorization using SDXL + ControlNet with automatic captioning", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
firebase_cred_paths = [ |
|
|
os.getenv("FIREBASE_CREDENTIALS_PATH"), |
|
|
"/tmp/firebase-adminsdk.json", |
|
|
"/data/firebase-adminsdk.json", |
|
|
"colorize-662df-firebase-adminsdk-fbsvc-bfd21c77c6.json", |
|
|
os.path.join(os.path.dirname(__file__), "..", "colorize-662df-firebase-adminsdk-fbsvc-bfd21c77c6.json"), |
|
|
] |
|
|
|
|
|
firebase_initialized = False |
|
|
for cred_path in firebase_cred_paths: |
|
|
if not cred_path: |
|
|
continue |
|
|
cred_path = os.path.abspath(cred_path) |
|
|
if os.path.exists(cred_path): |
|
|
try: |
|
|
cred = credentials.Certificate(cred_path) |
|
|
firebase_admin.initialize_app(cred) |
|
|
logger.info("Firebase Admin SDK initialized from: %s", cred_path) |
|
|
firebase_initialized = True |
|
|
break |
|
|
except Exception as e: |
|
|
logger.warning("Failed to initialize Firebase from %s: %s", cred_path, str(e)) |
|
|
continue |
|
|
|
|
|
|
|
|
if not firebase_initialized: |
|
|
firebase_json = os.getenv("FIREBASE_CREDENTIALS") |
|
|
if firebase_json: |
|
|
try: |
|
|
import json |
|
|
firebase_dict = json.loads(firebase_json) |
|
|
cred = credentials.Certificate(firebase_dict) |
|
|
firebase_admin.initialize_app(cred) |
|
|
logger.info("Firebase Admin SDK initialized from environment variable") |
|
|
firebase_initialized = True |
|
|
except Exception as e: |
|
|
logger.warning("Failed to initialize Firebase from environment: %s", str(e)) |
|
|
|
|
|
if not firebase_initialized: |
|
|
logger.warning("Firebase credentials file not found. App Check will be disabled.") |
|
|
try: |
|
|
firebase_admin.initialize_app() |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
UPLOAD_DIR = Path("/tmp/colorize_uploads") |
|
|
RESULT_DIR = Path("/tmp/colorize_results") |
|
|
|
|
|
|
|
|
app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results") |
|
|
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads") |
|
|
|
|
|
|
|
|
inference_client = None |
|
|
model_load_error: Optional[str] = None |
|
|
|
|
|
|
|
|
|
|
|
def apply_color(image: Image.Image, color_map: Image.Image) -> Image.Image: |
|
|
"""Apply color from color_map to image using LAB color space.""" |
|
|
|
|
|
image_lab = image.convert('LAB') |
|
|
color_map_lab = color_map.convert('LAB') |
|
|
|
|
|
|
|
|
l, _, _ = image_lab.split() |
|
|
_, a_map, b_map = color_map_lab.split() |
|
|
merged_lab = Image.merge('LAB', (l, a_map, b_map)) |
|
|
|
|
|
return merged_lab.convert('RGB') |
|
|
|
|
|
|
|
|
def remove_unlikely_words(prompt: str) -> str: |
|
|
"""Removes predefined unlikely phrases from prompt text.""" |
|
|
unlikely_words = [] |
|
|
|
|
|
a1 = [f'{i}s' for i in range(1900, 2000)] |
|
|
a2 = [f'{i}' for i in range(1900, 2000)] |
|
|
a3 = [f'year {i}' for i in range(1900, 2000)] |
|
|
a4 = [f'circa {i}' for i in range(1900, 2000)] |
|
|
|
|
|
b1 = [f"{y[0]} {y[1]} {y[2]} {y[3]} s" for y in a1] |
|
|
b2 = [f"{y[0]} {y[1]} {y[2]} {y[3]}" for y in a1] |
|
|
b3 = [f"year {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1] |
|
|
b4 = [f"circa {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1] |
|
|
|
|
|
manual = [ |
|
|
"black and white,", "black and white", "black & white,", "black & white", "circa", |
|
|
"balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,", |
|
|
"black - and - white photography,", "monochrome bw,", "black white,", "black an white,", |
|
|
"grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo", |
|
|
"back and white", "back and white,", "monochrome contrast", "monochrome", "grainy", |
|
|
"grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w", |
|
|
"grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo", |
|
|
"b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,", |
|
|
"black-and-white photo,", "black-and-white photo", "black - and - white photography", |
|
|
"b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic", |
|
|
"blurry photo,", "blurry,", "blurry photography,", "monochromatic photo", |
|
|
"black - and - white photograph,", "black - and - white photograph", "black on white,", |
|
|
"black on white", "black-and-white", "historical image,", "historical picture,", |
|
|
"historical photo,", "historical photograph,", "archival photo,", "taken in the early", |
|
|
"taken in the late", "taken in the", "historic photograph,", "restored,", "restored", |
|
|
"historical photo", "historical setting,", |
|
|
"historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated", |
|
|
"taken in", "shot on leica", "shot on leica sl2", "sl2", |
|
|
"taken with a leica camera", "leica sl2", "leica", "setting", |
|
|
"overcast day", "overcast weather", "slight overcast", "overcast", |
|
|
"picture taken in", "photo taken in", |
|
|
", photo", ", photo", ", photo", ", photo", ", photograph", |
|
|
",,", ",,,", ",,,,", " ,", " ,", " ,", " ,", |
|
|
] |
|
|
|
|
|
unlikely_words.extend(a1 + a2 + a3 + a4 + b1 + b2 + b3 + b4 + manual) |
|
|
|
|
|
for word in unlikely_words: |
|
|
prompt = prompt.replace(word, "") |
|
|
return prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize Hugging Face Inference API client and MongoDB""" |
|
|
global inference_client, model_load_error |
|
|
|
|
|
|
|
|
try: |
|
|
db = get_database() |
|
|
if db: |
|
|
logger.info("✅ MongoDB initialized successfully!") |
|
|
except Exception as e: |
|
|
logger.warning("⚠️ MongoDB initialization failed: %s", str(e)) |
|
|
|
|
|
try: |
|
|
logger.info("🔄 Initializing Hugging Face Inference API client...") |
|
|
|
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") or settings.HF_TOKEN |
|
|
if not hf_token: |
|
|
raise ValueError("HF_TOKEN environment variable is required for Inference API") |
|
|
|
|
|
|
|
|
inference_client = InferenceClient( |
|
|
provider="fal-ai", |
|
|
api_key=hf_token, |
|
|
) |
|
|
|
|
|
logger.info("✅ Inference API client initialized successfully!") |
|
|
model_load_error = None |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
logger.error(f"❌ Failed to initialize Inference API client: {error_msg}") |
|
|
model_load_error = error_msg |
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("shutdown") |
|
|
async def shutdown_event(): |
|
|
"""Cleanup on shutdown""" |
|
|
global inference_client |
|
|
if inference_client: |
|
|
inference_client = None |
|
|
close_connection() |
|
|
logger.info("Application shutdown") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RegisterRequest(BaseModel): |
|
|
email: EmailStr |
|
|
password: str |
|
|
display_name: Optional[str] = None |
|
|
|
|
|
class LoginRequest(BaseModel): |
|
|
email: EmailStr |
|
|
password: str |
|
|
|
|
|
class TokenResponse(BaseModel): |
|
|
id_token: str |
|
|
refresh_token: Optional[str] = None |
|
|
expires_in: int |
|
|
token_type: str = "Bearer" |
|
|
user: dict |
|
|
|
|
|
|
|
|
|
|
|
def _extract_bearer_token(authorization_header: str | None) -> str | None: |
|
|
if not authorization_header: |
|
|
return None |
|
|
parts = authorization_header.split(" ", 1) |
|
|
if len(parts) == 2 and parts[0].lower() == "bearer": |
|
|
return parts[1].strip() |
|
|
return None |
|
|
|
|
|
|
|
|
async def verify_request(request: Request): |
|
|
""" |
|
|
Verify Firebase authentication. |
|
|
Priority: |
|
|
1. Firebase App Check token (X-Firebase-AppCheck header) - Primary method per documentation |
|
|
2. Firebase Auth ID token (Authorization: Bearer header) - Fallback for auth endpoints |
|
|
""" |
|
|
if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true": |
|
|
return True |
|
|
|
|
|
|
|
|
app_check_token = request.headers.get("X-Firebase-AppCheck") |
|
|
if app_check_token: |
|
|
try: |
|
|
app_check_claims = app_check.verify_token(app_check_token) |
|
|
logger.info("App Check token verified for: %s", app_check_claims.get("app_id")) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning("App Check token verification failed: %s", str(e)) |
|
|
if settings.ENABLE_APP_CHECK: |
|
|
raise HTTPException(status_code=401, detail="Invalid App Check token") |
|
|
|
|
|
|
|
|
|
|
|
bearer = _extract_bearer_token(request.headers.get("Authorization")) |
|
|
if bearer: |
|
|
try: |
|
|
decoded = firebase_auth.verify_id_token(bearer) |
|
|
request.state.user = decoded |
|
|
logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid")) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning("Auth token verification failed: %s", str(e)) |
|
|
|
|
|
|
|
|
if settings.ENABLE_APP_CHECK: |
|
|
if not app_check_token: |
|
|
raise HTTPException(status_code=401, detail="Missing App Check token") |
|
|
raise HTTPException(status_code=401, detail="Invalid App Check token") |
|
|
|
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/auth/register", response_model=TokenResponse) |
|
|
async def register_user(user_data: RegisterRequest): |
|
|
""" |
|
|
Register a new user with email and password. |
|
|
Returns Firebase ID token for immediate use. |
|
|
""" |
|
|
if not firebase_admin._apps: |
|
|
raise HTTPException(status_code=503, detail="Firebase not initialized") |
|
|
|
|
|
try: |
|
|
|
|
|
user_record = firebase_auth.create_user( |
|
|
email=user_data.email, |
|
|
password=user_data.password, |
|
|
display_name=user_data.display_name, |
|
|
email_verified=False |
|
|
) |
|
|
|
|
|
|
|
|
custom_token = firebase_auth.create_custom_token(user_record.uid) |
|
|
|
|
|
logger.info("User registered: %s (uid: %s)", user_data.email, user_record.uid) |
|
|
|
|
|
return TokenResponse( |
|
|
id_token=custom_token.decode('utf-8'), |
|
|
token_type="Bearer", |
|
|
expires_in=3600, |
|
|
user={ |
|
|
"uid": user_record.uid, |
|
|
"email": user_record.email, |
|
|
"display_name": user_record.display_name, |
|
|
"email_verified": user_record.email_verified |
|
|
} |
|
|
) |
|
|
except firebase_auth.EmailAlreadyExistsError: |
|
|
raise HTTPException(status_code=400, detail="Email already registered") |
|
|
except ValueError as e: |
|
|
raise HTTPException(status_code=400, detail=f"Invalid input: {str(e)}") |
|
|
except Exception as e: |
|
|
logger.error("Registration error: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail=f"Registration failed: {str(e)}") |
|
|
|
|
|
|
|
|
@app.post("/auth/login", response_model=TokenResponse) |
|
|
async def login_user(credentials: LoginRequest): |
|
|
""" |
|
|
Login with email and password. |
|
|
Uses Firebase REST API to authenticate and get ID token. |
|
|
""" |
|
|
if not firebase_admin._apps: |
|
|
raise HTTPException(status_code=503, detail="Firebase not initialized") |
|
|
|
|
|
|
|
|
firebase_api_key = os.getenv("FIREBASE_API_KEY") or settings.FIREBASE_API_KEY |
|
|
if not firebase_api_key: |
|
|
|
|
|
try: |
|
|
user_record = firebase_auth.get_user_by_email(credentials.email) |
|
|
custom_token = firebase_auth.create_custom_token(user_record.uid) |
|
|
|
|
|
logger.info("User login: %s (uid: %s)", credentials.email, user_record.uid) |
|
|
|
|
|
return TokenResponse( |
|
|
id_token=custom_token.decode('utf-8'), |
|
|
token_type="Bearer", |
|
|
expires_in=3600, |
|
|
user={ |
|
|
"uid": user_record.uid, |
|
|
"email": user_record.email, |
|
|
"display_name": user_record.display_name, |
|
|
"email_verified": user_record.email_verified |
|
|
} |
|
|
) |
|
|
except firebase_auth.UserNotFoundError: |
|
|
raise HTTPException(status_code=401, detail="Invalid email or password") |
|
|
except Exception as e: |
|
|
logger.error("Login error: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail=f"Login failed: {str(e)}") |
|
|
|
|
|
|
|
|
try: |
|
|
async with httpx.AsyncClient() as client: |
|
|
response = await client.post( |
|
|
f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={firebase_api_key}", |
|
|
json={ |
|
|
"email": credentials.email, |
|
|
"password": credentials.password, |
|
|
"returnSecureToken": True |
|
|
} |
|
|
) |
|
|
|
|
|
if response.status_code != 200: |
|
|
error_data = response.json() |
|
|
error_msg = error_data.get("error", {}).get("message", "Authentication failed") |
|
|
raise HTTPException(status_code=401, detail=error_msg) |
|
|
|
|
|
data = response.json() |
|
|
logger.info("User login successful: %s", credentials.email) |
|
|
|
|
|
|
|
|
user_record = firebase_auth.get_user(data["localId"]) |
|
|
|
|
|
return TokenResponse( |
|
|
id_token=data["idToken"], |
|
|
refresh_token=data.get("refreshToken"), |
|
|
expires_in=int(data.get("expiresIn", 3600)), |
|
|
token_type="Bearer", |
|
|
user={ |
|
|
"uid": user_record.uid, |
|
|
"email": user_record.email, |
|
|
"display_name": user_record.display_name, |
|
|
"email_verified": user_record.email_verified |
|
|
} |
|
|
) |
|
|
except httpx.HTTPError as e: |
|
|
logger.error("HTTP error during login: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail="Authentication service unavailable") |
|
|
except Exception as e: |
|
|
logger.error("Login error: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail=f"Login failed: {str(e)}") |
|
|
|
|
|
|
|
|
@app.get("/auth/me") |
|
|
async def get_current_user(request: Request, verified: bool = Depends(verify_request)): |
|
|
"""Get current authenticated user information""" |
|
|
if not firebase_admin._apps: |
|
|
raise HTTPException(status_code=503, detail="Firebase not initialized") |
|
|
|
|
|
|
|
|
if hasattr(request, 'state') and hasattr(request.state, 'user'): |
|
|
user_data = request.state.user |
|
|
uid = user_data.get("uid") |
|
|
|
|
|
try: |
|
|
user_record = firebase_auth.get_user(uid) |
|
|
return { |
|
|
"uid": user_record.uid, |
|
|
"email": user_record.email, |
|
|
"display_name": user_record.display_name, |
|
|
"email_verified": user_record.email_verified, |
|
|
"created_at": user_record.user_metadata.creation_timestamp, |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error("Error getting user: %s", str(e)) |
|
|
raise HTTPException(status_code=404, detail="User not found") |
|
|
|
|
|
raise HTTPException(status_code=401, detail="Not authenticated") |
|
|
|
|
|
|
|
|
@app.post("/auth/refresh") |
|
|
async def refresh_token(refresh_token: str = Body(..., embed=True)): |
|
|
"""Refresh Firebase ID token using refresh token""" |
|
|
firebase_api_key = os.getenv("FIREBASE_API_KEY") or settings.FIREBASE_API_KEY |
|
|
if not firebase_api_key: |
|
|
raise HTTPException(status_code=503, detail="Firebase API key not configured") |
|
|
|
|
|
try: |
|
|
async with httpx.AsyncClient() as client: |
|
|
response = await client.post( |
|
|
f"https://securetoken.googleapis.com/v1/token?key={firebase_api_key}", |
|
|
json={ |
|
|
"grant_type": "refresh_token", |
|
|
"refresh_token": refresh_token |
|
|
} |
|
|
) |
|
|
|
|
|
if response.status_code != 200: |
|
|
error_data = response.json() |
|
|
error_msg = error_data.get("error", {}).get("message", "Token refresh failed") |
|
|
raise HTTPException(status_code=401, detail=error_msg) |
|
|
|
|
|
data = response.json() |
|
|
return { |
|
|
"id_token": data["id_token"], |
|
|
"refresh_token": data.get("refresh_token"), |
|
|
"expires_in": int(data.get("expires_in", 3600)), |
|
|
"token_type": "Bearer" |
|
|
} |
|
|
except httpx.HTTPError as e: |
|
|
logger.error("HTTP error during token refresh: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail="Token refresh service unavailable") |
|
|
except Exception as e: |
|
|
logger.error("Token refresh error: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail=f"Token refresh failed: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api") |
|
|
async def api_info(request: Request): |
|
|
"""API info endpoint""" |
|
|
response_data = { |
|
|
"app": "Text-Guided Image Colorization API", |
|
|
"version": "1.0.0", |
|
|
"endpoints": { |
|
|
"health": "/health", |
|
|
"upload": "/upload", |
|
|
"colorize": "/colorize", |
|
|
"download": "/download/{file_id}", |
|
|
"results": "/results/{filename}", |
|
|
"uploads": "/uploads/{filename}", |
|
|
"auth": { |
|
|
"register": "/auth/register", |
|
|
"login": "/auth/login", |
|
|
"me": "/auth/me", |
|
|
"refresh": "/auth/refresh" |
|
|
}, |
|
|
"gradio": "/" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
user_id = None |
|
|
if hasattr(request, 'state') and hasattr(request.state, 'user'): |
|
|
user_id = request.state.user.get("uid") |
|
|
|
|
|
log_api_call( |
|
|
endpoint="/api", |
|
|
method="GET", |
|
|
status_code=200, |
|
|
response_data=response_data, |
|
|
user_id=user_id, |
|
|
ip_address=request.client.host if request.client else None |
|
|
) |
|
|
|
|
|
return response_data |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(request: Request): |
|
|
"""Health check endpoint""" |
|
|
response = { |
|
|
"status": "healthy", |
|
|
"model_loaded": inference_client is not None, |
|
|
"model_type": "hf_inference_api", |
|
|
"provider": "fal-ai" |
|
|
} |
|
|
if model_load_error: |
|
|
response["model_error"] = model_load_error |
|
|
|
|
|
|
|
|
log_api_call( |
|
|
endpoint="/health", |
|
|
method="GET", |
|
|
status_code=200, |
|
|
response_data=response, |
|
|
ip_address=request.client.host if request.client else None |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
def colorize_image_sdxl( |
|
|
image: Image.Image, |
|
|
positive_prompt: Optional[str] = None, |
|
|
negative_prompt: Optional[str] = None, |
|
|
seed: int = 123, |
|
|
num_inference_steps: int = 8 |
|
|
) -> Tuple[Image.Image, str]: |
|
|
""" |
|
|
Colorize a grayscale or low-color image using Hugging Face Inference API. |
|
|
|
|
|
Args: |
|
|
image: PIL Image to colorize |
|
|
positive_prompt: Additional descriptive text to enhance the caption |
|
|
negative_prompt: Words or phrases to avoid during generation |
|
|
seed: Random seed for reproducible generation |
|
|
num_inference_steps: Number of inference steps |
|
|
|
|
|
Returns: |
|
|
Tuple of (colorized PIL Image, caption string) |
|
|
""" |
|
|
if inference_client is None: |
|
|
raise RuntimeError("Inference API client not initialized") |
|
|
|
|
|
original_size = image.size |
|
|
|
|
|
control_image = image.convert("RGB").resize((512, 512)) |
|
|
|
|
|
|
|
|
img_bytes = io.BytesIO() |
|
|
control_image.save(img_bytes, format="PNG") |
|
|
img_bytes.seek(0) |
|
|
input_image = img_bytes.read() |
|
|
|
|
|
|
|
|
base_prompt = positive_prompt or "colorize this image with vibrant natural colors, high quality" |
|
|
if negative_prompt: |
|
|
|
|
|
final_prompt = f"{base_prompt}. Avoid: {negative_prompt}" |
|
|
else: |
|
|
final_prompt = base_prompt |
|
|
|
|
|
|
|
|
model_name = settings.INFERENCE_MODEL |
|
|
logger.info(f"Calling Inference API with model {model_name}, prompt: {final_prompt}") |
|
|
try: |
|
|
result_image = inference_client.image_to_image( |
|
|
input_image, |
|
|
prompt=final_prompt, |
|
|
model=model_name, |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(result_image, Image.Image): |
|
|
colorized = result_image.resize(original_size) |
|
|
else: |
|
|
|
|
|
colorized = Image.open(io.BytesIO(result_image)).resize(original_size) |
|
|
|
|
|
|
|
|
caption = final_prompt[:100] |
|
|
|
|
|
return colorized, caption |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Inference API error: {e}") |
|
|
raise RuntimeError(f"Failed to colorize image: {str(e)}") |
|
|
|
|
|
|
|
|
@app.post("/upload") |
|
|
async def upload_image( |
|
|
request: Request, |
|
|
file: UploadFile = File(...), |
|
|
verified: bool = Depends(verify_request) |
|
|
): |
|
|
""" |
|
|
Upload an image and get the uploaded image URL. |
|
|
Requires Firebase App Check authentication. |
|
|
""" |
|
|
user_id = None |
|
|
if hasattr(request, 'state') and hasattr(request.state, 'user'): |
|
|
user_id = request.state.user.get("uid") |
|
|
|
|
|
ip_address = request.client.host if request.client else None |
|
|
|
|
|
if not file.content_type or not file.content_type.startswith("image/"): |
|
|
log_api_call( |
|
|
endpoint="/upload", |
|
|
method="POST", |
|
|
status_code=400, |
|
|
error="File must be an image", |
|
|
user_id=user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
|
|
try: |
|
|
|
|
|
file_extension = file.filename.split('.')[-1] if file.filename else 'jpg' |
|
|
image_id = f"{uuid.uuid4()}.{file_extension}" |
|
|
file_path = UPLOAD_DIR / image_id |
|
|
|
|
|
|
|
|
img_bytes = await file.read() |
|
|
file_size = len(img_bytes) |
|
|
with open(file_path, "wb") as f: |
|
|
f.write(img_bytes) |
|
|
|
|
|
logger.info("Image uploaded: %s", image_id) |
|
|
|
|
|
|
|
|
base_url = os.getenv("BASE_URL", settings.BASE_URL) |
|
|
if not base_url or base_url == "http://localhost:8000": |
|
|
|
|
|
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" |
|
|
|
|
|
response_data = { |
|
|
"success": True, |
|
|
"image_id": image_id.replace(f".{file_extension}", ""), |
|
|
"image_url": f"{base_url}/uploads/{image_id}", |
|
|
"filename": image_id |
|
|
} |
|
|
|
|
|
|
|
|
log_image_upload( |
|
|
image_id=image_id.replace(f".{file_extension}", ""), |
|
|
filename=file.filename or image_id, |
|
|
file_size=file_size, |
|
|
content_type=file.content_type or "image/jpeg", |
|
|
user_id=user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
log_api_call( |
|
|
endpoint="/upload", |
|
|
method="POST", |
|
|
status_code=200, |
|
|
request_data={"filename": file.filename, "content_type": file.content_type}, |
|
|
response_data=response_data, |
|
|
user_id=user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
return JSONResponse(response_data) |
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
logger.error("Error uploading image: %s", error_msg) |
|
|
log_api_call( |
|
|
endpoint="/upload", |
|
|
method="POST", |
|
|
status_code=500, |
|
|
error=error_msg, |
|
|
user_id=user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=500, detail=f"Error uploading image: {error_msg}") |
|
|
|
|
|
|
|
|
@app.post("/colorize") |
|
|
async def colorize_api( |
|
|
request: Request, |
|
|
file: UploadFile = File(...), |
|
|
positive_prompt: Optional[str] = None, |
|
|
negative_prompt: Optional[str] = None, |
|
|
seed: int = 123, |
|
|
num_inference_steps: int = 8, |
|
|
verified: bool = Depends(verify_request) |
|
|
): |
|
|
""" |
|
|
Upload a grayscale image -> returns colorized image. |
|
|
Uses SDXL + ControlNet with automatic captioning. |
|
|
""" |
|
|
import time |
|
|
start_time = time.time() |
|
|
|
|
|
user_id = None |
|
|
if hasattr(request, 'state') and hasattr(request.state, 'user'): |
|
|
user_id = request.state.user.get("uid") |
|
|
|
|
|
ip_address = request.client.host if request.client else None |
|
|
|
|
|
if inference_client is None: |
|
|
log_api_call( |
|
|
endpoint="/colorize", |
|
|
method="POST", |
|
|
status_code=503, |
|
|
error="Inference API client not initialized", |
|
|
user_id=user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=503, detail="Inference API client not initialized") |
|
|
|
|
|
if not file.content_type or not file.content_type.startswith("image/"): |
|
|
log_api_call( |
|
|
endpoint="/colorize", |
|
|
method="POST", |
|
|
status_code=400, |
|
|
error="File must be an image", |
|
|
user_id=user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
|
|
try: |
|
|
img_bytes = await file.read() |
|
|
image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
|
|
|
|
logger.info("Colorizing image with SDXL + ControlNet...") |
|
|
colorized, caption = colorize_image_sdxl( |
|
|
image, |
|
|
positive_prompt=positive_prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
seed=seed, |
|
|
num_inference_steps=num_inference_steps |
|
|
) |
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
|
|
|
output_filename = f"{uuid.uuid4()}.png" |
|
|
output_path = RESULT_DIR / output_filename |
|
|
colorized.save(output_path, "PNG") |
|
|
|
|
|
logger.info("Colorized image saved: %s", output_filename) |
|
|
|
|
|
|
|
|
base_url = os.getenv("BASE_URL", settings.BASE_URL) |
|
|
if not base_url or base_url == "http://localhost:8000": |
|
|
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" |
|
|
|
|
|
result_id = output_filename.replace(".png", "") |
|
|
|
|
|
response_data = { |
|
|
"success": True, |
|
|
"result_id": result_id, |
|
|
"download_url": f"{base_url}/results/{output_filename}", |
|
|
"api_download_url": f"{base_url}/download/{result_id}", |
|
|
"filename": output_filename, |
|
|
"caption": caption |
|
|
} |
|
|
|
|
|
|
|
|
log_colorization( |
|
|
result_id=result_id, |
|
|
prompt=positive_prompt, |
|
|
model_type="sdxl", |
|
|
processing_time=processing_time, |
|
|
user_id=user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
log_api_call( |
|
|
endpoint="/colorize", |
|
|
method="POST", |
|
|
status_code=200, |
|
|
request_data={ |
|
|
"filename": file.filename, |
|
|
"positive_prompt": positive_prompt, |
|
|
"negative_prompt": negative_prompt, |
|
|
"seed": seed, |
|
|
"num_inference_steps": num_inference_steps |
|
|
}, |
|
|
response_data=response_data, |
|
|
user_id=user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
return JSONResponse(response_data) |
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
logger.error("Error colorizing image: %s", error_msg) |
|
|
log_api_call( |
|
|
endpoint="/colorize", |
|
|
method="POST", |
|
|
status_code=500, |
|
|
error=error_msg, |
|
|
user_id=user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=500, detail=f"Error colorizing image: {error_msg}") |
|
|
|
|
|
|
|
|
@app.get("/download/{file_id}") |
|
|
def download_result( |
|
|
request: Request, |
|
|
file_id: str, |
|
|
verified: bool = Depends(verify_request) |
|
|
): |
|
|
"""Download colorized image by file ID""" |
|
|
user_id = None |
|
|
if hasattr(request, 'state') and hasattr(request.state, 'user'): |
|
|
user_id = request.state.user.get("uid") |
|
|
|
|
|
ip_address = request.client.host if request.client else None |
|
|
|
|
|
filename = f"{file_id}.png" |
|
|
path = RESULT_DIR / filename |
|
|
|
|
|
if not path.exists(): |
|
|
log_api_call( |
|
|
endpoint=f"/download/{file_id}", |
|
|
method="GET", |
|
|
status_code=404, |
|
|
error="Result not found", |
|
|
user_id=user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=404, detail="Result not found") |
|
|
|
|
|
log_api_call( |
|
|
endpoint=f"/download/{file_id}", |
|
|
method="GET", |
|
|
status_code=200, |
|
|
request_data={"file_id": file_id}, |
|
|
user_id=user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
return FileResponse(path, media_type="image/png") |
|
|
|
|
|
|
|
|
@app.get("/results/{filename}") |
|
|
def get_result(request: Request, filename: str): |
|
|
"""Public endpoint to access colorized images""" |
|
|
ip_address = request.client.host if request.client else None |
|
|
|
|
|
path = RESULT_DIR / filename |
|
|
if not path.exists(): |
|
|
log_api_call( |
|
|
endpoint=f"/results/{filename}", |
|
|
method="GET", |
|
|
status_code=404, |
|
|
error="Result not found", |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=404, detail="Result not found") |
|
|
|
|
|
log_api_call( |
|
|
endpoint=f"/results/{filename}", |
|
|
method="GET", |
|
|
status_code=200, |
|
|
request_data={"filename": filename}, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
return FileResponse(path, media_type="image/png") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_colorize(image, positive_prompt=None, negative_prompt=None, seed=123): |
|
|
"""Gradio colorization function""" |
|
|
if image is None: |
|
|
return None, "" |
|
|
try: |
|
|
if inference_client is None: |
|
|
return None, "Inference API client not initialized" |
|
|
colorized, caption = colorize_image_sdxl( |
|
|
image, |
|
|
positive_prompt=positive_prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
seed=seed |
|
|
) |
|
|
return colorized, caption |
|
|
except Exception as e: |
|
|
logger.error("Gradio colorization error: %s", str(e)) |
|
|
return None, str(e) |
|
|
|
|
|
|
|
|
title = "🎨 Text-Guided Image Colorization" |
|
|
description = "Upload a grayscale image and generate a color version using Hugging Face Inference API (fal-ai provider)." |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=gradio_colorize, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Upload Image"), |
|
|
gr.Textbox(label="Positive Prompt", placeholder="Enter details to enhance the caption"), |
|
|
gr.Textbox(label="Negative Prompt", value=settings.NEGATIVE_PROMPT), |
|
|
gr.Slider(0, 1000, 123, label="Seed") |
|
|
], |
|
|
outputs=[ |
|
|
gr.Image(type="pil", label="Colorized Image"), |
|
|
gr.Textbox(label="Caption") |
|
|
], |
|
|
title=title, |
|
|
description=description, |
|
|
) |
|
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, iface, path="/") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.getenv("PORT", "7860")) |
|
|
uvicorn.run(app, host="0.0.0.0", port=port) |
|
|
|
|
|
|