LogicGoInfotechSpaces's picture
Add project files
0ed44c0
import argparse
import os
import sys
import time
from typing import Optional
import requests
def log(msg: str) -> None:
print(msg)
def upload_image(base_url: str, image_path: str, headers: dict) -> Optional[str]:
log("Uploading image: %s" % image_path)
import mimetypes
mime, _ = mimetypes.guess_type(image_path)
if not mime or not mime.startswith("image/"):
mime = "image/jpeg"
filename = os.path.basename(image_path) or "image.jpg"
with open(image_path, "rb") as f:
files = {"file": (filename, f, mime)}
r = requests.post("%s/upload" % base_url, files=files, headers=headers, timeout=120)
if r.status_code != 200:
log("Upload failed: %s %s" % (r.status_code, r.text))
return None
image_id = r.json().get("image_id")
log("Uploaded. image_id=%s" % image_id)
return image_id
def edit_image(base_url: str, image_id: str, prompt: str, headers: dict) -> Optional[str]:
log("Editing image with prompt: %s" % prompt)
r = requests.post(
"%s/edit" % base_url, data={"image_id": image_id, "prompt": prompt}, headers=headers, timeout=300
)
if r.status_code != 200:
log("Edit failed: %s %s" % (r.status_code, r.text))
return None
task_id = r.json().get("task_id")
log("Edit submitted. task_id=%s status=%s" % (task_id, r.json().get("status")))
return task_id
def get_result(base_url: str, task_id: str, headers: dict) -> Optional[dict]:
r = requests.get("%s/result/%s" % (base_url, task_id), headers=headers, timeout=120)
if r.status_code != 200:
log("Result failed: %s %s" % (r.status_code, r.text))
return None
return r.json()
def download_image(base_url: str, result_image_id: str, out_path: str, headers: dict) -> bool:
r = requests.get("%s/result/image/%s" % (base_url, result_image_id), headers=headers, timeout=300)
if r.status_code != 200:
log("Download failed: %s %s" % (r.status_code, r.text))
return False
with open(out_path, "wb") as f:
f.write(r.content)
log("Saved: %s" % out_path)
return True
def health(base_url: str, headers: dict) -> bool:
try:
r = requests.get("%s/health" % base_url, headers=headers, timeout=30)
if r.status_code == 200:
j = r.json()
log("Health: %s (model_loaded=%s)" % (j.get("status"), j.get("model_loaded")))
return True
log("Health failed: %s %s" % (r.status_code, r.text))
return False
except requests.ConnectionError:
log("Cannot connect to %s" % base_url)
return False
def main():
parser = argparse.ArgumentParser(description="Test Nano Banana Image Edit API")
parser.add_argument("--base", dest="base_url", required=True, help="Base URL of the API, e.g. https://hf.space/... or http://127.0.0.1:7860")
parser.add_argument("--image", dest="image_path", required=True, help="Path to input image")
parser.add_argument("--prompt", dest="prompt", default="enhance the image", help="Edit prompt")
parser.add_argument("--out", dest="out_path", default="edited.png", help="Output file for edited image")
parser.add_argument("--token", dest="token", default="", help="Hugging Face access token for private Spaces")
args = parser.parse_args()
base_url = args.base_url.rstrip("/")
if not os.path.exists(args.image_path):
log("Image not found: %s" % args.image_path)
sys.exit(1)
headers = {}
if args.token:
headers["Authorization"] = "Bearer %s" % args.token
if not health(base_url, headers):
sys.exit(1)
image_id = upload_image(base_url, args.image_path, headers)
if not image_id:
sys.exit(1)
task_id = edit_image(base_url, image_id, args.prompt, headers)
if not task_id:
sys.exit(1)
# If the backend ever becomes async, poll here
time.sleep(1.0)
result = get_result(base_url, task_id, headers)
if not result or result.get("status") != "completed":
log("Result not completed: %s" % (result,))
sys.exit(1)
result_image_id = result.get("result_image_id")
if not result_image_id:
log("No result_image_id in response")
sys.exit(1)
ok = download_image(base_url, result_image_id, args.out_path, headers)
if not ok:
sys.exit(1)
log("Done.")
if __name__ == "__main__":
main()