File size: 18,583 Bytes
133e6e1
7a48d5d
f33feb1
 
 
 
 
 
 
 
 
 
 
133e6e1
1bd62bb
4fbdfab
 
 
 
2abec95
7a48d5d
 
 
 
 
 
1bd62bb
ae72f5f
4fbdfab
ae72f5f
 
4fbdfab
ae72f5f
 
4fbdfab
ae72f5f
 
 
a69dc08
 
ae72f5f
 
 
 
 
 
4fbdfab
f33feb1
 
ae72f5f
4fbdfab
3e643af
4fbdfab
 
 
 
 
 
 
 
 
 
125c8b0
7a48d5d
125c8b0
fc5131b
266a885
4fbdfab
125c8b0
4fbdfab
ae72f5f
6ee99f8
f33feb1
6ee99f8
 
4fbdfab
266a885
 
 
4838ba1
4fbdfab
 
 
f33feb1
4fbdfab
 
 
a69dc08
4fbdfab
f33feb1
6ee99f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f33feb1
6ee99f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f33feb1
 
 
 
 
6ee99f8
 
 
 
 
4fbdfab
 
 
 
125c8b0
8e9cf60
 
125c8b0
a69dc08
133e6e1
4fbdfab
a69dc08
3e5a149
 
4fbdfab
 
3e5a149
4fbdfab
ae72f5f
4fbdfab
 
 
3e5a149
 
4fbdfab
f33feb1
4fbdfab
 
 
 
ae72f5f
4fbdfab
 
 
 
 
 
 
 
 
 
 
 
f33feb1
 
4fbdfab
f33feb1
4fbdfab
 
 
 
ae72f5f
4fbdfab
 
3e5a149
4fbdfab
 
 
3e643af
4fbdfab
 
 
3e643af
4fbdfab
 
 
3e643af
4fbdfab
 
 
 
3e643af
4fbdfab
3e643af
4fbdfab
3e643af
4fbdfab
125c8b0
 
6ee99f8
 
 
 
 
 
 
 
4fbdfab
6ee99f8
 
f33feb1
6ee99f8
266a885
4fbdfab
 
266a885
f33feb1
 
6ee99f8
 
f33feb1
6ee99f8
 
f33feb1
 
6ee99f8
 
 
f33feb1
 
6ee99f8
 
 
 
4fbdfab
 
 
 
ae72f5f
6ee99f8
4fbdfab
fc5131b
7a48d5d
6ee99f8
4fbdfab
 
 
 
6ee99f8
f33feb1
6ee99f8
f33feb1
fc5131b
4fbdfab
 
3e643af
4fbdfab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e643af
4fbdfab
 
 
 
3e5a149
4fbdfab
 
 
133e6e1
 
f33feb1
 
 
 
 
133e6e1
4fbdfab
 
7a48d5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e5a149
3e643af
 
4fbdfab
3e643af
4fbdfab
 
3e643af
7a48d5d
3e643af
4fbdfab
3e643af
 
8e9cf60
7a48d5d
4fbdfab
 
 
 
 
 
 
 
7a48d5d
4fbdfab
3e643af
125c8b0
7a48d5d
4fbdfab
a69dc08
 
4fbdfab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e643af
 
7a48d5d
4fbdfab
 
 
 
 
 
 
 
 
 
 
 
 
 
f33feb1
4fbdfab
a69dc08
 
3e643af
4fbdfab
 
 
 
 
 
 
 
f33feb1
4fbdfab
 
f33feb1
ae72f5f
4fbdfab
 
 
 
f33feb1
 
4fbdfab
 
 
 
 
 
 
 
 
3e643af
 
4fbdfab
 
4c4f1bd
 
 
 
 
3e643af
3e5a149
4fbdfab
 
 
 
 
 
 
125c8b0
8493b64
3e5a149
ae72f5f
7a48d5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a69dc08
4fbdfab
6ee99f8
4fbdfab
4c4f1bd
7a48d5d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
#!/usr/bin/env python3
""" 
FastAPI JASCO server (Postman-ready) with robust Hugging Face auth and
safe handling of HF fast-transfer (hf_transfer).

- If HF_HUB_ENABLE_HF_TRANSFER=1 but 'hf_transfer' isn't installed, we fall back
  to standard downloads and log a warning instead of failing.
- POST /predict supports multipart (drums_file) and JSON (drums_b64).
- GET /hf-status shows auth and model access.

Run:
  export HUGGINGFACE_HUB_TOKEN=hf_xxx   # or HF_TOKEN/HFTOKEN/HUGGINGFACEHUB_API_TOKEN
  uvicorn main:app --host 0.0.0.0 --port 7860
"""

