|
|
import argparse |
|
|
import io |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
from typing import Optional |
|
|
|
|
|
import requests |
|
|
|
|
|
|
|
|
def configure_logging(verbose: bool) -> None: |
|
|
log_level = logging.DEBUG if verbose else logging.INFO |
|
|
logging.basicConfig( |
|
|
level=log_level, |
|
|
format="%(asctime)s - %(levelname)s - %(message)s" |
|
|
) |
|
|
|
|
|
|
|
|
def get_base_url(cli_base_url: Optional[str]) -> str: |
|
|
if cli_base_url: |
|
|
return cli_base_url.rstrip("/") |
|
|
env_base = os.getenv("BASE_URL") |
|
|
if env_base: |
|
|
return env_base.rstrip("/") |
|
|
|
|
|
return "http://localhost:7860" |
|
|
|
|
|
|
|
|
def wait_for_model(base_url: str, timeout_seconds: int = 300) -> None: |
|
|
deadline = time.time() + timeout_seconds |
|
|
health_url = f"{base_url}/health" |
|
|
logging.info("Waiting for model to load at %s", health_url) |
|
|
last_status = None |
|
|
while time.time() < deadline: |
|
|
try: |
|
|
resp = requests.get(health_url, timeout=15) |
|
|
if resp.ok: |
|
|
data = resp.json() |
|
|
last_status = data |
|
|
if data.get("model_loaded"): |
|
|
logging.info("Model loaded: %s", json.dumps(data)) |
|
|
return |
|
|
logging.info("Health: %s", json.dumps(data)) |
|
|
else: |
|
|
logging.warning("Health check HTTP %s", resp.status_code) |
|
|
except Exception as e: |
|
|
logging.warning("Health check error: %s", str(e)) |
|
|
time.sleep(3) |
|
|
raise RuntimeError("Model did not load before timeout. Last health: %s" % (last_status,)) |
|
|
|
|
|
|
|
|
def upload_image(base_url: str, image_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> dict: |
|
|
url = f"{base_url}/upload" |
|
|
headers = {} |
|
|
if auth_bearer: |
|
|
headers["Authorization"] = f"Bearer {auth_bearer}" |
|
|
if app_check: |
|
|
headers["X-Firebase-AppCheck"] = app_check |
|
|
with open(image_path, "rb") as f: |
|
|
files = {"file": (os.path.basename(image_path), f, "image/jpeg")} |
|
|
resp = requests.post(url, files=files, headers=headers, timeout=120) |
|
|
if not resp.ok: |
|
|
raise RuntimeError("Upload failed: HTTP %s %s" % (resp.status_code, resp.text)) |
|
|
data = resp.json() |
|
|
logging.info("Upload response: %s", json.dumps(data)) |
|
|
return data |
|
|
|
|
|
|
|
|
def colorize_image(base_url: str, image_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> dict: |
|
|
url = f"{base_url}/colorize" |
|
|
headers = {} |
|
|
if auth_bearer: |
|
|
headers["Authorization"] = f"Bearer {auth_bearer}" |
|
|
if app_check: |
|
|
headers["X-Firebase-AppCheck"] = app_check |
|
|
with open(image_path, "rb") as f: |
|
|
files = {"file": (os.path.basename(image_path), f, "image/jpeg")} |
|
|
resp = requests.post(url, files=files, headers=headers, timeout=900) |
|
|
if not resp.ok: |
|
|
raise RuntimeError("Colorize failed: HTTP %s %s" % (resp.status_code, resp.text)) |
|
|
data = resp.json() |
|
|
logging.info("Colorize response: %s", json.dumps(data)) |
|
|
return data |
|
|
|
|
|
|
|
|
def download_result(base_url: str, result_id: str, output_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> None: |
|
|
url = f"{base_url}/download/{result_id}" |
|
|
headers = {} |
|
|
if auth_bearer: |
|
|
headers["Authorization"] = f"Bearer {auth_bearer}" |
|
|
if app_check: |
|
|
headers["X-Firebase-AppCheck"] = app_check |
|
|
resp = requests.get(url, headers=headers, stream=True, timeout=300) |
|
|
if not resp.ok: |
|
|
raise RuntimeError("Download failed: HTTP %s %s" % (resp.status_code, resp.text)) |
|
|
with open(output_path, "wb") as out: |
|
|
for chunk in resp.iter_content(chunk_size=8192): |
|
|
if chunk: |
|
|
out.write(chunk) |
|
|
logging.info("Saved colorized image to: %s", output_path) |
|
|
|
|
|
|
|
|
def main() -> int: |
|
|
parser = argparse.ArgumentParser(description="End-to-end test for Colorize API") |
|
|
parser.add_argument("--base-url", type=str, help="API base URL, e.g. https://<space>.hf.space") |
|
|
parser.add_argument("--image", type=str, required=True, help="Path to input image") |
|
|
parser.add_argument("--out", type=str, default="colorized_result.jpg", help="Path to save colorized image") |
|
|
parser.add_argument("--auth", type=str, default=os.getenv("ID_TOKEN", ""), help="Optional Firebase id_token") |
|
|
parser.add_argument("--app-check", type=str, default=os.getenv("APP_CHECK_TOKEN", ""), help="Optional App Check token") |
|
|
parser.add_argument("--skip-wait", action="store_true", help="Skip waiting for model to load") |
|
|
parser.add_argument("--verbose", action="store_true", help="Verbose logging") |
|
|
args = parser.parse_args() |
|
|
|
|
|
configure_logging(args.verbose) |
|
|
base_url = get_base_url(args.base_url) |
|
|
image_path = args.image |
|
|
|
|
|
if not os.path.exists(image_path): |
|
|
logging.error("Image not found: %s", image_path) |
|
|
return 1 |
|
|
|
|
|
if not args.skip_wait: |
|
|
try: |
|
|
wait_for_model(base_url, timeout_seconds=600) |
|
|
except Exception as e: |
|
|
logging.warning("Continuing despite health wait failure: %s", str(e)) |
|
|
|
|
|
auth_bearer = args.auth.strip() or None |
|
|
app_check = args.app_check.strip() or None |
|
|
|
|
|
try: |
|
|
upload_resp = upload_image(base_url, image_path, auth_bearer, app_check) |
|
|
except Exception as e: |
|
|
logging.error("Upload error: %s", str(e)) |
|
|
return 1 |
|
|
|
|
|
try: |
|
|
colorize_resp = colorize_image(base_url, image_path, auth_bearer, app_check) |
|
|
except Exception as e: |
|
|
logging.error("Colorize error: %s", str(e)) |
|
|
return 1 |
|
|
|
|
|
result_id = colorize_resp.get("result_id") |
|
|
if not result_id: |
|
|
logging.error("No result_id in response: %s", json.dumps(colorize_resp)) |
|
|
return 1 |
|
|
|
|
|
try: |
|
|
download_result(base_url, result_id, args.out, auth_bearer, app_check) |
|
|
except Exception as e: |
|
|
logging.error("Download error: %s", str(e)) |
|
|
return 1 |
|
|
|
|
|
logging.info("Test workflow completed successfully.") |
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|
|
|
|