|
|
import random |
|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
import gradio as gr |
|
|
import spaces |
|
|
from chatterbox.tts_turbo import ChatterboxTurboTTS |
|
|
|
|
|
|
|
|
MODEL = ChatterboxTurboTTS.from_pretrained("cuda" ) |
|
|
|
|
|
EVENT_TAGS = [ |
|
|
"[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]", |
|
|
"[sniff]", "[gasp]", "[chuckle]", "[laugh]" |
|
|
] |
|
|
|
|
|
CUSTOM_CSS = """ |
|
|
.tag-container { |
|
|
display: flex !important; |
|
|
flex-wrap: wrap !important; |
|
|
gap: 8px !important; |
|
|
margin-top: 5px !important; |
|
|
margin-bottom: 10px !important; |
|
|
border: none !important; |
|
|
background: transparent !important; |
|
|
} |
|
|
|
|
|
.tag-btn { |
|
|
min-width: fit-content !important; |
|
|
width: auto !important; |
|
|
height: 32px !important; |
|
|
font-size: 13px !important; |
|
|
background: #eef2ff !important; |
|
|
border: 1px solid #c7d2fe !important; |
|
|
color: #3730a3 !important; |
|
|
border-radius: 6px !important; |
|
|
padding: 0 10px !important; |
|
|
margin: 0 !important; |
|
|
box-shadow: none !important; |
|
|
} |
|
|
|
|
|
.tag-btn:hover { |
|
|
background: #c7d2fe !important; |
|
|
transform: translateY(-1px); |
|
|
} |
|
|
""" |
|
|
|
|
|
INSERT_TAG_JS = """ |
|
|
(tag_val, current_text) => { |
|
|
const textarea = document.querySelector('#main_textbox textarea'); |
|
|
if (!textarea) return current_text + " " + tag_val; |
|
|
|
|
|
const start = textarea.selectionStart; |
|
|
const end = textarea.selectionEnd; |
|
|
|
|
|
let prefix = " "; |
|
|
let suffix = " "; |
|
|
|
|
|
if (start === 0) prefix = ""; |
|
|
else if (current_text[start - 1] === ' ') prefix = ""; |
|
|
|
|
|
if (end < current_text.length && current_text[end] === ' ') suffix = ""; |
|
|
|
|
|
return current_text.slice(0, start) + prefix + tag_val + suffix + current_text.slice(end); |
|
|
} |
|
|
""" |
|
|
|
|
|
def set_seed(seed: int): |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
|
|
|
@spaces.GPU |
|
|
def generate( |
|
|
text, |
|
|
audio_prompt_path, |
|
|
temperature, |
|
|
seed_num, |
|
|
min_p, |
|
|
top_p, |
|
|
top_k, |
|
|
repetition_penalty, |
|
|
norm_loudness |
|
|
): |
|
|
if seed_num != 0: |
|
|
set_seed(int(seed_num)) |
|
|
|
|
|
wav = MODEL.generate( |
|
|
text, |
|
|
audio_prompt_path=audio_prompt_path, |
|
|
temperature=temperature, |
|
|
min_p=min_p, |
|
|
top_p=top_p, |
|
|
top_k=int(top_k), |
|
|
repetition_penalty=repetition_penalty, |
|
|
norm_loudness=norm_loudness, |
|
|
) |
|
|
|
|
|
return (MODEL.sr, wav.squeeze(0).cpu().numpy()) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Chatterbox Turbo") as demo: |
|
|
gr.Markdown("# ⚡ Chatterbox Turbo") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
text = gr.Textbox( |
|
|
value="Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?", |
|
|
label="Text to synthesize (max chars 300)", |
|
|
max_lines=5, |
|
|
elem_id="main_textbox" |
|
|
) |
|
|
|
|
|
with gr.Row(elem_classes=["tag-container"]): |
|
|
for tag in EVENT_TAGS: |
|
|
btn = gr.Button(tag, elem_classes=["tag-btn"]) |
|
|
btn.click( |
|
|
fn=None, |
|
|
inputs=[btn, text], |
|
|
outputs=text, |
|
|
js=INSERT_TAG_JS |
|
|
) |
|
|
|
|
|
ref_wav = gr.Audio( |
|
|
sources=["upload", "microphone"], |
|
|
type="filepath", |
|
|
label="Reference Audio File", |
|
|
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/Ethan.wav", |
|
|
) |
|
|
|
|
|
run_btn = gr.Button("Generate ⚡", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
audio_output = gr.Audio(label="Output Audio") |
|
|
|
|
|
with gr.Accordion("Advanced Options", open=False): |
|
|
seed_num = gr.Number(value=0, label="Random seed (0 for random)") |
|
|
temp = gr.Slider(0.05, 2.0, step=.05, label="Temperature", value=0.8) |
|
|
top_p = gr.Slider(0.00, 1.00, step=0.01, label="Top P", value=0.95) |
|
|
top_k = gr.Slider(0, 1000, step=10, label="Top K", value=1000) |
|
|
repetition_penalty = gr.Slider(1.00, 2.00, step=0.05, label="Repetition Penalty", value=1.2) |
|
|
min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00) |
|
|
norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (-27 LUFS)") |
|
|
|
|
|
run_btn.click( |
|
|
fn=generate, |
|
|
inputs=[ |
|
|
text, |
|
|
ref_wav, |
|
|
temp, |
|
|
seed_num, |
|
|
min_p, |
|
|
top_p, |
|
|
top_k, |
|
|
repetition_penalty, |
|
|
norm_loudness, |
|
|
], |
|
|
outputs=audio_output, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch( |
|
|
mcp_server=True, |
|
|
css=CUSTOM_CSS, |
|
|
ssr_mode=False |
|
|
) |
|
|
|