# -----------------------------
# Environment (HF Spaces-friendly)
# -----------------------------
import os
from pathlib import Path
from requests import Request, Response      
from pydantic import BaseModel, Field
import numpy as np
from scipy.io import wavfile
from fastapi.responses import FileResponse


def _pick_cache_dir() -> Path:
    for c in [Path("/data/cache"), Path("/tmp/audiocraft_cache"), Path.cwd() / "cache"]:
        try:
            c.mkdir(parents=True, exist_ok=True)
            (c / ".w").touch(); (c / ".w").unlink()
            return c
        except Exception:
            pass
    return Path.cwd()

CACHE_DIR = _pick_cache_dir()
for sub in ["models", "huggingface", "transformers", "drum_cache", "cache"]:
    (CACHE_DIR / sub).mkdir(parents=True, exist_ok=True)

os.environ["AUDIOCRAFT_CACHE_DIR"] = str(CACHE_DIR)
os.environ["XDG_CACHE_HOME"] = str(CACHE_DIR)
os.environ["TORCH_HOME"] = str(CACHE_DIR / "cache")
os.environ["HF_HOME"] = str(CACHE_DIR / "huggingface")
os.environ["TRANSFORMERS_CACHE"] = str(CACHE_DIR / "transformers")
os.environ["NUMBA_DISABLE_JIT"] = "1"
# Do NOT force-enable fast transfer; handle it dynamically below.
# os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")  # (removed)

# -----------------------------
# Imports
# -----------------------------
import io
import re
import json
import wave
import time
import base64
import random
import hashlib
import zipfile
from tempfile import NamedTemporaryFile
from typing import Optional, List, Tuple, Union, Optional           

import numpy as np
import torch
from fastapi import FastAPI, Request, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse

# Hugging Face auth helpers
from huggingface_hub import login as hf_login, HfApi, HfFolder
from huggingface_hub.utils import HfHubHTTPError

# JASCO / AudioCraft
from audiocraft.data.audio_utils import f32_pcm, normalize_audio
from audiocraft.data.audio import audio_write
from audiocraft.models import JASCO

# -----------------------------
# App boilerplate
# -----------------------------
app = FastAPI(title="JASCO /predict (HF auth)")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
)


# -----------------------------
# Hugging Face auth utilities
# -----------------------------
def _get_hf_token() -> Optional[str]:
    for k in ("HUGGINGFACE_HUB_TOKEN", "HUGGINGFACEHUB_API_TOKEN", "HF_TOKEN", "HFTOKEN"):
        v = os.getenv(k)
        if v:
            return v.strip()
    return None

HF_TOKEN = _get_hf_token()

def ensure_hf_login():
    """Login once; provide clear logging. No-op if no token (but gated models will fail)."""
    global HF_TOKEN
    if not HF_TOKEN:
        print("[HF] No token found in env (HUGGINGFACE_HUB_TOKEN / HUGGINGFACEHUB_API_TOKEN / HF_TOKEN / HFTOKEN).")
        return
    try:
        hf_login(token=HF_TOKEN, add_to_git_credential=False)
        HfFolder.save_token(HF_TOKEN)  # persist under HF_HOME
        who = HfApi().whoami(token=HF_TOKEN)
        print(f"[HF] Logged in as: {who.get('name') or who.get('email') or who.get('username')}")
    except Exception as e:
        print(f"[HF] Login failed: {e}")

@app.get("/hf-status")
def hf_status():
    token = _get_hf_token()
    out = {"token_present": bool(token)}
    try:
        if token:
            who = HfApi().whoami(token=token)
            out["whoami"] = who
        else:
            out["whoami"] = None
    except Exception as e:
        out["whoami_error"] = str(e)
    model_id = "facebook/jasco-chords-drums-melody-400M"
    try:
        api = HfApi()
        info = api.model_info(model_id, token=token) if token else api.model_info(model_id)
        out["model_access"] = True
        out["model_private"] = getattr(info, "private", None)
        out["gated"] = bool(getattr(info, "gated", False))
    except Exception as e:
        out["model_access"] = False
        out["error"] = str(e)
    return out

# -----------------------------
# Chords helpers
# -----------------------------
def _default_chord_map():
    chords = [
        "N","C","Cm","C7","Cmaj7","Cm7","D","Dm","D7","Dmaj7","Dm7",
        "E","Em","E7","Emaj7","Em7","F","Fm","F7","Fmaj7","Fm7",
    ]
    return {ch:i for i,ch in enumerate(chords)}

