|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
from typing import Optional |
|
|
|
|
|
import requests |
|
|
|
|
|
|
|
|
def log(msg: str) -> None: |
|
|
print(msg) |
|
|
|
|
|
|
|
|
def upload_image(base_url: str, image_path: str, headers: dict) -> Optional[str]: |
|
|
log("Uploading image: %s" % image_path) |
|
|
import mimetypes |
|
|
mime, _ = mimetypes.guess_type(image_path) |
|
|
if not mime or not mime.startswith("image/"): |
|
|
mime = "image/jpeg" |
|
|
filename = os.path.basename(image_path) or "image.jpg" |
|
|
with open(image_path, "rb") as f: |
|
|
files = {"file": (filename, f, mime)} |
|
|
r = requests.post("%s/upload" % base_url, files=files, headers=headers, timeout=120) |
|
|
if r.status_code != 200: |
|
|
log("Upload failed: %s %s" % (r.status_code, r.text)) |
|
|
return None |
|
|
image_id = r.json().get("image_id") |
|
|
log("Uploaded. image_id=%s" % image_id) |
|
|
return image_id |
|
|
|
|
|
|
|
|
def edit_image(base_url: str, image_id: str, prompt: str, headers: dict) -> Optional[str]: |
|
|
log("Editing image with prompt: %s" % prompt) |
|
|
r = requests.post( |
|
|
"%s/edit" % base_url, data={"image_id": image_id, "prompt": prompt}, headers=headers, timeout=300 |
|
|
) |
|
|
if r.status_code != 200: |
|
|
log("Edit failed: %s %s" % (r.status_code, r.text)) |
|
|
return None |
|
|
task_id = r.json().get("task_id") |
|
|
log("Edit submitted. task_id=%s status=%s" % (task_id, r.json().get("status"))) |
|
|
return task_id |
|
|
|
|
|
|
|
|
def get_result(base_url: str, task_id: str, headers: dict) -> Optional[dict]: |
|
|
r = requests.get("%s/result/%s" % (base_url, task_id), headers=headers, timeout=120) |
|
|
if r.status_code != 200: |
|
|
log("Result failed: %s %s" % (r.status_code, r.text)) |
|
|
return None |
|
|
return r.json() |
|
|
|
|
|
|
|
|
def download_image(base_url: str, result_image_id: str, out_path: str, headers: dict) -> bool: |
|
|
r = requests.get("%s/result/image/%s" % (base_url, result_image_id), headers=headers, timeout=300) |
|
|
if r.status_code != 200: |
|
|
log("Download failed: %s %s" % (r.status_code, r.text)) |
|
|
return False |
|
|
with open(out_path, "wb") as f: |
|
|
f.write(r.content) |
|
|
log("Saved: %s" % out_path) |
|
|
return True |
|
|
|
|
|
|
|
|
def health(base_url: str, headers: dict) -> bool: |
|
|
try: |
|
|
r = requests.get("%s/health" % base_url, headers=headers, timeout=30) |
|
|
if r.status_code == 200: |
|
|
j = r.json() |
|
|
log("Health: %s (model_loaded=%s)" % (j.get("status"), j.get("model_loaded"))) |
|
|
return True |
|
|
log("Health failed: %s %s" % (r.status_code, r.text)) |
|
|
return False |
|
|
except requests.ConnectionError: |
|
|
log("Cannot connect to %s" % base_url) |
|
|
return False |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Test Nano Banana Image Edit API") |
|
|
parser.add_argument("--base", dest="base_url", required=True, help="Base URL of the API, e.g. https://hf.space/... or http://127.0.0.1:7860") |
|
|
parser.add_argument("--image", dest="image_path", required=True, help="Path to input image") |
|
|
parser.add_argument("--prompt", dest="prompt", default="enhance the image", help="Edit prompt") |
|
|
parser.add_argument("--out", dest="out_path", default="edited.png", help="Output file for edited image") |
|
|
parser.add_argument("--token", dest="token", default="", help="Hugging Face access token for private Spaces") |
|
|
args = parser.parse_args() |
|
|
|
|
|
base_url = args.base_url.rstrip("/") |
|
|
if not os.path.exists(args.image_path): |
|
|
log("Image not found: %s" % args.image_path) |
|
|
sys.exit(1) |
|
|
|
|
|
headers = {} |
|
|
if args.token: |
|
|
headers["Authorization"] = "Bearer %s" % args.token |
|
|
|
|
|
if not health(base_url, headers): |
|
|
sys.exit(1) |
|
|
|
|
|
image_id = upload_image(base_url, args.image_path, headers) |
|
|
if not image_id: |
|
|
sys.exit(1) |
|
|
|
|
|
task_id = edit_image(base_url, image_id, args.prompt, headers) |
|
|
if not task_id: |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
time.sleep(1.0) |
|
|
result = get_result(base_url, task_id, headers) |
|
|
if not result or result.get("status") != "completed": |
|
|
log("Result not completed: %s" % (result,)) |
|
|
sys.exit(1) |
|
|
|
|
|
result_image_id = result.get("result_image_id") |
|
|
if not result_image_id: |
|
|
log("No result_image_id in response") |
|
|
sys.exit(1) |
|
|
|
|
|
ok = download_image(base_url, result_image_id, args.out_path, headers) |
|
|
if not ok: |
|
|
sys.exit(1) |
|
|
|
|
|
log("Done.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|