rahul7star's picture
Update app.py
903d68e verified
import os
import random
import torch
import spaces
import numpy as np
import gradio as gr
from chatterbox.tts_turbo import ChatterboxTurboTTS
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
# =========================
# CONFIG
# =========================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_REPO = "rahul7star/vaani-lora-lata"
LORA_FILE = "adapter_model.safetensors"
print(f"🚀 Running on device: {DEVICE}")
MODEL = None
# =========================
# MANUAL LORA MERGE
# =========================
def merge_lora_into_t3(t3_model, lora_state):
"""
Manual LoRA merge (PEFT-free).
"""
print("🔧 Merging LoRA weights into T3...")
loras = {}
for k, v in lora_state.items():
if ".lora_A." in k or ".lora_B." in k:
prefix = k.split(".lora_")[0]
loras.setdefault(prefix, {})[
k.split(".lora_")[1].split(".")[0]
] = v
elif k.endswith(".lora_alpha"):
prefix = k.replace(".lora_alpha", "")
loras.setdefault(prefix, {})["alpha"] = v.item()
for layer_path, parts in loras.items():
if "A" not in parts or "B" not in parts:
continue
module = t3_model
for attr in layer_path.split("."):
module = getattr(module, attr)
W = module.weight.data
A = parts["A"].to(W.device)
B = parts["B"].to(W.device)
r = A.shape[0]
alpha = parts.get("alpha", r)
scale = alpha / r
W += (B @ A) * scale
print("✅ LoRA merged successfully")
def merge_lora_into_t3(t3_model, lora_state):
"""
Manually merges LoRA weights into the base T3 model (NO PEFT).
"""
print("🔧 Merging Hindi LoRA into T3 weights...")
# Group LoRA tensors
loras = {}
for k, v in lora_state.items():
if ".lora_A." in k or ".lora_B." in k:
prefix = k.split(".lora_")[0]
loras.setdefault(prefix, {})[k.split(".lora_")[1].split(".")[0]] = v
if k.endswith("lora_alpha"):
prefix = k.replace(".lora_alpha", "")
loras.setdefault(prefix, {})["alpha"] = v.item()
for layer_name, parts in loras.items():
if "A" not in parts or "B" not in parts:
continue
# Locate base weight
module = t3_model
for attr in layer_name.split("."):
module = getattr(module, attr)
W = module.weight.data
A = parts["A"].to(W.device)
B = parts["B"].to(W.device)
r = A.shape[0]
alpha = parts.get("alpha", r)
scale = alpha / r
# Merge
W += (B @ A) * scale
print("✅ LoRA merge complete")
return t3_model
# =========================
# MODEL LOADING
# =========================
def get_or_load_model():
"""Loads ChatterboxTTS and merges Hindi LoRA or loads full T3 safely."""
global MODEL
if MODEL is None:
print("Model not loaded, initializing...")
try:
MODEL = ChatterboxTurboTTS.from_pretrained(DEVICE)
checkpoint_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=LORA_FILE ,
token=os.environ["HF_TOKEN"],
)
state = load_file(checkpoint_path, device="cpu")
# --------------------------------------------------
# CASE 1: FULL T3 CHECKPOINT (your repo currently)
# --------------------------------------------------
if any(k.startswith("tfmr.") for k in state.keys()):
print("Detected FULL T3 checkpoint – loading directly")
MODEL.t3.load_state_dict(state, strict=True)
# --------------------------------------------------
# CASE 2: REAL LoRA ADAPTER → MANUAL MERGE
# --------------------------------------------------
else:
print("Detected LoRA adapter – merging weights")
merge_lora_into_t3(MODEL.t3, state)
MODEL.to(DEVICE)
print(f"Model loaded successfully on {DEVICE}")
except Exception as e:
print(f"Error loading model: {e}")
raise
return MODEL
# Load on startup
get_or_load_model()
# =========================
# SEED
# =========================
def set_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed_all(seed)
# =========================
# TTS
# =========================
def generate_tts_audio(
text_input,
audio_prompt_path_input,
exaggeration_input,
temperature_input,
seed_num_input,
cfgw_input,
):
model = get_or_load_model()
if seed_num_input != 0:
set_seed(int(seed_num_input))
wav = model.generate(
text_input[:3000],
audio_prompt_path=audio_prompt_path_input,
exaggeration=exaggeration_input,
temperature=temperature_input,
cfg_weight=cfgw_input,
)
return model.sr, wav.squeeze(0).numpy()
# =========================
# UI
# =========================
with gr.Blocks() as demo:
gr.Markdown("## 🇮🇳 Hindi TTS – Chatterbox (LoRA merged, no PEFT)")
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी",
label="Text",
max_lines=5
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio (optional)"
)
exaggeration = gr.Slider(0.25, 2.0, step=0.05, value=0.5)
cfg_weight = gr.Slider(0.2, 1.0, step=0.05, value=0.3)
with gr.Accordion("Advanced", open=False):
seed_num = gr.Number(value=0, label="Seed (0=random)")
temp = gr.Slider(0.05, 5.0, step=0.05, value=0.6)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output")
run_btn.click(
fn=generate_tts_audio,
inputs=[text, ref_wav, exaggeration, temp, seed_num, cfg_weight],
outputs=[audio_output],
)
demo.launch()