def _validate_chord(ch: str, mapping: dict) -> str:
    return ch if ch in mapping else "UNK"

def chords_string_to_list(chords: str):
    if not chords or chords.strip() == "":
        return []
    try:
        clean = chords.replace("[", "").replace("]", "").replace(" ", "")
        pairs = re.findall(r"\(([^,]+),([^)]+)\)", clean)
        mapping = _default_chord_map()
        return [(_validate_chord(ch.strip(), mapping), float(t.strip())) for ch, t in pairs]
    except Exception:
        return []

# -----------------------------
# Audio decoding (WAV stdlib)
# -----------------------------
def _read_wav_bytes(raw: Optional[bytes]) -> Tuple[int, Optional[torch.Tensor]]:
    if not raw:
        return 32000, None
    try:
        with wave.open(io.BytesIO(raw), "rb") as wf:
            sr = wf.getframerate()
            ch = wf.getnchannels()
            sw = wf.getsampwidth()
            frames = wf.getnframes()
            buf = wf.readframes(frames)

        if   sw == 2: data = np.frombuffer(buf, dtype=np.int16).astype(np.float32) / 32768.0
        elif sw == 1: data = (np.frombuffer(buf, dtype=np.uint8).astype(np.float32) - 128) / 128.0
        elif sw == 4: data = np.frombuffer(buf, dtype=np.float32)
        else:         return 32000, None

        if ch > 1: data = data.reshape(-1, ch).T
        else:      data = data[None, :]

        drums = f32_pcm(torch.from_numpy(data)).t()
        if drums.dim() == 1:
            drums = drums[None]
        drums = normalize_audio(drums, "loudness", loudness_headroom_db=16, sample_rate=sr)
        return sr, drums
    except Exception as e:
        print(f"[audio] WAV decode failed: {e}")
        return 32000, None

def _read_uploadfile_to_bytes(file: Optional[UploadFile]) -> Optional[bytes]:
    if file is None:
        return None
    try:
        return file.file.read()
    except Exception:
        return None

def _read_b64_to_bytes(b64str: Optional[str]) -> Optional[bytes]:
    if not b64str:
        return None
    try:
        s = b64str.strip()
        if s.startswith("data:"):
            s = s.split(",", 1)[1]
        return base64.b64decode(s)
    except Exception:
        return None

# -----------------------------
# Model
# -----------------------------
MODEL = None

def _ensure_mapping_file() -> Path:
    import pickle
    mapping_file = CACHE_DIR / "chord_to_index_mapping.pkl"
    if not mapping_file.exists():
        with open(mapping_file, "wb") as f:
            pickle.dump(_default_chord_map(), f)
    return mapping_file

def load_model(name: str):
    """
    Load JASCO, ensuring HF auth for gated repos.
    Falls back if hf_transfer is unavailable.
    """
    global MODEL
    if MODEL is not None and getattr(MODEL, "name", None) == name:
        return MODEL


    # Ensure HF login
    ensure_hf_login()

    # Preflight access for clearer errors
    try:
        api = HfApi()
        token = _get_hf_token()
        _ = api.model_info(name, token=token) if token else api.model_info(name)
    except HfHubHTTPError as e:
        msg = (
            f"Cannot access model '{name}'. This repo may be gated or private.\n"
            f"- Ensure your token has access and terms are accepted.\n"
            f"- Provide token via HUGGINGFACE_HUB_TOKEN (or HF_TOKEN/HFTOKEN/HUGGINGFACEHUB_API_TOKEN).\n"
            f"Hugging Face error: {e}"
        )
        raise HTTPException(status_code=401, detail=msg)

    cache_path = CACHE_DIR / name.replace("/", "_")
    cache_path.mkdir(parents=True, exist_ok=True)
    os.environ["AUDIOCRAFT_CACHE_DIR"] = str(cache_path)
    os.environ["TRANSFORMERS_CACHE"] = str(cache_path / "transformers")

    mapping_file = _ensure_mapping_file()

    try:
        model = JASCO.get_pretrained(name, device="cpu", chords_mapping_path=str(mapping_file))
        model.name = name
        import pickle
        if not hasattr(model, "chord_to_index"):
            with open(mapping_file, "rb") as f:
                model.chord_to_index = pickle.load(f)
    except HfHubHTTPError as e:
        raise HTTPException(status_code=401, detail=f"Model load failed due to HF auth/access: {e}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Model load failed11: {e}")

    MODEL = model
    return MODEL

