|
|
import os |
|
|
import uuid |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from typing import Optional |
|
|
from datetime import datetime |
|
|
|
|
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
|
|
from fastapi.responses import JSONResponse, FileResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
try: |
|
|
import transformers_gradio |
|
|
import gradio as gr |
|
|
GRADIO_AVAILABLE = True |
|
|
except ImportError: |
|
|
GRADIO_AVAILABLE = False |
|
|
print("Warning: gradio/transformers_gradio not available. Using mock mode.") |
|
|
|
|
|
app = FastAPI( |
|
|
title="Nano Banana Image Edit API", |
|
|
description="API for Qwen Image Edit model - Upload images and edit them with prompts", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
tasks = {} |
|
|
|
|
|
|
|
|
gradio_demo = None |
|
|
gradio_fn = None |
|
|
|
|
|
def load_model(): |
|
|
"""Load the Qwen Image Edit model using Gradio""" |
|
|
global gradio_demo, gradio_fn |
|
|
if not GRADIO_AVAILABLE: |
|
|
return False |
|
|
|
|
|
try: |
|
|
print("Loading Qwen/Qwen-Image-Edit model via Gradio...") |
|
|
gradio_demo = gr.load(name="Qwen/Qwen-Image-Edit", src=transformers_gradio.registry) |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(gradio_demo, 'fn'): |
|
|
gradio_fn = gradio_demo.fn |
|
|
elif hasattr(gradio_demo, 'blocks') and gradio_demo.blocks: |
|
|
|
|
|
for block in gradio_demo.blocks.values(): |
|
|
if hasattr(block, 'fn') and callable(block.fn): |
|
|
gradio_fn = block.fn |
|
|
break |
|
|
|
|
|
print("Model loaded successfully") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
class UploadResponse(BaseModel): |
|
|
image_id: str |
|
|
message: str |
|
|
timestamp: str |
|
|
|
|
|
class EditRequest(BaseModel): |
|
|
image_id: str |
|
|
prompt: str |
|
|
|
|
|
class EditResponse(BaseModel): |
|
|
task_id: str |
|
|
status: str |
|
|
message: str |
|
|
timestamp: str |
|
|
|
|
|
class ResultResponse(BaseModel): |
|
|
task_id: str |
|
|
status: str |
|
|
result_image_id: Optional[str] = None |
|
|
result_image_url: Optional[str] = None |
|
|
error: Optional[str] = None |
|
|
timestamp: str |
|
|
|
|
|
class ErrorResponse(BaseModel): |
|
|
error: str |
|
|
detail: Optional[str] = None |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize model on startup""" |
|
|
if GRADIO_AVAILABLE: |
|
|
load_model() |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint""" |
|
|
return { |
|
|
"message": "Nano Banana Image Edit API", |
|
|
"version": "1.0.0", |
|
|
"endpoints": { |
|
|
"upload": "/upload", |
|
|
"edit": "/edit", |
|
|
"result": "/result/{task_id}", |
|
|
"health": "/health" |
|
|
} |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model_loaded": gradio_fn is not None if GRADIO_AVAILABLE else False, |
|
|
"model_available": GRADIO_AVAILABLE |
|
|
} |
|
|
|
|
|
@app.post("/upload", response_model=UploadResponse) |
|
|
async def upload_image(file: UploadFile = File(...)): |
|
|
""" |
|
|
Upload an image file |
|
|
|
|
|
Returns: |
|
|
image_id: Unique identifier for the uploaded image |
|
|
""" |
|
|
|
|
|
if not file.content_type or not file.content_type.startswith('image/'): |
|
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
|
|
try: |
|
|
|
|
|
image_data = await file.read() |
|
|
image = Image.open(BytesIO(image_data)) |
|
|
|
|
|
|
|
|
image.verify() |
|
|
image = Image.open(BytesIO(image_data)) |
|
|
|
|
|
|
|
|
image_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
os.makedirs("uploads", exist_ok=True) |
|
|
image_path = f"uploads/{image_id}.{image.format.lower()}" |
|
|
image.save(image_path) |
|
|
|
|
|
|
|
|
tasks[image_id] = { |
|
|
"type": "image", |
|
|
"path": image_path, |
|
|
"format": image.format, |
|
|
"size": image.size, |
|
|
"uploaded_at": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
return UploadResponse( |
|
|
image_id=image_id, |
|
|
message="Image uploaded successfully", |
|
|
timestamp=datetime.now().isoformat() |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") |
|
|
|
|
|
@app.post("/edit", response_model=EditResponse) |
|
|
async def edit_image( |
|
|
image_id: str = Form(...), |
|
|
prompt: str = Form(...) |
|
|
): |
|
|
""" |
|
|
Edit an image using a text prompt |
|
|
|
|
|
Parameters: |
|
|
image_id: ID of the uploaded image |
|
|
prompt: Text prompt describing the desired edit |
|
|
|
|
|
Returns: |
|
|
task_id: Unique identifier for the editing task |
|
|
""" |
|
|
|
|
|
if image_id not in tasks or tasks[image_id]["type"] != "image": |
|
|
raise HTTPException(status_code=404, detail="Image not found") |
|
|
|
|
|
|
|
|
task_id = str(uuid.uuid4()) |
|
|
|
|
|
try: |
|
|
|
|
|
image_path = tasks[image_id]["path"] |
|
|
image = Image.open(image_path) |
|
|
|
|
|
|
|
|
if GRADIO_AVAILABLE and gradio_fn is not None: |
|
|
try: |
|
|
|
|
|
|
|
|
result = gradio_fn(image, prompt) |
|
|
|
|
|
|
|
|
if isinstance(result, tuple): |
|
|
edited_image = result[0] |
|
|
elif isinstance(result, dict): |
|
|
edited_image = result.get('image', result.get('output', image)) |
|
|
else: |
|
|
edited_image = result |
|
|
|
|
|
|
|
|
if not isinstance(edited_image, Image.Image): |
|
|
edited_image = image.copy() |
|
|
except Exception as e: |
|
|
print(f"Error processing with model: {e}") |
|
|
|
|
|
edited_image = image.copy() |
|
|
else: |
|
|
|
|
|
edited_image = image.copy() |
|
|
|
|
|
|
|
|
os.makedirs("results", exist_ok=True) |
|
|
result_image_id = str(uuid.uuid4()) |
|
|
result_path = f"results/{result_image_id}.png" |
|
|
edited_image.save(result_path) |
|
|
|
|
|
|
|
|
tasks[task_id] = { |
|
|
"type": "edit_task", |
|
|
"image_id": image_id, |
|
|
"prompt": prompt, |
|
|
"result_image_id": result_image_id, |
|
|
"result_path": result_path, |
|
|
"status": "completed", |
|
|
"created_at": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
return EditResponse( |
|
|
task_id=task_id, |
|
|
status="completed", |
|
|
message="Image edited successfully", |
|
|
timestamp=datetime.now().isoformat() |
|
|
) |
|
|
except Exception as e: |
|
|
|
|
|
tasks[task_id] = { |
|
|
"type": "edit_task", |
|
|
"image_id": image_id, |
|
|
"prompt": prompt, |
|
|
"status": "failed", |
|
|
"error": str(e), |
|
|
"created_at": datetime.now().isoformat() |
|
|
} |
|
|
raise HTTPException(status_code=500, detail=f"Error editing image: {str(e)}") |
|
|
|
|
|
@app.get("/result/{task_id}", response_model=ResultResponse) |
|
|
async def get_result(task_id: str): |
|
|
""" |
|
|
Get the result of an image editing task |
|
|
|
|
|
Parameters: |
|
|
task_id: ID of the editing task |
|
|
|
|
|
Returns: |
|
|
Result information including image URL |
|
|
""" |
|
|
if task_id not in tasks or tasks[task_id]["type"] != "edit_task": |
|
|
raise HTTPException(status_code=404, detail="Task not found") |
|
|
|
|
|
task = tasks[task_id] |
|
|
|
|
|
if task["status"] == "failed": |
|
|
return ResultResponse( |
|
|
task_id=task_id, |
|
|
status="failed", |
|
|
error=task.get("error", "Unknown error"), |
|
|
timestamp=task["created_at"] |
|
|
) |
|
|
|
|
|
if task["status"] == "completed": |
|
|
result_image_id = task.get("result_image_id") |
|
|
return ResultResponse( |
|
|
task_id=task_id, |
|
|
status="completed", |
|
|
result_image_id=result_image_id, |
|
|
result_image_url=f"/result/image/{result_image_id}", |
|
|
timestamp=task["created_at"] |
|
|
) |
|
|
|
|
|
return ResultResponse( |
|
|
task_id=task_id, |
|
|
status="processing", |
|
|
timestamp=task["created_at"] |
|
|
) |
|
|
|
|
|
@app.get("/result/image/{result_image_id}") |
|
|
async def get_result_image(result_image_id: str): |
|
|
""" |
|
|
Get the edited image file |
|
|
|
|
|
Parameters: |
|
|
result_image_id: ID of the result image |
|
|
|
|
|
Returns: |
|
|
Image file |
|
|
""" |
|
|
|
|
|
task = None |
|
|
for t in tasks.values(): |
|
|
if t.get("type") == "edit_task" and t.get("result_image_id") == result_image_id: |
|
|
task = t |
|
|
break |
|
|
|
|
|
if not task or "result_path" not in task: |
|
|
raise HTTPException(status_code=404, detail="Result image not found") |
|
|
|
|
|
if not os.path.exists(task["result_path"]): |
|
|
raise HTTPException(status_code=404, detail="Image file not found") |
|
|
|
|
|
return FileResponse( |
|
|
task["result_path"], |
|
|
media_type="image/png", |
|
|
filename=f"edited_{result_image_id}.png" |
|
|
) |
|
|
|
|
|
@app.get("/result/image/{result_image_id}/base64") |
|
|
async def get_result_image_base64(result_image_id: str): |
|
|
""" |
|
|
Get the edited image as base64 encoded string |
|
|
|
|
|
Parameters: |
|
|
result_image_id: ID of the result image |
|
|
|
|
|
Returns: |
|
|
JSON with base64 encoded image |
|
|
""" |
|
|
|
|
|
task = None |
|
|
for t in tasks.values(): |
|
|
if t.get("type") == "edit_task" and t.get("result_image_id") == result_image_id: |
|
|
task = t |
|
|
break |
|
|
|
|
|
if not task or "result_path" not in task: |
|
|
raise HTTPException(status_code=404, detail="Result image not found") |
|
|
|
|
|
if not os.path.exists(task["result_path"]): |
|
|
raise HTTPException(status_code=404, detail="Image file not found") |
|
|
|
|
|
|
|
|
with open(task["result_path"], "rb") as f: |
|
|
image_data = f.read() |
|
|
base64_data = base64.b64encode(image_data).decode("utf-8") |
|
|
|
|
|
return { |
|
|
"result_image_id": result_image_id, |
|
|
"image_base64": base64_data, |
|
|
"format": "png" |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
|
|
|