Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import torch | |
| import spaces | |
| import numpy as np | |
| import gradio as gr | |
| from chatterbox.tts_turbo import ChatterboxTurboTTS | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| # ========================= | |
| # CONFIG | |
| # ========================= | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_REPO = "rahul7star/vaani-lora-lata" | |
| LORA_FILE = "adapter_model.safetensors" | |
| print(f"🚀 Running on device: {DEVICE}") | |
| MODEL = None | |
| # ========================= | |
| # MANUAL LORA MERGE | |
| # ========================= | |
| def merge_lora_into_t3(t3_model, lora_state): | |
| """ | |
| Manual LoRA merge (PEFT-free). | |
| """ | |
| print("🔧 Merging LoRA weights into T3...") | |
| loras = {} | |
| for k, v in lora_state.items(): | |
| if ".lora_A." in k or ".lora_B." in k: | |
| prefix = k.split(".lora_")[0] | |
| loras.setdefault(prefix, {})[ | |
| k.split(".lora_")[1].split(".")[0] | |
| ] = v | |
| elif k.endswith(".lora_alpha"): | |
| prefix = k.replace(".lora_alpha", "") | |
| loras.setdefault(prefix, {})["alpha"] = v.item() | |
| for layer_path, parts in loras.items(): | |
| if "A" not in parts or "B" not in parts: | |
| continue | |
| module = t3_model | |
| for attr in layer_path.split("."): | |
| module = getattr(module, attr) | |
| W = module.weight.data | |
| A = parts["A"].to(W.device) | |
| B = parts["B"].to(W.device) | |
| r = A.shape[0] | |
| alpha = parts.get("alpha", r) | |
| scale = alpha / r | |
| W += (B @ A) * scale | |
| print("✅ LoRA merged successfully") | |
| def merge_lora_into_t3(t3_model, lora_state): | |
| """ | |
| Manually merges LoRA weights into the base T3 model (NO PEFT). | |
| """ | |
| print("🔧 Merging Hindi LoRA into T3 weights...") | |
| # Group LoRA tensors | |
| loras = {} | |
| for k, v in lora_state.items(): | |
| if ".lora_A." in k or ".lora_B." in k: | |
| prefix = k.split(".lora_")[0] | |
| loras.setdefault(prefix, {})[k.split(".lora_")[1].split(".")[0]] = v | |
| if k.endswith("lora_alpha"): | |
| prefix = k.replace(".lora_alpha", "") | |
| loras.setdefault(prefix, {})["alpha"] = v.item() | |
| for layer_name, parts in loras.items(): | |
| if "A" not in parts or "B" not in parts: | |
| continue | |
| # Locate base weight | |
| module = t3_model | |
| for attr in layer_name.split("."): | |
| module = getattr(module, attr) | |
| W = module.weight.data | |
| A = parts["A"].to(W.device) | |
| B = parts["B"].to(W.device) | |
| r = A.shape[0] | |
| alpha = parts.get("alpha", r) | |
| scale = alpha / r | |
| # Merge | |
| W += (B @ A) * scale | |
| print("✅ LoRA merge complete") | |
| return t3_model | |
| # ========================= | |
| # MODEL LOADING | |
| # ========================= | |
| def get_or_load_model(): | |
| """Loads ChatterboxTTS and merges Hindi LoRA or loads full T3 safely.""" | |
| global MODEL | |
| if MODEL is None: | |
| print("Model not loaded, initializing...") | |
| try: | |
| MODEL = ChatterboxTurboTTS.from_pretrained(DEVICE) | |
| checkpoint_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=LORA_FILE , | |
| token=os.environ["HF_TOKEN"], | |
| ) | |
| state = load_file(checkpoint_path, device="cpu") | |
| # -------------------------------------------------- | |
| # CASE 1: FULL T3 CHECKPOINT (your repo currently) | |
| # -------------------------------------------------- | |
| if any(k.startswith("tfmr.") for k in state.keys()): | |
| print("Detected FULL T3 checkpoint – loading directly") | |
| MODEL.t3.load_state_dict(state, strict=True) | |
| # -------------------------------------------------- | |
| # CASE 2: REAL LoRA ADAPTER → MANUAL MERGE | |
| # -------------------------------------------------- | |
| else: | |
| print("Detected LoRA adapter – merging weights") | |
| merge_lora_into_t3(MODEL.t3, state) | |
| MODEL.to(DEVICE) | |
| print(f"Model loaded successfully on {DEVICE}") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| return MODEL | |
| # Load on startup | |
| get_or_load_model() | |
| # ========================= | |
| # SEED | |
| # ========================= | |
| def set_seed(seed): | |
| torch.manual_seed(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| if DEVICE == "cuda": | |
| torch.cuda.manual_seed_all(seed) | |
| # ========================= | |
| # TTS | |
| # ========================= | |
| def generate_tts_audio( | |
| text_input, | |
| audio_prompt_path_input, | |
| exaggeration_input, | |
| temperature_input, | |
| seed_num_input, | |
| cfgw_input, | |
| ): | |
| model = get_or_load_model() | |
| if seed_num_input != 0: | |
| set_seed(int(seed_num_input)) | |
| wav = model.generate( | |
| text_input[:3000], | |
| audio_prompt_path=audio_prompt_path_input, | |
| exaggeration=exaggeration_input, | |
| temperature=temperature_input, | |
| cfg_weight=cfgw_input, | |
| ) | |
| return model.sr, wav.squeeze(0).numpy() | |
| # ========================= | |
| # UI | |
| # ========================= | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🇮🇳 Hindi TTS – Chatterbox (LoRA merged, no PEFT)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text = gr.Textbox( | |
| value="राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी", | |
| label="Text", | |
| max_lines=5 | |
| ) | |
| ref_wav = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Reference Audio (optional)" | |
| ) | |
| exaggeration = gr.Slider(0.25, 2.0, step=0.05, value=0.5) | |
| cfg_weight = gr.Slider(0.2, 1.0, step=0.05, value=0.3) | |
| with gr.Accordion("Advanced", open=False): | |
| seed_num = gr.Number(value=0, label="Seed (0=random)") | |
| temp = gr.Slider(0.05, 5.0, step=0.05, value=0.6) | |
| run_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| audio_output = gr.Audio(label="Output") | |
| run_btn.click( | |
| fn=generate_tts_audio, | |
| inputs=[text, ref_wav, exaggeration, temp, seed_num, cfg_weight], | |
| outputs=[audio_output], | |
| ) | |
| demo.launch() | |