|
|
import datetime
|
|
|
import builtins
|
|
|
import asyncio
|
|
|
import json
|
|
|
import os
|
|
|
import threading
|
|
|
import traceback
|
|
|
from pathlib import Path
|
|
|
from queue import Empty, Queue
|
|
|
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, cast
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from fastapi import FastAPI, WebSocket
|
|
|
from fastapi.responses import FileResponse
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
from starlette.websockets import WebSocketDisconnect, WebSocketState
|
|
|
|
|
|
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
|
|
VibeVoiceStreamingForConditionalGenerationInference,
|
|
|
)
|
|
|
from vibevoice.processor.vibevoice_streaming_processor import (
|
|
|
VibeVoiceStreamingProcessor,
|
|
|
)
|
|
|
from vibevoice.modular.streamer import AudioStreamer
|
|
|
|
|
|
import copy
|
|
|
|
|
|
BASE = Path(__file__).parent
|
|
|
SAMPLE_RATE = 24_000
|
|
|
|
|
|
|
|
|
def get_timestamp():
|
|
|
timestamp = datetime.datetime.utcnow().replace(
|
|
|
tzinfo=datetime.timezone.utc
|
|
|
).astimezone(
|
|
|
datetime.timezone(datetime.timedelta(hours=8))
|
|
|
).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
|
|
return timestamp
|
|
|
|
|
|
class StreamingTTSService:
|
|
|
def __init__(
|
|
|
self,
|
|
|
model_path: str,
|
|
|
device: str = "cuda",
|
|
|
inference_steps: int = 5,
|
|
|
) -> None:
|
|
|
self.model_path = Path(model_path)
|
|
|
self.inference_steps = inference_steps
|
|
|
self.sample_rate = SAMPLE_RATE
|
|
|
|
|
|
self.processor: Optional[VibeVoiceStreamingProcessor] = None
|
|
|
self.model: Optional[VibeVoiceStreamingForConditionalGenerationInference] = None
|
|
|
self.voice_presets: Dict[str, Path] = {}
|
|
|
self.default_voice_key: Optional[str] = None
|
|
|
self._voice_cache: Dict[str, Tuple[object, Path, str]] = {}
|
|
|
|
|
|
if device == "mpx":
|
|
|
print("Note: device 'mpx' detected, treating it as 'mps'.")
|
|
|
device = "mps"
|
|
|
if device == "mps" and not torch.backends.mps.is_available():
|
|
|
print("Warning: MPS not available. Falling back to CPU.")
|
|
|
device = "cpu"
|
|
|
self.device = device
|
|
|
self._torch_device = torch.device(device)
|
|
|
|
|
|
def load(self) -> None:
|
|
|
print(f"[startup] Loading processor from {self.model_path}")
|
|
|
self.processor = VibeVoiceStreamingProcessor.from_pretrained(str(self.model_path))
|
|
|
|
|
|
|
|
|
|
|
|
if self.device == "mps":
|
|
|
load_dtype = torch.float32
|
|
|
device_map = None
|
|
|
attn_impl_primary = "sdpa"
|
|
|
elif self.device == "cuda":
|
|
|
load_dtype = torch.bfloat16
|
|
|
device_map = 'cuda'
|
|
|
attn_impl_primary = "flash_attention_2"
|
|
|
else:
|
|
|
load_dtype = torch.float32
|
|
|
device_map = 'cpu'
|
|
|
attn_impl_primary = "sdpa"
|
|
|
print(f"Using device: {device_map}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
|
|
|
|
|
|
try:
|
|
|
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
|
|
str(self.model_path),
|
|
|
torch_dtype=load_dtype,
|
|
|
device_map=device_map,
|
|
|
attn_implementation=attn_impl_primary,
|
|
|
)
|
|
|
|
|
|
if self.device == "mps":
|
|
|
self.model.to("mps")
|
|
|
except Exception as e:
|
|
|
if attn_impl_primary == 'flash_attention_2':
|
|
|
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
|
|
|
|
|
|
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
|
|
str(self.model_path),
|
|
|
torch_dtype=load_dtype,
|
|
|
device_map=self.device,
|
|
|
attn_implementation='sdpa',
|
|
|
)
|
|
|
print("Load model with SDPA successfully ")
|
|
|
else:
|
|
|
raise e
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
|
|
|
self.model.model.noise_scheduler.config,
|
|
|
algorithm_type="sde-dpmsolver++",
|
|
|
beta_schedule="squaredcos_cap_v2",
|
|
|
)
|
|
|
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
|
|
|
|
|
|
self.voice_presets = self._load_voice_presets()
|
|
|
preset_name = os.environ.get("VOICE_PRESET")
|
|
|
self.default_voice_key = self._determine_voice_key(preset_name)
|
|
|
self._ensure_voice_cached(self.default_voice_key)
|
|
|
|
|
|
def _load_voice_presets(self) -> Dict[str, Path]:
|
|
|
voices_dir = BASE.parent / "voices" / "streaming_model"
|
|
|
if not voices_dir.exists():
|
|
|
raise RuntimeError(f"Voices directory not found: {voices_dir}")
|
|
|
|
|
|
presets: Dict[str, Path] = {}
|
|
|
for pt_path in voices_dir.glob("*.pt"):
|
|
|
presets[pt_path.stem] = pt_path
|
|
|
|
|
|
if not presets:
|
|
|
raise RuntimeError(f"No voice preset (.pt) files found in {voices_dir}")
|
|
|
|
|
|
print(f"[startup] Found {len(presets)} voice presets")
|
|
|
return dict(sorted(presets.items()))
|
|
|
|
|
|
def _determine_voice_key(self, name: Optional[str]) -> str:
|
|
|
if name and name in self.voice_presets:
|
|
|
return name
|
|
|
|
|
|
default_key = "en-WHTest_man"
|
|
|
if default_key in self.voice_presets:
|
|
|
return default_key
|
|
|
|
|
|
first_key = next(iter(self.voice_presets))
|
|
|
print(f"[startup] Using fallback voice preset: {first_key}")
|
|
|
return first_key
|
|
|
|
|
|
def _ensure_voice_cached(self, key: str) -> Tuple[object, Path, str]:
|
|
|
if key not in self.voice_presets:
|
|
|
raise RuntimeError(f"Voice preset {key!r} not found")
|
|
|
|
|
|
if key not in self._voice_cache:
|
|
|
preset_path = self.voice_presets[key]
|
|
|
print(f"[startup] Loading voice preset {key} from {preset_path}")
|
|
|
print(f"[startup] Loading prefilled prompt from {preset_path}")
|
|
|
prefilled_outputs = torch.load(
|
|
|
preset_path,
|
|
|
map_location=self._torch_device,
|
|
|
weights_only=False,
|
|
|
)
|
|
|
self._voice_cache[key] = prefilled_outputs
|
|
|
|
|
|
return self._voice_cache[key]
|
|
|
|
|
|
def _get_voice_resources(self, requested_key: Optional[str]) -> Tuple[str, object, Path, str]:
|
|
|
key = requested_key if requested_key and requested_key in self.voice_presets else self.default_voice_key
|
|
|
if key is None:
|
|
|
key = next(iter(self.voice_presets))
|
|
|
self.default_voice_key = key
|
|
|
|
|
|
prefilled_outputs = self._ensure_voice_cached(key)
|
|
|
return key, prefilled_outputs
|
|
|
|
|
|
def _prepare_inputs(self, text: str, prefilled_outputs: object):
|
|
|
if not self.processor or not self.model:
|
|
|
raise RuntimeError("StreamingTTSService not initialized")
|
|
|
|
|
|
processor_kwargs = {
|
|
|
"text": text.strip(),
|
|
|
"cached_prompt": prefilled_outputs,
|
|
|
"padding": True,
|
|
|
"return_tensors": "pt",
|
|
|
"return_attention_mask": True,
|
|
|
}
|
|
|
|
|
|
processed = self.processor.process_input_with_cached_prompt(**processor_kwargs)
|
|
|
|
|
|
prepared = {
|
|
|
key: value.to(self._torch_device) if hasattr(value, "to") else value
|
|
|
for key, value in processed.items()
|
|
|
}
|
|
|
return prepared
|
|
|
|
|
|
def _run_generation(
|
|
|
self,
|
|
|
inputs,
|
|
|
audio_streamer: AudioStreamer,
|
|
|
errors,
|
|
|
cfg_scale: float,
|
|
|
do_sample: bool,
|
|
|
temperature: float,
|
|
|
top_p: float,
|
|
|
refresh_negative: bool,
|
|
|
prefilled_outputs,
|
|
|
stop_event: threading.Event,
|
|
|
) -> None:
|
|
|
try:
|
|
|
self.model.generate(
|
|
|
**inputs,
|
|
|
max_new_tokens=None,
|
|
|
cfg_scale=cfg_scale,
|
|
|
tokenizer=self.processor.tokenizer,
|
|
|
generation_config={
|
|
|
"do_sample": do_sample,
|
|
|
"temperature": temperature if do_sample else 1.0,
|
|
|
"top_p": top_p if do_sample else 1.0,
|
|
|
},
|
|
|
audio_streamer=audio_streamer,
|
|
|
stop_check_fn=stop_event.is_set,
|
|
|
verbose=False,
|
|
|
refresh_negative=refresh_negative,
|
|
|
all_prefilled_outputs=copy.deepcopy(prefilled_outputs),
|
|
|
)
|
|
|
except Exception as exc:
|
|
|
errors.append(exc)
|
|
|
traceback.print_exc()
|
|
|
audio_streamer.end()
|
|
|
|
|
|
def stream(
|
|
|
self,
|
|
|
text: str,
|
|
|
cfg_scale: float = 1.5,
|
|
|
do_sample: bool = False,
|
|
|
temperature: float = 0.9,
|
|
|
top_p: float = 0.9,
|
|
|
refresh_negative: bool = True,
|
|
|
inference_steps: Optional[int] = None,
|
|
|
voice_key: Optional[str] = None,
|
|
|
log_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
|
|
stop_event: Optional[threading.Event] = None,
|
|
|
) -> Iterator[np.ndarray]:
|
|
|
if not text.strip():
|
|
|
return
|
|
|
text = text.replace("’", "'")
|
|
|
selected_voice, prefilled_outputs = self._get_voice_resources(voice_key)
|
|
|
|
|
|
def emit(event: str, **payload: Any) -> None:
|
|
|
if log_callback:
|
|
|
try:
|
|
|
log_callback(event, **payload)
|
|
|
except Exception as exc:
|
|
|
print(f"[log_callback] Error while emitting {event}: {exc}")
|
|
|
|
|
|
steps_to_use = self.inference_steps
|
|
|
if inference_steps is not None:
|
|
|
try:
|
|
|
parsed_steps = int(inference_steps)
|
|
|
if parsed_steps > 0:
|
|
|
steps_to_use = parsed_steps
|
|
|
except (TypeError, ValueError):
|
|
|
pass
|
|
|
if self.model:
|
|
|
self.model.set_ddpm_inference_steps(num_steps=steps_to_use)
|
|
|
self.inference_steps = steps_to_use
|
|
|
|
|
|
inputs = self._prepare_inputs(text, prefilled_outputs)
|
|
|
audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None)
|
|
|
errors: list = []
|
|
|
stop_signal = stop_event or threading.Event()
|
|
|
|
|
|
thread = threading.Thread(
|
|
|
target=self._run_generation,
|
|
|
kwargs={
|
|
|
"inputs": inputs,
|
|
|
"audio_streamer": audio_streamer,
|
|
|
"errors": errors,
|
|
|
"cfg_scale": cfg_scale,
|
|
|
"do_sample": do_sample,
|
|
|
"temperature": temperature,
|
|
|
"top_p": top_p,
|
|
|
"refresh_negative": refresh_negative,
|
|
|
"prefilled_outputs": prefilled_outputs,
|
|
|
"stop_event": stop_signal,
|
|
|
},
|
|
|
daemon=True,
|
|
|
)
|
|
|
thread.start()
|
|
|
|
|
|
generated_samples = 0
|
|
|
|
|
|
try:
|
|
|
stream = audio_streamer.get_stream(0)
|
|
|
for audio_chunk in stream:
|
|
|
if torch.is_tensor(audio_chunk):
|
|
|
audio_chunk = audio_chunk.detach().cpu().to(torch.float32).numpy()
|
|
|
else:
|
|
|
audio_chunk = np.asarray(audio_chunk, dtype=np.float32)
|
|
|
|
|
|
if audio_chunk.ndim > 1:
|
|
|
audio_chunk = audio_chunk.reshape(-1)
|
|
|
|
|
|
peak = np.max(np.abs(audio_chunk)) if audio_chunk.size else 0.0
|
|
|
if peak > 1.0:
|
|
|
audio_chunk = audio_chunk / peak
|
|
|
|
|
|
generated_samples += int(audio_chunk.size)
|
|
|
emit(
|
|
|
"model_progress",
|
|
|
generated_sec=generated_samples / self.sample_rate,
|
|
|
chunk_sec=audio_chunk.size / self.sample_rate,
|
|
|
)
|
|
|
|
|
|
chunk_to_yield = audio_chunk.astype(np.float32, copy=False)
|
|
|
|
|
|
yield chunk_to_yield
|
|
|
finally:
|
|
|
stop_signal.set()
|
|
|
audio_streamer.end()
|
|
|
thread.join()
|
|
|
if errors:
|
|
|
emit("generation_error", message=str(errors[0]))
|
|
|
raise errors[0]
|
|
|
|
|
|
def chunk_to_pcm16(self, chunk: np.ndarray) -> bytes:
|
|
|
chunk = np.clip(chunk, -1.0, 1.0)
|
|
|
pcm = (chunk * 32767.0).astype(np.int16)
|
|
|
return pcm.tobytes()
|
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def _startup() -> None:
|
|
|
model_path = os.environ.get("MODEL_PATH")
|
|
|
if not model_path:
|
|
|
raise RuntimeError("MODEL_PATH not set in environment")
|
|
|
|
|
|
device = os.environ.get("MODEL_DEVICE", "cuda")
|
|
|
|
|
|
service = StreamingTTSService(
|
|
|
model_path=model_path,
|
|
|
device=device
|
|
|
)
|
|
|
service.load()
|
|
|
|
|
|
app.state.tts_service = service
|
|
|
app.state.model_path = model_path
|
|
|
app.state.device = device
|
|
|
app.state.websocket_lock = asyncio.Lock()
|
|
|
print("[startup] Model ready.")
|
|
|
|
|
|
|
|
|
def streaming_tts(text: str, **kwargs) -> Iterator[np.ndarray]:
|
|
|
service: StreamingTTSService = app.state.tts_service
|
|
|
yield from service.stream(text, **kwargs)
|
|
|
|
|
|
@app.websocket("/stream")
|
|
|
async def websocket_stream(ws: WebSocket) -> None:
|
|
|
await ws.accept()
|
|
|
text = ws.query_params.get("text", "")
|
|
|
print(f"Client connected, text={text!r}")
|
|
|
cfg_param = ws.query_params.get("cfg")
|
|
|
steps_param = ws.query_params.get("steps")
|
|
|
voice_param = ws.query_params.get("voice")
|
|
|
|
|
|
try:
|
|
|
cfg_scale = float(cfg_param) if cfg_param is not None else 1.5
|
|
|
except ValueError:
|
|
|
cfg_scale = 1.5
|
|
|
if cfg_scale <= 0:
|
|
|
cfg_scale = 1.5
|
|
|
try:
|
|
|
inference_steps = int(steps_param) if steps_param is not None else None
|
|
|
if inference_steps is not None and inference_steps <= 0:
|
|
|
inference_steps = None
|
|
|
except ValueError:
|
|
|
inference_steps = None
|
|
|
|
|
|
service: StreamingTTSService = app.state.tts_service
|
|
|
lock: asyncio.Lock = app.state.websocket_lock
|
|
|
|
|
|
if lock.locked():
|
|
|
busy_message = {
|
|
|
"type": "log",
|
|
|
"event": "backend_busy",
|
|
|
"data": {"message": "Please wait for the other requests to complete."},
|
|
|
"timestamp": get_timestamp(),
|
|
|
}
|
|
|
print("Please wait for the other requests to complete.")
|
|
|
try:
|
|
|
await ws.send_text(json.dumps(busy_message))
|
|
|
except Exception:
|
|
|
pass
|
|
|
await ws.close(code=1013, reason="Service busy")
|
|
|
return
|
|
|
|
|
|
acquired = False
|
|
|
try:
|
|
|
await lock.acquire()
|
|
|
acquired = True
|
|
|
|
|
|
log_queue: "Queue[Dict[str, Any]]" = Queue()
|
|
|
|
|
|
def enqueue_log(event: str, **data: Any) -> None:
|
|
|
log_queue.put({"event": event, "data": data})
|
|
|
|
|
|
async def flush_logs() -> None:
|
|
|
while True:
|
|
|
try:
|
|
|
entry = log_queue.get_nowait()
|
|
|
except Empty:
|
|
|
break
|
|
|
message = {
|
|
|
"type": "log",
|
|
|
"event": entry.get("event"),
|
|
|
"data": entry.get("data", {}),
|
|
|
"timestamp": get_timestamp(),
|
|
|
}
|
|
|
try:
|
|
|
await ws.send_text(json.dumps(message))
|
|
|
except Exception:
|
|
|
break
|
|
|
|
|
|
enqueue_log(
|
|
|
"backend_request_received",
|
|
|
text_length=len(text or ""),
|
|
|
cfg_scale=cfg_scale,
|
|
|
inference_steps=inference_steps,
|
|
|
voice=voice_param,
|
|
|
)
|
|
|
|
|
|
stop_signal = threading.Event()
|
|
|
|
|
|
iterator = streaming_tts(
|
|
|
text,
|
|
|
cfg_scale=cfg_scale,
|
|
|
inference_steps=inference_steps,
|
|
|
voice_key=voice_param,
|
|
|
log_callback=enqueue_log,
|
|
|
stop_event=stop_signal,
|
|
|
)
|
|
|
sentinel = object()
|
|
|
first_ws_send_logged = False
|
|
|
|
|
|
await flush_logs()
|
|
|
|
|
|
try:
|
|
|
while ws.client_state == WebSocketState.CONNECTED:
|
|
|
await flush_logs()
|
|
|
chunk = await asyncio.to_thread(next, iterator, sentinel)
|
|
|
if chunk is sentinel:
|
|
|
break
|
|
|
chunk = cast(np.ndarray, chunk)
|
|
|
payload = service.chunk_to_pcm16(chunk)
|
|
|
await ws.send_bytes(payload)
|
|
|
if not first_ws_send_logged:
|
|
|
first_ws_send_logged = True
|
|
|
enqueue_log("backend_first_chunk_sent")
|
|
|
await flush_logs()
|
|
|
except WebSocketDisconnect:
|
|
|
print("Client disconnected (WebSocketDisconnect)")
|
|
|
enqueue_log("client_disconnected")
|
|
|
stop_signal.set()
|
|
|
finally:
|
|
|
stop_signal.set()
|
|
|
enqueue_log("backend_stream_complete")
|
|
|
await flush_logs()
|
|
|
try:
|
|
|
iterator_close = getattr(iterator, "close", None)
|
|
|
if callable(iterator_close):
|
|
|
iterator_close()
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
while not log_queue.empty():
|
|
|
try:
|
|
|
log_queue.get_nowait()
|
|
|
except Empty:
|
|
|
break
|
|
|
if ws.client_state == WebSocketState.CONNECTED:
|
|
|
await ws.close()
|
|
|
print("WS handler exit")
|
|
|
finally:
|
|
|
if acquired:
|
|
|
lock.release()
|
|
|
|
|
|
|
|
|
@app.get("/")
|
|
|
def index():
|
|
|
return FileResponse(BASE / "index.html")
|
|
|
|
|
|
|
|
|
@app.get("/config")
|
|
|
def get_config():
|
|
|
service: StreamingTTSService = app.state.tts_service
|
|
|
voices = sorted(service.voice_presets.keys())
|
|
|
return {
|
|
|
"voices": voices,
|
|
|
"default_voice": service.default_voice_key,
|
|
|
}
|
|
|
|
|
|
|