def set_gen_params(model, **kwargs):
    valid = None
    if hasattr(model, "get_generation_params"):
        try:
            valid = set(model.get_generation_params().keys())
        except Exception:
            pass
    filtered, unknown = {}, []
    for k, v in kwargs.items():
        if valid is not None and k not in valid:
            unknown.append(k)
        else:
            filtered[k] = v
    print(f"[gen] request={kwargs}")
    if valid is not None:
        print(f"[gen] applied={filtered} unknown={unknown}")
    model.set_generation_params(**filtered)

def _tensor_fp(t: Optional[torch.Tensor]) -> str:
    if t is None:
        return "NONE"
    try:
        x = t.detach().cpu().contiguous().float()
        return hashlib.sha1(x.numpy().tobytes()).hexdigest()[:10]
    except Exception:
        return "ERR"

# -----------------------------
# Endpoints
# -----------------------------
@app.get("/health")
def health_check():
    return {
        "status": "healthy",
        "model_loaded": MODEL is not None,
        "cache_dir": str(CACHE_DIR),
    }

_TEXT_KEYS = ["text", "prompt", "description", "query", "message", "input", "content"]

class PredictRequest(BaseModel):
    model: str = Field(default="facebook/jasco-chords-drums-melody-400M")   
    text: str = ""
    chords_sym: str = ""
    n_samples: int = Field(default=2)
    seed: Optional[int] = Field(default=None)
    cfg_coef_all: float = Field(default=1.25)
    cfg_coef_txt: float = 2.5
    ode_rtol: float = 1e-4
    ode_atol: float = 1e-4
    ode_solver: str = "euler"
    ode_steps: int = 10
    drums_b64: Optional[str] = None
    drums_upload: Optional[UploadFile] = None

class PredictResponse(BaseModel):
    status: str = "success"
    message: Optional[str] = None
    data: Optional[dict] = None

def tensor_to_wav_scipy(tensor, sample_rate, filename):
    # Convert to numpy and ensure correct format
    audio_data = tensor.detach().cpu().numpy()
    
    # Normalize to 16-bit range
    audio_data = np.clip(audio_data, -1.0, 1.0)
    audio_data = (audio_data * 32767).astype(np.int16)
    
    # Save as WAV
    wavfile.write(filename, sample_rate, audio_data)

class FileCleaner:
    def __init__(self, file_lifetime: float = 3600):
        self.file_lifetime = file_lifetime
        self.files = []

    def add(self, path: Union[str, Path]):
        self._cleanup()
        self.files.append((time.time(), Path(path)))

    def _cleanup(self):
        now = time.time()
        for time_added, path in list(self.files):
            if now - time_added > self.file_lifetime:
                if path.exists():
                    path.unlink()
                self.files.pop(0)
            else:
                break

file_cleaner = FileCleaner()


