import os import random import torch import spaces import numpy as np import gradio as gr from chatterbox.tts_turbo import ChatterboxTurboTTS from safetensors.torch import load_file from huggingface_hub import hf_hub_download # ========================= # CONFIG # ========================= DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_REPO = "rahul7star/vaani-lora-lata" LORA_FILE = "adapter_model.safetensors" print(f"ЁЯЪА Running on device: {DEVICE}") MODEL = None # ========================= # MANUAL LORA MERGE # ========================= def merge_lora_into_t3(t3_model, lora_state): """ Manual LoRA merge (PEFT-free). """ print("ЁЯФз Merging LoRA weights into T3...") loras = {} for k, v in lora_state.items(): if ".lora_A." in k or ".lora_B." in k: prefix = k.split(".lora_")[0] loras.setdefault(prefix, {})[ k.split(".lora_")[1].split(".")[0] ] = v elif k.endswith(".lora_alpha"): prefix = k.replace(".lora_alpha", "") loras.setdefault(prefix, {})["alpha"] = v.item() for layer_path, parts in loras.items(): if "A" not in parts or "B" not in parts: continue module = t3_model for attr in layer_path.split("."): module = getattr(module, attr) W = module.weight.data A = parts["A"].to(W.device) B = parts["B"].to(W.device) r = A.shape[0] alpha = parts.get("alpha", r) scale = alpha / r W += (B @ A) * scale print("тЬЕ LoRA merged successfully") def merge_lora_into_t3(t3_model, lora_state): """ Manually merges LoRA weights into the base T3 model (NO PEFT). """ print("ЁЯФз Merging Hindi LoRA into T3 weights...") # Group LoRA tensors loras = {} for k, v in lora_state.items(): if ".lora_A." in k or ".lora_B." in k: prefix = k.split(".lora_")[0] loras.setdefault(prefix, {})[k.split(".lora_")[1].split(".")[0]] = v if k.endswith("lora_alpha"): prefix = k.replace(".lora_alpha", "") loras.setdefault(prefix, {})["alpha"] = v.item() for layer_name, parts in loras.items(): if "A" not in parts or "B" not in parts: continue # Locate base weight module = t3_model for attr in layer_name.split("."): module = getattr(module, attr) W = module.weight.data A = parts["A"].to(W.device) B = parts["B"].to(W.device) r = A.shape[0] alpha = parts.get("alpha", r) scale = alpha / r # Merge W += (B @ A) * scale print("тЬЕ LoRA merge complete") return t3_model # ========================= # MODEL LOADING # ========================= def get_or_load_model(): """Loads ChatterboxTTS and merges Hindi LoRA or loads full T3 safely.""" global MODEL if MODEL is None: print("Model not loaded, initializing...") try: MODEL = ChatterboxTurboTTS.from_pretrained(DEVICE) checkpoint_path = hf_hub_download( repo_id=MODEL_REPO, filename=LORA_FILE , token=os.environ["HF_TOKEN"], ) state = load_file(checkpoint_path, device="cpu") # -------------------------------------------------- # CASE 1: FULL T3 CHECKPOINT (your repo currently) # -------------------------------------------------- if any(k.startswith("tfmr.") for k in state.keys()): print("Detected FULL T3 checkpoint тАУ loading directly") MODEL.t3.load_state_dict(state, strict=True) # -------------------------------------------------- # CASE 2: REAL LoRA ADAPTER тЖТ MANUAL MERGE # -------------------------------------------------- else: print("Detected LoRA adapter тАУ merging weights") merge_lora_into_t3(MODEL.t3, state) MODEL.to(DEVICE) print(f"Model loaded successfully on {DEVICE}") except Exception as e: print(f"Error loading model: {e}") raise return MODEL # Load on startup get_or_load_model() # ========================= # SEED # ========================= def set_seed(seed): torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) if DEVICE == "cuda": torch.cuda.manual_seed_all(seed) # ========================= # TTS # ========================= def generate_tts_audio( text_input, audio_prompt_path_input, exaggeration_input, temperature_input, seed_num_input, cfgw_input, ): model = get_or_load_model() if seed_num_input != 0: set_seed(int(seed_num_input)) wav = model.generate( text_input[:3000], audio_prompt_path=audio_prompt_path_input, exaggeration=exaggeration_input, temperature=temperature_input, cfg_weight=cfgw_input, ) return model.sr, wav.squeeze(0).numpy() # ========================= # UI # ========================= with gr.Blocks() as demo: gr.Markdown("## ЁЯЗоЁЯЗ│ Hindi TTS тАУ Chatterbox (LoRA merged, no PEFT)") with gr.Row(): with gr.Column(): text = gr.Textbox( value="рд░рд╛рдЬрдиреАрддрд┐рдЬреНрдЮреЛрдВ рдиреЗ рдХрд╣рд╛ рдХрд┐ рдЙрдиреНрд╣реЛрдВрдиреЗ рдирд┐рд░реНрдгрд╛рдпрдХ рдордд рдХреЛ рдЕрдирд╛рд╡рд╢реНрдпрдХ рд░реВрдк рд╕реЗ рдирд┐рд░реНрдзрд╛рд░рд┐рдд рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдЕрдлрд╝рдЧрд╛рди рд╕рдВрд╡рд┐рдзрд╛рди рдореЗрдВ рдХрд╛рдлреА рдЕрд╕реНрдкрд╖реНрдЯрддрд╛ рдкрд╛рдИ рдереА", label="Text", max_lines=5 ) ref_wav = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Reference Audio (optional)" ) exaggeration = gr.Slider(0.25, 2.0, step=0.05, value=0.5) cfg_weight = gr.Slider(0.2, 1.0, step=0.05, value=0.3) with gr.Accordion("Advanced", open=False): seed_num = gr.Number(value=0, label="Seed (0=random)") temp = gr.Slider(0.05, 5.0, step=0.05, value=0.6) run_btn = gr.Button("Generate", variant="primary") with gr.Column(): audio_output = gr.Audio(label="Output") run_btn.click( fn=generate_tts_audio, inputs=[text, ref_wav, exaggeration, temp, seed_num, cfg_weight], outputs=[audio_output], ) demo.launch()