Spaces:
No application file
No application file
| import gc | |
| import os.path | |
| import diffusers | |
| import torch.cuda | |
| import transformers | |
| import librosa | |
| model: diffusers.AudioLDMPipeline = None | |
| loaded = False | |
| clap_model: transformers.ClapModel = None | |
| processor: transformers.ClapProcessor = None | |
| device: str = None | |
| models = ['cvssp/audioldm', 'cvssp/audioldm-s-full-v2', 'cvssp/audioldm-m-full', 'cvssp/audioldm-l-full'] | |
| def create_model(pretrained='cvssp/audioldm-m-full', map_device='cuda' if torch.cuda.is_available() else 'cpu'): | |
| if is_loaded(): | |
| delete_model() | |
| global model, loaded, clap_model, processor, device | |
| try: | |
| cache_dir = os.path.join('data', 'models', 'audioldm') | |
| model = diffusers.AudioLDMPipeline.from_pretrained(pretrained, cache_dir=cache_dir).to(map_device) | |
| clap_model = transformers.ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full", cache_dir=cache_dir).to(map_device) | |
| processor = transformers.AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full", cache_dir=cache_dir) | |
| device = map_device | |
| loaded = True | |
| except: | |
| pass | |
| def delete_model(): | |
| global model, loaded, clap_model, processor, device | |
| try: | |
| del model, clap_model, processor | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| loaded = False | |
| device = None | |
| except: | |
| pass | |
| def is_loaded(): | |
| return loaded | |
| def score_waveforms(text, waveforms): | |
| inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True) | |
| inputs = {key: inputs[key].to(device) for key in inputs} | |
| with torch.no_grad(): | |
| logits_per_text = clap_model(**inputs).logits_per_text # this is the audio-text similarity score | |
| probs = logits_per_text.softmax(dim=-1) # we can take the softmax to get the label probabilities | |
| most_probable = torch.argmax(probs) # and now select the most likely audio waveform | |
| waveform = waveforms[most_probable] | |
| return waveform | |
| def generate(prompt='', negative_prompt='', steps=10, duration=5.0, cfg=2.5, seed=-1, wav_best_count=1, enhance=False, callback=None): | |
| if is_loaded(): | |
| try: | |
| sample_rate = 16000 | |
| seed = seed if seed >= 0 else torch.seed() | |
| torch.manual_seed(seed) | |
| output = model(prompt, negative_prompt=negative_prompt if negative_prompt else None, | |
| audio_length_in_s=duration, num_inference_steps=steps, guidance_scale=cfg, | |
| num_waveforms_per_prompt=wav_best_count, callback=callback) | |
| waveforms = output.audios | |
| if waveforms.shape[0] > 1: | |
| waveform = score_waveforms(prompt, waveforms) | |
| else: | |
| waveform = waveforms[0] | |
| if enhance: # https://github.com/gitmylo/audio-webui/issues/36#issuecomment-1627380868 | |
| sample_rate = 44100 | |
| audio_resampled = librosa.resample(waveform, orig_sr=16000, target_sr=sample_rate) | |
| waveform = audio_resampled + librosa.effects.pitch_shift(audio_resampled, sr=sample_rate, n_steps=12, res_type="soxr_vhq") | |
| return seed, (sample_rate, waveform) | |
| except Exception as e: | |
| return f'An exception occurred: {str(e)}' | |
| return 'No model loaded! Please load a model first.' | |