@app.post("/predict")
async def predict(request: Request):
    """
    Returns a ZIP with jasco_1.wav, jasco_2.wav, ...
    Accepts:
      - multipart/form-data (fields + optional drums_file)
      - application/json (fields + optional drums_b64)
    """
    ct = (request.headers.get("content-type") or "application/json").lower()

    params = {
        "model": "facebook/jasco-chords-drums-melody-400M",
        "text": "",
        "chords_sym": "",
        "n_samples": 1,
        "seed": None,
        "cfg_coef_all": 1.25,
        "cfg_coef_txt": 2.5,
        "ode_rtol": 1e-4,
        "ode_atol": 1e-4,
        "ode_solver": "euler",
        "ode_steps": 10,
        "drums_b64": None
    } 
    drums_upload: Optional[UploadFile] = None

    try:
        if ct == "application/json":
            data = await request.json()
            if not isinstance(data, dict):
                raise HTTPException(status_code=400, detail="JSON body must be an object")
            for k in _TEXT_KEYS:
                if k in data and data[k]:
                    params["text"] = data[k]; break
            params["model"]        = data.get("model", params["model"])
            params["chords_sym"]   = data.get("chords_sym", params["chords_sym"])
            params["n_samples"]    = int(data.get("n_samples", params["n_samples"]))
            params["seed"]         = data.get("seed", None)
            params["cfg_coef_all"] = float(data.get("cfg_coef_all", params["cfg_coef_all"]))
            params["cfg_coef_txt"] = float(data.get("cfg_coef_txt", params["cfg_coef_txt"]))
            params["ode_rtol"]     = float(data.get("ode_rtol", params["ode_rtol"]))
            params["ode_atol"]     = float(data.get("ode_atol", params["ode_atol"]))
            params["ode_solver"]   = str(data.get("ode_solver", params["ode_solver"])).lower()
            params["ode_steps"]    = int(data.get("ode_steps", params["ode_steps"]))
            params["drums_b64"]    = data.get("drums_b64", None)
            raw_drums = _read_b64_to_bytes(params["drums_b64"])
        else:
            form = await request.form()
            fd = {k: (form.get(k)) for k in form.keys() if form.get(k)}
            for k in _TEXT_KEYS:
                if k in fd and fd[k]:
                    params["text"] = fd[k]; break
            params["model"]        = fd.get("model", params["model"])
            params["chords_sym"]   = fd.get("chords_sym", params["chords_sym"])
            params["n_samples"]    = int(fd.get("n_samples", params["n_samples"]))
            params["seed"]         = fd.get("seed", None)
            params["cfg_coef_all"] = float(fd.get("cfg_coef_all", params["cfg_coef_all"]))
            params["cfg_coef_txt"] = float(fd.get("cfg_coef_txt", params["cfg_coef_txt"]))
            params["ode_rtol"]     = float(fd.get("ode_rtol", params["ode_rtol"]))
            params["ode_atol"]     = float(fd.get("ode_atol", params["ode_atol"]))
            params["ode_solver"]   = str(fd.get("ode_solver", params["ode_solver"])).lower()
            params["ode_steps"]    = int(fd.get("ode_steps", params["ode_steps"]))
            params["drums_b64"]    = fd.get("drums_b64", None)
            drums_upload = form.get("drums_file")
            raw_drums = _read_uploadfile_to_bytes(drums_upload) if drums_upload else _read_b64_to_bytes(params["drums_b64"])
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Bad request: {e}")

    print(json.dumps({
        "ct": ct,
        "text_len": len(params["text"] or ""),
        "text_preview": (params["text"] or "")[:120],
        "model": params["model"],
        "n_samples": params["n_samples"],
        "has_drums_bytes": raw_drums is not None,
    }))

    model = load_model(params["model"])  # may raise HTTPException(401/500)

    drums_sr, drums_tensor = _read_wav_bytes(raw_drums)
    print(f"[predict] drums_present={drums_tensor is not None} sr={drums_sr} drums_fp={_tensor_fp(drums_tensor)}")

    base_seed = int(params["seed"]) if params["seed"] is not None else (int(time.time() * 1000) & 0xFFFFFFFF)
    random.seed(base_seed); np.random.seed(base_seed); torch.manual_seed(base_seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(base_seed)

    set_gen_params(
        model,
        cfg_coef_all=float(params["cfg_coef_all"]),
        cfg_coef_txt=float(params["cfg_coef_txt"]),
        ode_rtol=float(params["ode_rtol"]),
        ode_atol=float(params["ode_atol"]),
        euler=(params["ode_solver"] == "euler"),
        euler_steps=int(params["ode_steps"])
    )

    texts = [params["text"]] * max(1, int(params["n_samples"]))
    chords_list = chords_string_to_list(params["chords_sym"])
    print(f"[predictdebug] chords_list={chords_list}")
    print(f"[predictdebug] drums_tensor={drums_tensor}")
    print(f"[predictdebug] drums_sr={drums_sr}")
    print(f"[predictdebug] model={model}")
    print(f"[predictdebug] texts={texts}")

    try:
        outputs = model.generate_music(
            descriptions=texts,
            chords=chords_list,
            drums_wav=drums_tensor,
            melody_salience_matrix=None,
            drums_sample_rate=drums_sr,
            progress=False
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Generation failed: {e}")

    # Usage:
        # for i, wav in enumerate(outputs):
        #     with NamedTemporaryFile("wb", suffix=".wav", delete=False) as f:
        #         tensor_to_wav_scipy(wav, model.sample_rate, f.name)
        #         zf.write(f.name, arcname=f"jasco_{i+1}.wav")
    print(f"[predictdebug] outputs={outputs}")  # Log the raw model outputs
    
    # Convert model outputs from GPU tensor to CPU float tensor for processing
    outputs = outputs.detach().cpu().float()  
    print(f"[predictdebug] outputs converted to cpu={outputs}")  # Log the converted outputs


    with NamedTemporaryFile("wb", suffix=".wav", delete=False) as f:
        tmp_path = f.name

    audio_write(
        tmp_path,
        outputs[0],
        MODEL.sample_rate,  # or model.sample_rate β€” be consistent
        strategy="loudness",
        loudness_headroom_db=16,
        loudness_compressor=True,
        add_suffix=False,
    )
    return FileResponse(
        path=tmp_path,
        media_type="audio/wav",
        filename="jasco_output.wav"
    )
if __name__ == "__main__":
    ensure_hf_login()
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)


    # outputs = [1,2]
    # outputs[1] = [name, wav]
    # wav =[0.39203242, ]