audiocraft-2 / main.py
adityas129's picture
Update main.py
7a48d5d verified
#!/usr/bin/env python3
"""
FastAPI JASCO server (Postman-ready) with robust Hugging Face auth and
safe handling of HF fast-transfer (hf_transfer).
- If HF_HUB_ENABLE_HF_TRANSFER=1 but 'hf_transfer' isn't installed, we fall back
to standard downloads and log a warning instead of failing.
- POST /predict supports multipart (drums_file) and JSON (drums_b64).
- GET /hf-status shows auth and model access.
Run:
export HUGGINGFACE_HUB_TOKEN=hf_xxx # or HF_TOKEN/HFTOKEN/HUGGINGFACEHUB_API_TOKEN
uvicorn main:app --host 0.0.0.0 --port 7860
"""
# -----------------------------
# Environment (HF Spaces-friendly)
# -----------------------------
import os
from pathlib import Path
from requests import Request, Response
from pydantic import BaseModel, Field
import numpy as np
from scipy.io import wavfile
from fastapi.responses import FileResponse
def _pick_cache_dir() -> Path:
for c in [Path("/data/cache"), Path("/tmp/audiocraft_cache"), Path.cwd() / "cache"]:
try:
c.mkdir(parents=True, exist_ok=True)
(c / ".w").touch(); (c / ".w").unlink()
return c
except Exception:
pass
return Path.cwd()
CACHE_DIR = _pick_cache_dir()
for sub in ["models", "huggingface", "transformers", "drum_cache", "cache"]:
(CACHE_DIR / sub).mkdir(parents=True, exist_ok=True)
os.environ["AUDIOCRAFT_CACHE_DIR"] = str(CACHE_DIR)
os.environ["XDG_CACHE_HOME"] = str(CACHE_DIR)
os.environ["TORCH_HOME"] = str(CACHE_DIR / "cache")
os.environ["HF_HOME"] = str(CACHE_DIR / "huggingface")
os.environ["TRANSFORMERS_CACHE"] = str(CACHE_DIR / "transformers")
os.environ["NUMBA_DISABLE_JIT"] = "1"
# Do NOT force-enable fast transfer; handle it dynamically below.
# os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # (removed)
# -----------------------------
# Imports
# -----------------------------
import io
import re
import json
import wave
import time
import base64
import random
import hashlib
import zipfile
from tempfile import NamedTemporaryFile
from typing import Optional, List, Tuple, Union, Optional
import numpy as np
import torch
from fastapi import FastAPI, Request, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
# Hugging Face auth helpers
from huggingface_hub import login as hf_login, HfApi, HfFolder
from huggingface_hub.utils import HfHubHTTPError
# JASCO / AudioCraft
from audiocraft.data.audio_utils import f32_pcm, normalize_audio
from audiocraft.data.audio import audio_write
from audiocraft.models import JASCO
# -----------------------------
# App boilerplate
# -----------------------------
app = FastAPI(title="JASCO /predict (HF auth)")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
)
# -----------------------------
# Hugging Face auth utilities
# -----------------------------
def _get_hf_token() -> Optional[str]:
for k in ("HUGGINGFACE_HUB_TOKEN", "HUGGINGFACEHUB_API_TOKEN", "HF_TOKEN", "HFTOKEN"):
v = os.getenv(k)
if v:
return v.strip()
return None
HF_TOKEN = _get_hf_token()
def ensure_hf_login():
"""Login once; provide clear logging. No-op if no token (but gated models will fail)."""
global HF_TOKEN
if not HF_TOKEN:
print("[HF] No token found in env (HUGGINGFACE_HUB_TOKEN / HUGGINGFACEHUB_API_TOKEN / HF_TOKEN / HFTOKEN).")
return
try:
hf_login(token=HF_TOKEN, add_to_git_credential=False)
HfFolder.save_token(HF_TOKEN) # persist under HF_HOME
who = HfApi().whoami(token=HF_TOKEN)
print(f"[HF] Logged in as: {who.get('name') or who.get('email') or who.get('username')}")
except Exception as e:
print(f"[HF] Login failed: {e}")
@app.get("/hf-status")
def hf_status():
token = _get_hf_token()
out = {"token_present": bool(token)}
try:
if token:
who = HfApi().whoami(token=token)
out["whoami"] = who
else:
out["whoami"] = None
except Exception as e:
out["whoami_error"] = str(e)
model_id = "facebook/jasco-chords-drums-melody-400M"
try:
api = HfApi()
info = api.model_info(model_id, token=token) if token else api.model_info(model_id)
out["model_access"] = True
out["model_private"] = getattr(info, "private", None)
out["gated"] = bool(getattr(info, "gated", False))
except Exception as e:
out["model_access"] = False
out["error"] = str(e)
return out
# -----------------------------
# Chords helpers
# -----------------------------
def _default_chord_map():
chords = [
"N","C","Cm","C7","Cmaj7","Cm7","D","Dm","D7","Dmaj7","Dm7",
"E","Em","E7","Emaj7","Em7","F","Fm","F7","Fmaj7","Fm7",
]
return {ch:i for i,ch in enumerate(chords)}
def _validate_chord(ch: str, mapping: dict) -> str:
return ch if ch in mapping else "UNK"
def chords_string_to_list(chords: str):
if not chords or chords.strip() == "":
return []
try:
clean = chords.replace("[", "").replace("]", "").replace(" ", "")
pairs = re.findall(r"\(([^,]+),([^)]+)\)", clean)
mapping = _default_chord_map()
return [(_validate_chord(ch.strip(), mapping), float(t.strip())) for ch, t in pairs]
except Exception:
return []
# -----------------------------
# Audio decoding (WAV stdlib)
# -----------------------------
def _read_wav_bytes(raw: Optional[bytes]) -> Tuple[int, Optional[torch.Tensor]]:
if not raw:
return 32000, None
try:
with wave.open(io.BytesIO(raw), "rb") as wf:
sr = wf.getframerate()
ch = wf.getnchannels()
sw = wf.getsampwidth()
frames = wf.getnframes()
buf = wf.readframes(frames)
if sw == 2: data = np.frombuffer(buf, dtype=np.int16).astype(np.float32) / 32768.0
elif sw == 1: data = (np.frombuffer(buf, dtype=np.uint8).astype(np.float32) - 128) / 128.0
elif sw == 4: data = np.frombuffer(buf, dtype=np.float32)
else: return 32000, None
if ch > 1: data = data.reshape(-1, ch).T
else: data = data[None, :]
drums = f32_pcm(torch.from_numpy(data)).t()
if drums.dim() == 1:
drums = drums[None]
drums = normalize_audio(drums, "loudness", loudness_headroom_db=16, sample_rate=sr)
return sr, drums
except Exception as e:
print(f"[audio] WAV decode failed: {e}")
return 32000, None
def _read_uploadfile_to_bytes(file: Optional[UploadFile]) -> Optional[bytes]:
if file is None:
return None
try:
return file.file.read()
except Exception:
return None
def _read_b64_to_bytes(b64str: Optional[str]) -> Optional[bytes]:
if not b64str:
return None
try:
s = b64str.strip()
if s.startswith("data:"):
s = s.split(",", 1)[1]
return base64.b64decode(s)
except Exception:
return None
# -----------------------------
# Model
# -----------------------------
MODEL = None
def _ensure_mapping_file() -> Path:
import pickle
mapping_file = CACHE_DIR / "chord_to_index_mapping.pkl"
if not mapping_file.exists():
with open(mapping_file, "wb") as f:
pickle.dump(_default_chord_map(), f)
return mapping_file
def load_model(name: str):
"""
Load JASCO, ensuring HF auth for gated repos.
Falls back if hf_transfer is unavailable.
"""
global MODEL
if MODEL is not None and getattr(MODEL, "name", None) == name:
return MODEL
# Ensure HF login
ensure_hf_login()
# Preflight access for clearer errors
try:
api = HfApi()
token = _get_hf_token()
_ = api.model_info(name, token=token) if token else api.model_info(name)
except HfHubHTTPError as e:
msg = (
f"Cannot access model '{name}'. This repo may be gated or private.\n"
f"- Ensure your token has access and terms are accepted.\n"
f"- Provide token via HUGGINGFACE_HUB_TOKEN (or HF_TOKEN/HFTOKEN/HUGGINGFACEHUB_API_TOKEN).\n"
f"Hugging Face error: {e}"
)
raise HTTPException(status_code=401, detail=msg)
cache_path = CACHE_DIR / name.replace("/", "_")
cache_path.mkdir(parents=True, exist_ok=True)
os.environ["AUDIOCRAFT_CACHE_DIR"] = str(cache_path)
os.environ["TRANSFORMERS_CACHE"] = str(cache_path / "transformers")
mapping_file = _ensure_mapping_file()
try:
model = JASCO.get_pretrained(name, device="cpu", chords_mapping_path=str(mapping_file))
model.name = name
import pickle
if not hasattr(model, "chord_to_index"):
with open(mapping_file, "rb") as f:
model.chord_to_index = pickle.load(f)
except HfHubHTTPError as e:
raise HTTPException(status_code=401, detail=f"Model load failed due to HF auth/access: {e}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Model load failed11: {e}")
MODEL = model
return MODEL
def set_gen_params(model, **kwargs):
valid = None
if hasattr(model, "get_generation_params"):
try:
valid = set(model.get_generation_params().keys())
except Exception:
pass
filtered, unknown = {}, []
for k, v in kwargs.items():
if valid is not None and k not in valid:
unknown.append(k)
else:
filtered[k] = v
print(f"[gen] request={kwargs}")
if valid is not None:
print(f"[gen] applied={filtered} unknown={unknown}")
model.set_generation_params(**filtered)
def _tensor_fp(t: Optional[torch.Tensor]) -> str:
if t is None:
return "NONE"
try:
x = t.detach().cpu().contiguous().float()
return hashlib.sha1(x.numpy().tobytes()).hexdigest()[:10]
except Exception:
return "ERR"
# -----------------------------
# Endpoints
# -----------------------------
@app.get("/health")
def health_check():
return {
"status": "healthy",
"model_loaded": MODEL is not None,
"cache_dir": str(CACHE_DIR),
}
_TEXT_KEYS = ["text", "prompt", "description", "query", "message", "input", "content"]
class PredictRequest(BaseModel):
model: str = Field(default="facebook/jasco-chords-drums-melody-400M")
text: str = ""
chords_sym: str = ""
n_samples: int = Field(default=2)
seed: Optional[int] = Field(default=None)
cfg_coef_all: float = Field(default=1.25)
cfg_coef_txt: float = 2.5
ode_rtol: float = 1e-4
ode_atol: float = 1e-4
ode_solver: str = "euler"
ode_steps: int = 10
drums_b64: Optional[str] = None
drums_upload: Optional[UploadFile] = None
class PredictResponse(BaseModel):
status: str = "success"
message: Optional[str] = None
data: Optional[dict] = None
def tensor_to_wav_scipy(tensor, sample_rate, filename):
# Convert to numpy and ensure correct format
audio_data = tensor.detach().cpu().numpy()
# Normalize to 16-bit range
audio_data = np.clip(audio_data, -1.0, 1.0)
audio_data = (audio_data * 32767).astype(np.int16)
# Save as WAV
wavfile.write(filename, sample_rate, audio_data)
class FileCleaner:
def __init__(self, file_lifetime: float = 3600):
self.file_lifetime = file_lifetime
self.files = []
def add(self, path: Union[str, Path]):
self._cleanup()
self.files.append((time.time(), Path(path)))
def _cleanup(self):
now = time.time()
for time_added, path in list(self.files):
if now - time_added > self.file_lifetime:
if path.exists():
path.unlink()
self.files.pop(0)
else:
break
file_cleaner = FileCleaner()
@app.post("/predict")
async def predict(request: Request):
"""
Returns a ZIP with jasco_1.wav, jasco_2.wav, ...
Accepts:
- multipart/form-data (fields + optional drums_file)
- application/json (fields + optional drums_b64)
"""
ct = (request.headers.get("content-type") or "application/json").lower()
params = {
"model": "facebook/jasco-chords-drums-melody-400M",
"text": "",
"chords_sym": "",
"n_samples": 1,
"seed": None,
"cfg_coef_all": 1.25,
"cfg_coef_txt": 2.5,
"ode_rtol": 1e-4,
"ode_atol": 1e-4,
"ode_solver": "euler",
"ode_steps": 10,
"drums_b64": None
}
drums_upload: Optional[UploadFile] = None
try:
if ct == "application/json":
data = await request.json()
if not isinstance(data, dict):
raise HTTPException(status_code=400, detail="JSON body must be an object")
for k in _TEXT_KEYS:
if k in data and data[k]:
params["text"] = data[k]; break
params["model"] = data.get("model", params["model"])
params["chords_sym"] = data.get("chords_sym", params["chords_sym"])
params["n_samples"] = int(data.get("n_samples", params["n_samples"]))
params["seed"] = data.get("seed", None)
params["cfg_coef_all"] = float(data.get("cfg_coef_all", params["cfg_coef_all"]))
params["cfg_coef_txt"] = float(data.get("cfg_coef_txt", params["cfg_coef_txt"]))
params["ode_rtol"] = float(data.get("ode_rtol", params["ode_rtol"]))
params["ode_atol"] = float(data.get("ode_atol", params["ode_atol"]))
params["ode_solver"] = str(data.get("ode_solver", params["ode_solver"])).lower()
params["ode_steps"] = int(data.get("ode_steps", params["ode_steps"]))
params["drums_b64"] = data.get("drums_b64", None)
raw_drums = _read_b64_to_bytes(params["drums_b64"])
else:
form = await request.form()
fd = {k: (form.get(k)) for k in form.keys() if form.get(k)}
for k in _TEXT_KEYS:
if k in fd and fd[k]:
params["text"] = fd[k]; break
params["model"] = fd.get("model", params["model"])
params["chords_sym"] = fd.get("chords_sym", params["chords_sym"])
params["n_samples"] = int(fd.get("n_samples", params["n_samples"]))
params["seed"] = fd.get("seed", None)
params["cfg_coef_all"] = float(fd.get("cfg_coef_all", params["cfg_coef_all"]))
params["cfg_coef_txt"] = float(fd.get("cfg_coef_txt", params["cfg_coef_txt"]))
params["ode_rtol"] = float(fd.get("ode_rtol", params["ode_rtol"]))
params["ode_atol"] = float(fd.get("ode_atol", params["ode_atol"]))
params["ode_solver"] = str(fd.get("ode_solver", params["ode_solver"])).lower()
params["ode_steps"] = int(fd.get("ode_steps", params["ode_steps"]))
params["drums_b64"] = fd.get("drums_b64", None)
drums_upload = form.get("drums_file")
raw_drums = _read_uploadfile_to_bytes(drums_upload) if drums_upload else _read_b64_to_bytes(params["drums_b64"])
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad request: {e}")
print(json.dumps({
"ct": ct,
"text_len": len(params["text"] or ""),
"text_preview": (params["text"] or "")[:120],
"model": params["model"],
"n_samples": params["n_samples"],
"has_drums_bytes": raw_drums is not None,
}))
model = load_model(params["model"]) # may raise HTTPException(401/500)
drums_sr, drums_tensor = _read_wav_bytes(raw_drums)
print(f"[predict] drums_present={drums_tensor is not None} sr={drums_sr} drums_fp={_tensor_fp(drums_tensor)}")
base_seed = int(params["seed"]) if params["seed"] is not None else (int(time.time() * 1000) & 0xFFFFFFFF)
random.seed(base_seed); np.random.seed(base_seed); torch.manual_seed(base_seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(base_seed)
set_gen_params(
model,
cfg_coef_all=float(params["cfg_coef_all"]),
cfg_coef_txt=float(params["cfg_coef_txt"]),
ode_rtol=float(params["ode_rtol"]),
ode_atol=float(params["ode_atol"]),
euler=(params["ode_solver"] == "euler"),
euler_steps=int(params["ode_steps"])
)
texts = [params["text"]] * max(1, int(params["n_samples"]))
chords_list = chords_string_to_list(params["chords_sym"])
print(f"[predictdebug] chords_list={chords_list}")
print(f"[predictdebug] drums_tensor={drums_tensor}")
print(f"[predictdebug] drums_sr={drums_sr}")
print(f"[predictdebug] model={model}")
print(f"[predictdebug] texts={texts}")
try:
outputs = model.generate_music(
descriptions=texts,
chords=chords_list,
drums_wav=drums_tensor,
melody_salience_matrix=None,
drums_sample_rate=drums_sr,
progress=False
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
# Usage:
# for i, wav in enumerate(outputs):
# with NamedTemporaryFile("wb", suffix=".wav", delete=False) as f:
# tensor_to_wav_scipy(wav, model.sample_rate, f.name)
# zf.write(f.name, arcname=f"jasco_{i+1}.wav")
print(f"[predictdebug] outputs={outputs}") # Log the raw model outputs
# Convert model outputs from GPU tensor to CPU float tensor for processing
outputs = outputs.detach().cpu().float()
print(f"[predictdebug] outputs converted to cpu={outputs}") # Log the converted outputs
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as f:
tmp_path = f.name
audio_write(
tmp_path,
outputs[0],
MODEL.sample_rate, # or model.sample_rate β€” be consistent
strategy="loudness",
loudness_headroom_db=16,
loudness_compressor=True,
add_suffix=False,
)
return FileResponse(
path=tmp_path,
media_type="audio/wav",
filename="jasco_output.wav"
)
if __name__ == "__main__":
ensure_hf_login()
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
# outputs = [1,2]
# outputs[1] = [name, wav]
# wav =[0.39203242, ]