#!/usr/bin/env python3 """ FastAPI JASCO server (Postman-ready) with robust Hugging Face auth and safe handling of HF fast-transfer (hf_transfer). - If HF_HUB_ENABLE_HF_TRANSFER=1 but 'hf_transfer' isn't installed, we fall back to standard downloads and log a warning instead of failing. - POST /predict supports multipart (drums_file) and JSON (drums_b64). - GET /hf-status shows auth and model access. Run: export HUGGINGFACE_HUB_TOKEN=hf_xxx # or HF_TOKEN/HFTOKEN/HUGGINGFACEHUB_API_TOKEN uvicorn main:app --host 0.0.0.0 --port 7860 """ # ----------------------------- # Environment (HF Spaces-friendly) # ----------------------------- import os from pathlib import Path from requests import Request, Response from pydantic import BaseModel, Field import numpy as np from scipy.io import wavfile from fastapi.responses import FileResponse def _pick_cache_dir() -> Path: for c in [Path("/data/cache"), Path("/tmp/audiocraft_cache"), Path.cwd() / "cache"]: try: c.mkdir(parents=True, exist_ok=True) (c / ".w").touch(); (c / ".w").unlink() return c except Exception: pass return Path.cwd() CACHE_DIR = _pick_cache_dir() for sub in ["models", "huggingface", "transformers", "drum_cache", "cache"]: (CACHE_DIR / sub).mkdir(parents=True, exist_ok=True) os.environ["AUDIOCRAFT_CACHE_DIR"] = str(CACHE_DIR) os.environ["XDG_CACHE_HOME"] = str(CACHE_DIR) os.environ["TORCH_HOME"] = str(CACHE_DIR / "cache") os.environ["HF_HOME"] = str(CACHE_DIR / "huggingface") os.environ["TRANSFORMERS_CACHE"] = str(CACHE_DIR / "transformers") os.environ["NUMBA_DISABLE_JIT"] = "1" # Do NOT force-enable fast transfer; handle it dynamically below. # os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # (removed) # ----------------------------- # Imports # ----------------------------- import io import re import json import wave import time import base64 import random import hashlib import zipfile from tempfile import NamedTemporaryFile from typing import Optional, List, Tuple, Union, Optional import numpy as np import torch from fastapi import FastAPI, Request, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse # Hugging Face auth helpers from huggingface_hub import login as hf_login, HfApi, HfFolder from huggingface_hub.utils import HfHubHTTPError # JASCO / AudioCraft from audiocraft.data.audio_utils import f32_pcm, normalize_audio from audiocraft.data.audio import audio_write from audiocraft.models import JASCO # ----------------------------- # App boilerplate # ----------------------------- app = FastAPI(title="JASCO /predict (HF auth)") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) # ----------------------------- # Hugging Face auth utilities # ----------------------------- def _get_hf_token() -> Optional[str]: for k in ("HUGGINGFACE_HUB_TOKEN", "HUGGINGFACEHUB_API_TOKEN", "HF_TOKEN", "HFTOKEN"): v = os.getenv(k) if v: return v.strip() return None HF_TOKEN = _get_hf_token() def ensure_hf_login(): """Login once; provide clear logging. No-op if no token (but gated models will fail).""" global HF_TOKEN if not HF_TOKEN: print("[HF] No token found in env (HUGGINGFACE_HUB_TOKEN / HUGGINGFACEHUB_API_TOKEN / HF_TOKEN / HFTOKEN).") return try: hf_login(token=HF_TOKEN, add_to_git_credential=False) HfFolder.save_token(HF_TOKEN) # persist under HF_HOME who = HfApi().whoami(token=HF_TOKEN) print(f"[HF] Logged in as: {who.get('name') or who.get('email') or who.get('username')}") except Exception as e: print(f"[HF] Login failed: {e}") @app.get("/hf-status") def hf_status(): token = _get_hf_token() out = {"token_present": bool(token)} try: if token: who = HfApi().whoami(token=token) out["whoami"] = who else: out["whoami"] = None except Exception as e: out["whoami_error"] = str(e) model_id = "facebook/jasco-chords-drums-melody-400M" try: api = HfApi() info = api.model_info(model_id, token=token) if token else api.model_info(model_id) out["model_access"] = True out["model_private"] = getattr(info, "private", None) out["gated"] = bool(getattr(info, "gated", False)) except Exception as e: out["model_access"] = False out["error"] = str(e) return out # ----------------------------- # Chords helpers # ----------------------------- def _default_chord_map(): chords = [ "N","C","Cm","C7","Cmaj7","Cm7","D","Dm","D7","Dmaj7","Dm7", "E","Em","E7","Emaj7","Em7","F","Fm","F7","Fmaj7","Fm7", ] return {ch:i for i,ch in enumerate(chords)} def _validate_chord(ch: str, mapping: dict) -> str: return ch if ch in mapping else "UNK" def chords_string_to_list(chords: str): if not chords or chords.strip() == "": return [] try: clean = chords.replace("[", "").replace("]", "").replace(" ", "") pairs = re.findall(r"\(([^,]+),([^)]+)\)", clean) mapping = _default_chord_map() return [(_validate_chord(ch.strip(), mapping), float(t.strip())) for ch, t in pairs] except Exception: return [] # ----------------------------- # Audio decoding (WAV stdlib) # ----------------------------- def _read_wav_bytes(raw: Optional[bytes]) -> Tuple[int, Optional[torch.Tensor]]: if not raw: return 32000, None try: with wave.open(io.BytesIO(raw), "rb") as wf: sr = wf.getframerate() ch = wf.getnchannels() sw = wf.getsampwidth() frames = wf.getnframes() buf = wf.readframes(frames) if sw == 2: data = np.frombuffer(buf, dtype=np.int16).astype(np.float32) / 32768.0 elif sw == 1: data = (np.frombuffer(buf, dtype=np.uint8).astype(np.float32) - 128) / 128.0 elif sw == 4: data = np.frombuffer(buf, dtype=np.float32) else: return 32000, None if ch > 1: data = data.reshape(-1, ch).T else: data = data[None, :] drums = f32_pcm(torch.from_numpy(data)).t() if drums.dim() == 1: drums = drums[None] drums = normalize_audio(drums, "loudness", loudness_headroom_db=16, sample_rate=sr) return sr, drums except Exception as e: print(f"[audio] WAV decode failed: {e}") return 32000, None def _read_uploadfile_to_bytes(file: Optional[UploadFile]) -> Optional[bytes]: if file is None: return None try: return file.file.read() except Exception: return None def _read_b64_to_bytes(b64str: Optional[str]) -> Optional[bytes]: if not b64str: return None try: s = b64str.strip() if s.startswith("data:"): s = s.split(",", 1)[1] return base64.b64decode(s) except Exception: return None # ----------------------------- # Model # ----------------------------- MODEL = None def _ensure_mapping_file() -> Path: import pickle mapping_file = CACHE_DIR / "chord_to_index_mapping.pkl" if not mapping_file.exists(): with open(mapping_file, "wb") as f: pickle.dump(_default_chord_map(), f) return mapping_file def load_model(name: str): """ Load JASCO, ensuring HF auth for gated repos. Falls back if hf_transfer is unavailable. """ global MODEL if MODEL is not None and getattr(MODEL, "name", None) == name: return MODEL # Ensure HF login ensure_hf_login() # Preflight access for clearer errors try: api = HfApi() token = _get_hf_token() _ = api.model_info(name, token=token) if token else api.model_info(name) except HfHubHTTPError as e: msg = ( f"Cannot access model '{name}'. This repo may be gated or private.\n" f"- Ensure your token has access and terms are accepted.\n" f"- Provide token via HUGGINGFACE_HUB_TOKEN (or HF_TOKEN/HFTOKEN/HUGGINGFACEHUB_API_TOKEN).\n" f"Hugging Face error: {e}" ) raise HTTPException(status_code=401, detail=msg) cache_path = CACHE_DIR / name.replace("/", "_") cache_path.mkdir(parents=True, exist_ok=True) os.environ["AUDIOCRAFT_CACHE_DIR"] = str(cache_path) os.environ["TRANSFORMERS_CACHE"] = str(cache_path / "transformers") mapping_file = _ensure_mapping_file() try: model = JASCO.get_pretrained(name, device="cpu", chords_mapping_path=str(mapping_file)) model.name = name import pickle if not hasattr(model, "chord_to_index"): with open(mapping_file, "rb") as f: model.chord_to_index = pickle.load(f) except HfHubHTTPError as e: raise HTTPException(status_code=401, detail=f"Model load failed due to HF auth/access: {e}") except Exception as e: raise HTTPException(status_code=500, detail=f"Model load failed11: {e}") MODEL = model return MODEL def set_gen_params(model, **kwargs): valid = None if hasattr(model, "get_generation_params"): try: valid = set(model.get_generation_params().keys()) except Exception: pass filtered, unknown = {}, [] for k, v in kwargs.items(): if valid is not None and k not in valid: unknown.append(k) else: filtered[k] = v print(f"[gen] request={kwargs}") if valid is not None: print(f"[gen] applied={filtered} unknown={unknown}") model.set_generation_params(**filtered) def _tensor_fp(t: Optional[torch.Tensor]) -> str: if t is None: return "NONE" try: x = t.detach().cpu().contiguous().float() return hashlib.sha1(x.numpy().tobytes()).hexdigest()[:10] except Exception: return "ERR" # ----------------------------- # Endpoints # ----------------------------- @app.get("/health") def health_check(): return { "status": "healthy", "model_loaded": MODEL is not None, "cache_dir": str(CACHE_DIR), } _TEXT_KEYS = ["text", "prompt", "description", "query", "message", "input", "content"] class PredictRequest(BaseModel): model: str = Field(default="facebook/jasco-chords-drums-melody-400M") text: str = "" chords_sym: str = "" n_samples: int = Field(default=2) seed: Optional[int] = Field(default=None) cfg_coef_all: float = Field(default=1.25) cfg_coef_txt: float = 2.5 ode_rtol: float = 1e-4 ode_atol: float = 1e-4 ode_solver: str = "euler" ode_steps: int = 10 drums_b64: Optional[str] = None drums_upload: Optional[UploadFile] = None class PredictResponse(BaseModel): status: str = "success" message: Optional[str] = None data: Optional[dict] = None def tensor_to_wav_scipy(tensor, sample_rate, filename): # Convert to numpy and ensure correct format audio_data = tensor.detach().cpu().numpy() # Normalize to 16-bit range audio_data = np.clip(audio_data, -1.0, 1.0) audio_data = (audio_data * 32767).astype(np.int16) # Save as WAV wavfile.write(filename, sample_rate, audio_data) class FileCleaner: def __init__(self, file_lifetime: float = 3600): self.file_lifetime = file_lifetime self.files = [] def add(self, path: Union[str, Path]): self._cleanup() self.files.append((time.time(), Path(path))) def _cleanup(self): now = time.time() for time_added, path in list(self.files): if now - time_added > self.file_lifetime: if path.exists(): path.unlink() self.files.pop(0) else: break file_cleaner = FileCleaner() @app.post("/predict") async def predict(request: Request): """ Returns a ZIP with jasco_1.wav, jasco_2.wav, ... Accepts: - multipart/form-data (fields + optional drums_file) - application/json (fields + optional drums_b64) """ ct = (request.headers.get("content-type") or "application/json").lower() params = { "model": "facebook/jasco-chords-drums-melody-400M", "text": "", "chords_sym": "", "n_samples": 1, "seed": None, "cfg_coef_all": 1.25, "cfg_coef_txt": 2.5, "ode_rtol": 1e-4, "ode_atol": 1e-4, "ode_solver": "euler", "ode_steps": 10, "drums_b64": None } drums_upload: Optional[UploadFile] = None try: if ct == "application/json": data = await request.json() if not isinstance(data, dict): raise HTTPException(status_code=400, detail="JSON body must be an object") for k in _TEXT_KEYS: if k in data and data[k]: params["text"] = data[k]; break params["model"] = data.get("model", params["model"]) params["chords_sym"] = data.get("chords_sym", params["chords_sym"]) params["n_samples"] = int(data.get("n_samples", params["n_samples"])) params["seed"] = data.get("seed", None) params["cfg_coef_all"] = float(data.get("cfg_coef_all", params["cfg_coef_all"])) params["cfg_coef_txt"] = float(data.get("cfg_coef_txt", params["cfg_coef_txt"])) params["ode_rtol"] = float(data.get("ode_rtol", params["ode_rtol"])) params["ode_atol"] = float(data.get("ode_atol", params["ode_atol"])) params["ode_solver"] = str(data.get("ode_solver", params["ode_solver"])).lower() params["ode_steps"] = int(data.get("ode_steps", params["ode_steps"])) params["drums_b64"] = data.get("drums_b64", None) raw_drums = _read_b64_to_bytes(params["drums_b64"]) else: form = await request.form() fd = {k: (form.get(k)) for k in form.keys() if form.get(k)} for k in _TEXT_KEYS: if k in fd and fd[k]: params["text"] = fd[k]; break params["model"] = fd.get("model", params["model"]) params["chords_sym"] = fd.get("chords_sym", params["chords_sym"]) params["n_samples"] = int(fd.get("n_samples", params["n_samples"])) params["seed"] = fd.get("seed", None) params["cfg_coef_all"] = float(fd.get("cfg_coef_all", params["cfg_coef_all"])) params["cfg_coef_txt"] = float(fd.get("cfg_coef_txt", params["cfg_coef_txt"])) params["ode_rtol"] = float(fd.get("ode_rtol", params["ode_rtol"])) params["ode_atol"] = float(fd.get("ode_atol", params["ode_atol"])) params["ode_solver"] = str(fd.get("ode_solver", params["ode_solver"])).lower() params["ode_steps"] = int(fd.get("ode_steps", params["ode_steps"])) params["drums_b64"] = fd.get("drums_b64", None) drums_upload = form.get("drums_file") raw_drums = _read_uploadfile_to_bytes(drums_upload) if drums_upload else _read_b64_to_bytes(params["drums_b64"]) except HTTPException: raise except Exception as e: raise HTTPException(status_code=400, detail=f"Bad request: {e}") print(json.dumps({ "ct": ct, "text_len": len(params["text"] or ""), "text_preview": (params["text"] or "")[:120], "model": params["model"], "n_samples": params["n_samples"], "has_drums_bytes": raw_drums is not None, })) model = load_model(params["model"]) # may raise HTTPException(401/500) drums_sr, drums_tensor = _read_wav_bytes(raw_drums) print(f"[predict] drums_present={drums_tensor is not None} sr={drums_sr} drums_fp={_tensor_fp(drums_tensor)}") base_seed = int(params["seed"]) if params["seed"] is not None else (int(time.time() * 1000) & 0xFFFFFFFF) random.seed(base_seed); np.random.seed(base_seed); torch.manual_seed(base_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(base_seed) set_gen_params( model, cfg_coef_all=float(params["cfg_coef_all"]), cfg_coef_txt=float(params["cfg_coef_txt"]), ode_rtol=float(params["ode_rtol"]), ode_atol=float(params["ode_atol"]), euler=(params["ode_solver"] == "euler"), euler_steps=int(params["ode_steps"]) ) texts = [params["text"]] * max(1, int(params["n_samples"])) chords_list = chords_string_to_list(params["chords_sym"]) print(f"[predictdebug] chords_list={chords_list}") print(f"[predictdebug] drums_tensor={drums_tensor}") print(f"[predictdebug] drums_sr={drums_sr}") print(f"[predictdebug] model={model}") print(f"[predictdebug] texts={texts}") try: outputs = model.generate_music( descriptions=texts, chords=chords_list, drums_wav=drums_tensor, melody_salience_matrix=None, drums_sample_rate=drums_sr, progress=False ) except Exception as e: raise HTTPException(status_code=500, detail=f"Generation failed: {e}") # Usage: # for i, wav in enumerate(outputs): # with NamedTemporaryFile("wb", suffix=".wav", delete=False) as f: # tensor_to_wav_scipy(wav, model.sample_rate, f.name) # zf.write(f.name, arcname=f"jasco_{i+1}.wav") print(f"[predictdebug] outputs={outputs}") # Log the raw model outputs # Convert model outputs from GPU tensor to CPU float tensor for processing outputs = outputs.detach().cpu().float() print(f"[predictdebug] outputs converted to cpu={outputs}") # Log the converted outputs with NamedTemporaryFile("wb", suffix=".wav", delete=False) as f: tmp_path = f.name audio_write( tmp_path, outputs[0], MODEL.sample_rate, # or model.sample_rate — be consistent strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False, ) return FileResponse( path=tmp_path, media_type="audio/wav", filename="jasco_output.wav" ) if __name__ == "__main__": ensure_hf_login() import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860) # outputs = [1,2] # outputs[1] = [name, wav] # wav =[0.39203242, ]