Spaces:
Runtime error
Runtime error
File size: 6,620 Bytes
85283fb e698a98 7661963 9139806 85283fb ec73267 f87a760 9139806 85283fb 9139806 85283fb f87a760 bb7173b 9139806 85283fb f87a760 85283fb 02511ad f87a760 85283fb f87a760 85283fb f87a760 914290d 9139806 914290d 9139806 59d6306 9139806 914290d f87a760 84b867e 903d68e 85283fb 914290d 85283fb 914290d 85283fb 914290d 85283fb 9139806 914290d 9139806 85283fb 9139806 85283fb 914290d f87a760 85283fb e698a98 85283fb f87a760 85283fb f87a760 e698a98 85283fb 9139806 e698a98 85283fb f87a760 85283fb 9139806 f87a760 85283fb 9139806 85283fb 9139806 85283fb f87a760 85283fb 9139806 f87a760 85283fb e698a98 bb7173b f87a760 9139806 e698a98 85283fb e698a98 f87a760 9139806 85283fb f87a760 e698a98 f87a760 9139806 6899c06 f87a760 6899c06 e698a98 9139806 85283fb 9139806 e698a98 85283fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import os
import random
import torch
import spaces
import numpy as np
import gradio as gr
from chatterbox.tts_turbo import ChatterboxTurboTTS
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
# =========================
# CONFIG
# =========================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_REPO = "rahul7star/vaani-lora-lata"
LORA_FILE = "adapter_model.safetensors"
print(f"🚀 Running on device: {DEVICE}")
MODEL = None
# =========================
# MANUAL LORA MERGE
# =========================
def merge_lora_into_t3(t3_model, lora_state):
"""
Manual LoRA merge (PEFT-free).
"""
print("🔧 Merging LoRA weights into T3...")
loras = {}
for k, v in lora_state.items():
if ".lora_A." in k or ".lora_B." in k:
prefix = k.split(".lora_")[0]
loras.setdefault(prefix, {})[
k.split(".lora_")[1].split(".")[0]
] = v
elif k.endswith(".lora_alpha"):
prefix = k.replace(".lora_alpha", "")
loras.setdefault(prefix, {})["alpha"] = v.item()
for layer_path, parts in loras.items():
if "A" not in parts or "B" not in parts:
continue
module = t3_model
for attr in layer_path.split("."):
module = getattr(module, attr)
W = module.weight.data
A = parts["A"].to(W.device)
B = parts["B"].to(W.device)
r = A.shape[0]
alpha = parts.get("alpha", r)
scale = alpha / r
W += (B @ A) * scale
print("✅ LoRA merged successfully")
def merge_lora_into_t3(t3_model, lora_state):
"""
Manually merges LoRA weights into the base T3 model (NO PEFT).
"""
print("🔧 Merging Hindi LoRA into T3 weights...")
# Group LoRA tensors
loras = {}
for k, v in lora_state.items():
if ".lora_A." in k or ".lora_B." in k:
prefix = k.split(".lora_")[0]
loras.setdefault(prefix, {})[k.split(".lora_")[1].split(".")[0]] = v
if k.endswith("lora_alpha"):
prefix = k.replace(".lora_alpha", "")
loras.setdefault(prefix, {})["alpha"] = v.item()
for layer_name, parts in loras.items():
if "A" not in parts or "B" not in parts:
continue
# Locate base weight
module = t3_model
for attr in layer_name.split("."):
module = getattr(module, attr)
W = module.weight.data
A = parts["A"].to(W.device)
B = parts["B"].to(W.device)
r = A.shape[0]
alpha = parts.get("alpha", r)
scale = alpha / r
# Merge
W += (B @ A) * scale
print("✅ LoRA merge complete")
return t3_model
# =========================
# MODEL LOADING
# =========================
def get_or_load_model():
"""Loads ChatterboxTTS and merges Hindi LoRA or loads full T3 safely."""
global MODEL
if MODEL is None:
print("Model not loaded, initializing...")
try:
MODEL = ChatterboxTurboTTS.from_pretrained(DEVICE)
checkpoint_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=LORA_FILE ,
token=os.environ["HF_TOKEN"],
)
state = load_file(checkpoint_path, device="cpu")
# --------------------------------------------------
# CASE 1: FULL T3 CHECKPOINT (your repo currently)
# --------------------------------------------------
if any(k.startswith("tfmr.") for k in state.keys()):
print("Detected FULL T3 checkpoint – loading directly")
MODEL.t3.load_state_dict(state, strict=True)
# --------------------------------------------------
# CASE 2: REAL LoRA ADAPTER → MANUAL MERGE
# --------------------------------------------------
else:
print("Detected LoRA adapter – merging weights")
merge_lora_into_t3(MODEL.t3, state)
MODEL.to(DEVICE)
print(f"Model loaded successfully on {DEVICE}")
except Exception as e:
print(f"Error loading model: {e}")
raise
return MODEL
# Load on startup
get_or_load_model()
# =========================
# SEED
# =========================
def set_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed_all(seed)
# =========================
# TTS
# =========================
def generate_tts_audio(
text_input,
audio_prompt_path_input,
exaggeration_input,
temperature_input,
seed_num_input,
cfgw_input,
):
model = get_or_load_model()
if seed_num_input != 0:
set_seed(int(seed_num_input))
wav = model.generate(
text_input[:3000],
audio_prompt_path=audio_prompt_path_input,
exaggeration=exaggeration_input,
temperature=temperature_input,
cfg_weight=cfgw_input,
)
return model.sr, wav.squeeze(0).numpy()
# =========================
# UI
# =========================
with gr.Blocks() as demo:
gr.Markdown("## 🇮🇳 Hindi TTS – Chatterbox (LoRA merged, no PEFT)")
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी",
label="Text",
max_lines=5
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio (optional)"
)
exaggeration = gr.Slider(0.25, 2.0, step=0.05, value=0.5)
cfg_weight = gr.Slider(0.2, 1.0, step=0.05, value=0.3)
with gr.Accordion("Advanced", open=False):
seed_num = gr.Number(value=0, label="Seed (0=random)")
temp = gr.Slider(0.05, 5.0, step=0.05, value=0.6)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output")
run_btn.click(
fn=generate_tts_audio,
inputs=[text, ref_wav, exaggeration, temp, seed_num, cfg_weight],
outputs=[audio_output],
)
demo.launch()
|