Spaces:
Runtime error
Runtime error
Upload 30 files
Browse files- eval/README.md +100 -0
- eval/dataset/__init__.py +70 -0
- eval/dataset/musdb.py +75 -0
- eval/dataset/sam_audio_bench.py +153 -0
- eval/main.py +162 -0
- eval/metrics/__init__.py +13 -0
- eval/metrics/aes.py +49 -0
- eval/metrics/clap.py +46 -0
- eval/metrics/imagebind.py +52 -0
- eval/metrics/judge.py +44 -0
- sam_audio/__init__.py +4 -0
- sam_audio/model/__init__.py +4 -0
- sam_audio/model/align.py +50 -0
- sam_audio/model/base.py +58 -0
- sam_audio/model/codec.py +108 -0
- sam_audio/model/config.py +251 -0
- sam_audio/model/judge.py +135 -0
- sam_audio/model/model.py +362 -0
- sam_audio/model/patcher.py +164 -0
- sam_audio/model/rope.py +155 -0
- sam_audio/model/text_encoder.py +37 -0
- sam_audio/model/transformer.py +524 -0
- sam_audio/model/vision_encoder.py +113 -0
- sam_audio/processor.py +382 -0
- sam_audio/ranking/__init__.py +30 -0
- sam_audio/ranking/clap.py +84 -0
- sam_audio/ranking/imagebind.py +197 -0
- sam_audio/ranking/judge.py +42 -0
- sam_audio/ranking/ranker.py +36 -0
- sam_audio/ranking/sound_activity.py +129 -0
eval/README.md
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evaluation
|
| 2 |
+
|
| 3 |
+
This directory contains the evaluation code to reproduce the results from the SAM-Audio paper. The evaluation framework supports multiple datasets, prompting modes (text-only, span, visual), and metrics.
|
| 4 |
+
|
| 5 |
+
## Setup
|
| 6 |
+
|
| 7 |
+
Before running evaluation, ensure you have:
|
| 8 |
+
|
| 9 |
+
1. Installed the SAM-Audio package and its dependencies
|
| 10 |
+
2. Authenticated with Hugging Face to access the model checkpoints (see main [README](../README.md))
|
| 11 |
+
|
| 12 |
+
## Quick Start
|
| 13 |
+
|
| 14 |
+
Run evaluation on the default setting (instr-pro):
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
python main.py
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
You can also use multiple GPUs to speed up evaluation:
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
torchrun --nproc_per_node=<ngpus> python main.py
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
Evaluate on a specific setting:
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
python main.py --setting sfx
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Evaluate on multiple settings:
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
python main.py --setting sfx speech music
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Available Evaluation Settings
|
| 39 |
+
|
| 40 |
+
Run `python main.py --help` to see all available settings
|
| 41 |
+
|
| 42 |
+
## Command Line Options
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
python main.py [OPTIONS]
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### Options:
|
| 49 |
+
|
| 50 |
+
- `-s, --setting` - Which setting(s) to evaluate (default: `instr-pro`)
|
| 51 |
+
- Choices: See available settings above
|
| 52 |
+
- Can specify multiple settings: `--setting sfx speech music`
|
| 53 |
+
|
| 54 |
+
- `--cache-path` - Where to cache downloaded datasets (default: `~/.cache/sam_audio`)
|
| 55 |
+
|
| 56 |
+
- `-p, --checkpoint-path` - Model checkpoint to evaluate (default: `facebook/sam-audio-1b`)
|
| 57 |
+
- Can use local path or Hugging Face model ID
|
| 58 |
+
|
| 59 |
+
- `-b, --batch-size` - Batch size for evaluation (default: `1`)
|
| 60 |
+
|
| 61 |
+
- `-w, --num-workers` - Number of data loading workers (default: `4`)
|
| 62 |
+
|
| 63 |
+
- `-c, --candidates` - Number of reranking candidates (default: `8`)
|
| 64 |
+
|
| 65 |
+
## Evaluation Metrics
|
| 66 |
+
|
| 67 |
+
The evaluation framework computes the following metrics:
|
| 68 |
+
|
| 69 |
+
- **Judge** - SAM Audio Judge quality assessment metric
|
| 70 |
+
- **Aesthetic** - Aesthetic quality metric
|
| 71 |
+
- **CLAP** - Audio-text alignment metric (CLAP similarity)
|
| 72 |
+
- **ImageBind** - Audio-video alignment metric (for visual settings only)
|
| 73 |
+
|
| 74 |
+
## Output
|
| 75 |
+
|
| 76 |
+
Results are saved to the `results/` directory as JSON files, one per setting:
|
| 77 |
+
|
| 78 |
+
```
|
| 79 |
+
results/
|
| 80 |
+
├── sfx.json
|
| 81 |
+
├── speech.json
|
| 82 |
+
└── music.json
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Each JSON file contains the averaged metric scores across all samples in that setting.
|
| 86 |
+
|
| 87 |
+
Example output:
|
| 88 |
+
```json
|
| 89 |
+
{
|
| 90 |
+
"JudgeOverall": "4.386",
|
| 91 |
+
"JudgeFaithfulness": "4.708",
|
| 92 |
+
"JudgeRecall": "4.934",
|
| 93 |
+
"JudgePrecision": "4.451",
|
| 94 |
+
"ContentEnjoyment": "5.296",
|
| 95 |
+
"ContentUsefulness": "6.903",
|
| 96 |
+
"ProductionComplexity": "4.301",
|
| 97 |
+
"ProductionQuality": "7.100",
|
| 98 |
+
"CLAPSimilarity": "0.271"
|
| 99 |
+
}
|
| 100 |
+
```
|
eval/dataset/__init__.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
from .musdb import MUSDB
|
| 6 |
+
from .sam_audio_bench import SAMAudioBench
|
| 7 |
+
|
| 8 |
+
SETTINGS = {
|
| 9 |
+
# Text-only settings
|
| 10 |
+
"sfx": (
|
| 11 |
+
SAMAudioBench,
|
| 12 |
+
{"span": False, "visual": False, "subset": "others-50:text-only"},
|
| 13 |
+
),
|
| 14 |
+
"speech": (
|
| 15 |
+
SAMAudioBench,
|
| 16 |
+
{"span": False, "visual": False, "subset": "speech-clean-50:text-only"},
|
| 17 |
+
),
|
| 18 |
+
"speaker": (
|
| 19 |
+
SAMAudioBench,
|
| 20 |
+
{"span": False, "visual": False, "subset": "spk-50:text-only"},
|
| 21 |
+
),
|
| 22 |
+
"music": (
|
| 23 |
+
SAMAudioBench,
|
| 24 |
+
{"span": False, "visual": False, "subset": "music-clean-50:text-only"},
|
| 25 |
+
),
|
| 26 |
+
"instr-wild": (
|
| 27 |
+
SAMAudioBench,
|
| 28 |
+
{"span": False, "visual": False, "subset": "instr-50:text-only"},
|
| 29 |
+
),
|
| 30 |
+
"instr-pro": (MUSDB, {}),
|
| 31 |
+
# Span settings
|
| 32 |
+
"sfx-span": (
|
| 33 |
+
SAMAudioBench,
|
| 34 |
+
{"span": True, "visual": False, "subset": "others-50:text+span"},
|
| 35 |
+
),
|
| 36 |
+
"speech-span": (
|
| 37 |
+
SAMAudioBench,
|
| 38 |
+
{"span": True, "visual": False, "subset": "speech-clean-50:text+span"},
|
| 39 |
+
),
|
| 40 |
+
"speaker-span": (
|
| 41 |
+
SAMAudioBench,
|
| 42 |
+
{"span": True, "visual": False, "subset": "spk-50:text+span"},
|
| 43 |
+
),
|
| 44 |
+
"music-span": (
|
| 45 |
+
SAMAudioBench,
|
| 46 |
+
{"span": True, "visual": False, "subset": "music-clean-50:text+span"},
|
| 47 |
+
),
|
| 48 |
+
"instr-wild-span": (
|
| 49 |
+
SAMAudioBench,
|
| 50 |
+
{"span": True, "visual": False, "subset": "instr-50:text+span"},
|
| 51 |
+
),
|
| 52 |
+
# Visual settings
|
| 53 |
+
"sfx-visual": (
|
| 54 |
+
SAMAudioBench,
|
| 55 |
+
{"span": False, "visual": True, "subset": "others-onscreen-50:visual-only"},
|
| 56 |
+
),
|
| 57 |
+
"speaker-visual": (
|
| 58 |
+
SAMAudioBench,
|
| 59 |
+
{"span": False, "visual": True, "subset": "spk-onscreen-50:visual-only"},
|
| 60 |
+
),
|
| 61 |
+
"instr-wild-visual": (
|
| 62 |
+
SAMAudioBench,
|
| 63 |
+
{"span": False, "visual": True, "subset": "instr-onscreen-50:visual-only"},
|
| 64 |
+
),
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def make_dataset(setting: str, cache_path: str, collate_fn: Callable):
|
| 69 |
+
dataset, kwargs = SETTINGS[setting]
|
| 70 |
+
return dataset(cache_path=cache_path, collate_fn=collate_fn, **kwargs)
|
eval/dataset/musdb.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from subprocess import check_call
|
| 5 |
+
|
| 6 |
+
import torchaudio
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
from torchcodec.decoders import AudioDecoder
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def cache_file(url, outfile):
|
| 13 |
+
if not os.path.exists(outfile):
|
| 14 |
+
print("Downloading musdb18hq dataset...")
|
| 15 |
+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
|
| 16 |
+
check_call(["curl", "--url", url, "--output", outfile + ".tmp"])
|
| 17 |
+
os.rename(outfile + ".tmp", outfile)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MUSDB(Dataset):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
collate_fn,
|
| 24 |
+
sample_rate: int = 48_000,
|
| 25 |
+
cache_path: str = os.path.expanduser("~/.cache/sam_audio"),
|
| 26 |
+
):
|
| 27 |
+
self.cache_path = os.path.join(cache_path, "musdb18hq")
|
| 28 |
+
self.ds = self.get_dataset(cache_path)
|
| 29 |
+
self.captions = ["bass", "drums", "vocals"]
|
| 30 |
+
self.collate_fn = collate_fn
|
| 31 |
+
self.sample_rate = sample_rate
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
def visual(self):
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
def get_dataset(self, cache_path):
|
| 38 |
+
zip_file = os.path.join(cache_path, "musdb18hq.zip")
|
| 39 |
+
url = "https://zenodo.org/records/3338373/files/musdb18hq.zip?download=1"
|
| 40 |
+
cache_file(url, zip_file)
|
| 41 |
+
extracted_dir = os.path.join(cache_path, "musdb18hq")
|
| 42 |
+
if not os.path.exists(extracted_dir):
|
| 43 |
+
check_call(["unzip", zip_file, "-d", extracted_dir + ".tmp"])
|
| 44 |
+
os.rename(extracted_dir + ".tmp", extracted_dir)
|
| 45 |
+
return load_dataset("facebook/sam-audio-musdb18hq-test")["test"]
|
| 46 |
+
|
| 47 |
+
def __len__(self):
|
| 48 |
+
return len(self.ds)
|
| 49 |
+
|
| 50 |
+
def collate(self, items):
|
| 51 |
+
audios, descriptions = zip(*items, strict=False)
|
| 52 |
+
return self.collate_fn(
|
| 53 |
+
audios=audios,
|
| 54 |
+
descriptions=descriptions,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, idx):
|
| 58 |
+
item = self.ds[idx]
|
| 59 |
+
path = os.path.join(self.cache_path, "test", item["id"], "mixture.wav")
|
| 60 |
+
assert os.path.exists(path), f"{path} does not exist!"
|
| 61 |
+
decoder = AudioDecoder(path)
|
| 62 |
+
data = decoder.get_samples_played_in_range(item["start_time"], item["end_time"])
|
| 63 |
+
wav = data.data
|
| 64 |
+
if data.sample_rate != self.sample_rate:
|
| 65 |
+
wav = torchaudio.functional.resample(
|
| 66 |
+
wav, data.sample_rate, self.sample_rate
|
| 67 |
+
)
|
| 68 |
+
wav = wav.mean(0, keepdim=True)
|
| 69 |
+
return wav, item["description"]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
dataset = MUSDB(lambda **kwargs: None)
|
| 74 |
+
print(len(dataset))
|
| 75 |
+
print(dataset[0])
|
eval/dataset/sam_audio_bench.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torchaudio
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
from torchcodec.decoders import AudioDecoder, VideoDecoder
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class Item:
|
| 18 |
+
anchors: list[Tuple[str, float, float]]
|
| 19 |
+
masked_video_frames: torch.Tensor
|
| 20 |
+
audio_samples: torch.Tensor
|
| 21 |
+
description: str
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SAMAudioBench(torch.utils.data.Dataset):
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
cache_path,
|
| 28 |
+
collate_fn,
|
| 29 |
+
span: bool = True,
|
| 30 |
+
visual: bool = True,
|
| 31 |
+
subset: Optional[str] = None,
|
| 32 |
+
):
|
| 33 |
+
self.dataset = load_dataset("facebook/sam-audio-bench")["test"]
|
| 34 |
+
self.subset = subset
|
| 35 |
+
self._span = span
|
| 36 |
+
self._visual = visual
|
| 37 |
+
if subset is not None:
|
| 38 |
+
self.dataset = self.dataset.filter(lambda x: subset in x["paper_eval_sets"])
|
| 39 |
+
|
| 40 |
+
self.cache_path = os.path.join(cache_path, "sam_audio_bench")
|
| 41 |
+
self.collate_fn = collate_fn
|
| 42 |
+
DATA_MSG = (
|
| 43 |
+
f"`SAMAudioBench` requires the user to create a directory named {self.cache_path} "
|
| 44 |
+
"see the README.md file for how to prepare"
|
| 45 |
+
)
|
| 46 |
+
assert os.path.exists(self.cache_path), DATA_MSG
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def visual(self):
|
| 50 |
+
return self._visual
|
| 51 |
+
|
| 52 |
+
def __len__(self):
|
| 53 |
+
return len(self.dataset)
|
| 54 |
+
|
| 55 |
+
def _get_path(
|
| 56 |
+
self, video_id: str, source_dataset: str, start_offset: float, end_offset: float
|
| 57 |
+
) -> str:
|
| 58 |
+
path = f"{self.cache_path}/{source_dataset}/{video_id}.mp4"
|
| 59 |
+
select_frames = True
|
| 60 |
+
|
| 61 |
+
if not os.path.exists(path):
|
| 62 |
+
path = f"{self.cache_path}/{source_dataset}/{video_id}_{int(start_offset * 1000)}_{int(end_offset * 1000)}.mp4"
|
| 63 |
+
select_frames = False
|
| 64 |
+
|
| 65 |
+
if not os.path.exists(path):
|
| 66 |
+
path = f"{self.cache_path}/{source_dataset}/{video_id}_{int(start_offset)}_{int(end_offset)}.mp4"
|
| 67 |
+
|
| 68 |
+
if not os.path.exists(path):
|
| 69 |
+
path = f"{self.cache_path}/{source_dataset}/{video_id}.{int(start_offset * 1000):08d}_{int(end_offset * 1000):08d}.mp4"
|
| 70 |
+
|
| 71 |
+
return path, select_frames
|
| 72 |
+
|
| 73 |
+
def collate(self, items: list[Item]):
|
| 74 |
+
has_video = any(item.masked_video_frames is not None for item in items)
|
| 75 |
+
return self.collate_fn(
|
| 76 |
+
descriptions=[item.description for item in items],
|
| 77 |
+
audios=[item.audio_samples for item in items],
|
| 78 |
+
anchors=[item.anchors for item in items] if self._span else None,
|
| 79 |
+
masked_videos=[item.masked_video_frames for item in items]
|
| 80 |
+
if has_video and self._visual
|
| 81 |
+
else None,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def _get_masked_video(self, item, video_path, select_frames):
|
| 85 |
+
if item["mask_bytes"] is None:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
mask = torch.from_numpy(np.load(BytesIO(item["mask_bytes"]))["video_masklet"])
|
| 89 |
+
|
| 90 |
+
video_decoder = VideoDecoder(video_path)
|
| 91 |
+
if select_frames:
|
| 92 |
+
video_frames = video_decoder.get_frames_played_in_range(
|
| 93 |
+
item["start_offset"], item["end_offset"]
|
| 94 |
+
).data
|
| 95 |
+
else:
|
| 96 |
+
video_frames = video_decoder[:].data
|
| 97 |
+
|
| 98 |
+
if mask.size(0) != video_frames.size(0):
|
| 99 |
+
# It's possible that the mask and the video frames differ by a small amount
|
| 100 |
+
# we interpolate the mask frame to match
|
| 101 |
+
idxs = (
|
| 102 |
+
torch.linspace(0, mask.size(0) - 1, video_frames.size(0)).round().long()
|
| 103 |
+
)
|
| 104 |
+
mask = mask[idxs]
|
| 105 |
+
|
| 106 |
+
mask = mask.unsqueeze(1)
|
| 107 |
+
|
| 108 |
+
if mask.shape[-2:] != video_frames.shape[-2:]:
|
| 109 |
+
mask = F.interpolate(mask, size=video_frames.shape[-2:])
|
| 110 |
+
|
| 111 |
+
import torchvision
|
| 112 |
+
|
| 113 |
+
torchvision.io.write_video("test.mp4", video_frames.permute(0, 2, 3, 1), 30)
|
| 114 |
+
torchvision.io.write_video(
|
| 115 |
+
"test_mask.mp4", mask.unsqueeze(-1).expand(-1, -1, -1, 3) * 255, 30
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return video_frames * mask
|
| 119 |
+
|
| 120 |
+
def __getitem__(self, idx) -> Item:
|
| 121 |
+
item = self.dataset[idx]
|
| 122 |
+
|
| 123 |
+
video_path, select_frames = self._get_path(
|
| 124 |
+
item["video_id"],
|
| 125 |
+
item["source_dataset"],
|
| 126 |
+
item["start_offset"],
|
| 127 |
+
item["end_offset"],
|
| 128 |
+
)
|
| 129 |
+
assert os.path.exists(video_path), f"{video_path} does not exist!"
|
| 130 |
+
|
| 131 |
+
audio_decoder = AudioDecoder(video_path)
|
| 132 |
+
audio_samples = audio_decoder.get_samples_played_in_range(
|
| 133 |
+
start_seconds=item["start_offset"] if select_frames else 0,
|
| 134 |
+
stop_seconds=item["end_offset"] if select_frames else None,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
if audio_samples.sample_rate != self.collate_fn.audio_sampling_rate:
|
| 138 |
+
resampled_audio = torchaudio.functional.resample(
|
| 139 |
+
audio_samples.data,
|
| 140 |
+
audio_samples.sample_rate,
|
| 141 |
+
self.collate_fn.audio_sampling_rate,
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
resampled_audio = audio_samples.data
|
| 145 |
+
|
| 146 |
+
masked_video_frames = self._get_masked_video(item, video_path, select_frames)
|
| 147 |
+
|
| 148 |
+
return Item(
|
| 149 |
+
description=item["description"],
|
| 150 |
+
anchors=[("+", start, end) for start, end in item["spans"]],
|
| 151 |
+
masked_video_frames=masked_video_frames,
|
| 152 |
+
audio_samples=resampled_audio.mean(0, keepdim=True),
|
| 153 |
+
)
|
eval/main.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from dataset import SETTINGS, make_dataset
|
| 11 |
+
from metrics import CLAP, Aesthetic, ImageBind, Judge
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
from sam_audio import SAMAudio, SAMAudioProcessor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def gather_and_average_results(results, world_size):
|
| 20 |
+
if world_size == 1:
|
| 21 |
+
return json.loads(results.mean().to_json())
|
| 22 |
+
|
| 23 |
+
# 1. Gather all dictionaries to all ranks
|
| 24 |
+
all_results = [None for _ in range(world_size)]
|
| 25 |
+
dist.all_gather_object(
|
| 26 |
+
all_results, {"sum": results.sum().to_json(), "count": len(results)}
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
summed = {}
|
| 30 |
+
counts = 0
|
| 31 |
+
|
| 32 |
+
for res in all_results:
|
| 33 |
+
for k, v in json.loads(res["sum"]).items():
|
| 34 |
+
if k not in summed:
|
| 35 |
+
summed[k] = 0.0
|
| 36 |
+
summed[k] += v
|
| 37 |
+
counts += res["count"]
|
| 38 |
+
|
| 39 |
+
# 3. Compute average for keys that appeared at least once
|
| 40 |
+
averaged = {k: summed[k] / counts for k in summed}
|
| 41 |
+
|
| 42 |
+
return averaged
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main(
|
| 46 |
+
settings: list[str],
|
| 47 |
+
cache_path: str,
|
| 48 |
+
batch_size: int,
|
| 49 |
+
checkpoint_path: str,
|
| 50 |
+
num_workers: int = 4,
|
| 51 |
+
reranking_candidates: int = 8,
|
| 52 |
+
):
|
| 53 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 54 |
+
rank = int(os.environ.get("RANK", 0))
|
| 55 |
+
|
| 56 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
+
if world_size > 1:
|
| 58 |
+
torch.distributed.init_process_group(backend="nccl")
|
| 59 |
+
device = torch.device(f"cuda:{rank}")
|
| 60 |
+
torch.cuda.set_device(device)
|
| 61 |
+
|
| 62 |
+
model = SAMAudio.from_pretrained(checkpoint_path)
|
| 63 |
+
model = model.eval().to(device)
|
| 64 |
+
processor = SAMAudioProcessor.from_pretrained(checkpoint_path)
|
| 65 |
+
|
| 66 |
+
judge_metric = Judge(device=device)
|
| 67 |
+
aes_metric = Aesthetic(device=device)
|
| 68 |
+
clap_metric = CLAP(device=device)
|
| 69 |
+
imagebind_metric = ImageBind(device=device)
|
| 70 |
+
|
| 71 |
+
for setting in settings:
|
| 72 |
+
print(f"Evaluating: {setting}")
|
| 73 |
+
dset = make_dataset(setting, cache_path=cache_path, collate_fn=processor)
|
| 74 |
+
sampler = None
|
| 75 |
+
if world_size > 1:
|
| 76 |
+
sampler = DistributedSampler(dset)
|
| 77 |
+
|
| 78 |
+
dl = DataLoader(
|
| 79 |
+
dset,
|
| 80 |
+
batch_size=batch_size,
|
| 81 |
+
shuffle=False,
|
| 82 |
+
collate_fn=dset.collate,
|
| 83 |
+
num_workers=num_workers,
|
| 84 |
+
sampler=sampler,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
all_metrics = [
|
| 88 |
+
judge_metric,
|
| 89 |
+
aes_metric,
|
| 90 |
+
clap_metric,
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
if dset.visual:
|
| 94 |
+
all_metrics.append(imagebind_metric)
|
| 95 |
+
|
| 96 |
+
dfs = []
|
| 97 |
+
with torch.inference_mode():
|
| 98 |
+
for batch in tqdm(dl, disable=rank > 1):
|
| 99 |
+
batch = batch.to(device)
|
| 100 |
+
result = model.separate(
|
| 101 |
+
batch, reranking_candidates=reranking_candidates
|
| 102 |
+
)
|
| 103 |
+
mets = {}
|
| 104 |
+
for metric in all_metrics:
|
| 105 |
+
input_wavs = model.unbatch(batch.audios.squeeze(1), batch.wav_sizes)
|
| 106 |
+
|
| 107 |
+
mets.update(
|
| 108 |
+
metric(
|
| 109 |
+
target_wavs=result.target,
|
| 110 |
+
target_wavs_sample_rate=model.sample_rate,
|
| 111 |
+
descriptions=batch.descriptions,
|
| 112 |
+
input_wavs=input_wavs,
|
| 113 |
+
videos=batch.masked_video,
|
| 114 |
+
)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
dfs.append(pd.DataFrame.from_dict(mets))
|
| 118 |
+
|
| 119 |
+
df = pd.concat(dfs)
|
| 120 |
+
averaged_results = gather_and_average_results(df, world_size)
|
| 121 |
+
if rank == 0:
|
| 122 |
+
results_dict = {k: f"{v:.3f}" for k, v in averaged_results.items()}
|
| 123 |
+
print(json.dumps(results_dict, indent=4))
|
| 124 |
+
os.makedirs("results", exist_ok=True)
|
| 125 |
+
outfile = f"results/{setting}.json"
|
| 126 |
+
with open(outfile, "w") as fout:
|
| 127 |
+
print(json.dumps(results_dict), file=fout)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
parser = argparse.ArgumentParser()
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--setting",
|
| 134 |
+
"-s",
|
| 135 |
+
choices=SETTINGS.keys(),
|
| 136 |
+
help=f"Which setting to evaluate. Choices: {SETTINGS.keys()}",
|
| 137 |
+
default=["instr-pro"],
|
| 138 |
+
nargs="+",
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--cache-path",
|
| 142 |
+
type=str,
|
| 143 |
+
default=os.path.expanduser("~/.cache/sam_audio"),
|
| 144 |
+
help="Where to cache downloaded datasets",
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--checkpoint-path", "-p", type=str, default="facebook/sam-audio-large"
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument("--batch-size", "-b", type=int, default=1, help="Batch size")
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--num-workers", "-w", type=int, default=4, help="Number of workers"
|
| 152 |
+
)
|
| 153 |
+
parser.add_argument("--candidates", "-c", type=int, default=8)
|
| 154 |
+
opt = parser.parse_args()
|
| 155 |
+
main(
|
| 156 |
+
settings=opt.setting,
|
| 157 |
+
cache_path=opt.cache_path,
|
| 158 |
+
batch_size=opt.batch_size,
|
| 159 |
+
checkpoint_path=opt.checkpoint_path,
|
| 160 |
+
num_workers=opt.num_workers,
|
| 161 |
+
reranking_candidates=opt.candidates,
|
| 162 |
+
)
|
eval/metrics/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from metrics.aes import Aesthetic
|
| 4 |
+
from metrics.clap import CLAP
|
| 5 |
+
from metrics.imagebind import ImageBind
|
| 6 |
+
from metrics.judge import Judge
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"Aesthetic",
|
| 10 |
+
"CLAP",
|
| 11 |
+
"ImageBind",
|
| 12 |
+
"Judge",
|
| 13 |
+
]
|
eval/metrics/aes.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from audiobox_aesthetics.infer import AesPredictor
|
| 7 |
+
|
| 8 |
+
COLUMN_MAP = {
|
| 9 |
+
"CE": "ContentEnjoyment",
|
| 10 |
+
"CU": "ContentUsefulness",
|
| 11 |
+
"PC": "ProductionComplexity",
|
| 12 |
+
"PQ": "ProductionQuality",
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Aesthetic(torch.nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
checkpoint: Optional[str] = None,
|
| 20 |
+
device: Optional[torch.device] = None,
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.model = AesPredictor(
|
| 24 |
+
checkpoint_pth=checkpoint,
|
| 25 |
+
data_col="wav",
|
| 26 |
+
)
|
| 27 |
+
self.device = device or torch.device(
|
| 28 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def __call__(
|
| 32 |
+
self,
|
| 33 |
+
target_wavs: list[torch.Tensor],
|
| 34 |
+
target_wavs_sample_rate: int = 48_000,
|
| 35 |
+
**kwargs,
|
| 36 |
+
) -> dict[str, list[float]]:
|
| 37 |
+
result = self.model.forward(
|
| 38 |
+
[
|
| 39 |
+
{
|
| 40 |
+
"wav": wav[None] if wav.ndim == 1 else wav,
|
| 41 |
+
"sample_rate": target_wavs_sample_rate,
|
| 42 |
+
}
|
| 43 |
+
for wav in target_wavs
|
| 44 |
+
]
|
| 45 |
+
)
|
| 46 |
+
return {
|
| 47 |
+
long_name: [x[shortname] for x in result]
|
| 48 |
+
for shortname, long_name in COLUMN_MAP.items()
|
| 49 |
+
}
|
eval/metrics/clap.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from tempfile import TemporaryDirectory
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torchcodec.encoders import AudioEncoder
|
| 8 |
+
|
| 9 |
+
from sam_audio.ranking.clap import get_model
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CLAP(torch.nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
checkpoint: Optional[str] = None,
|
| 16 |
+
device: Optional[torch.device] = None,
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.model = get_model(device)
|
| 20 |
+
self.device = device or torch.device(
|
| 21 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def __call__(
|
| 25 |
+
self,
|
| 26 |
+
target_wavs: list[torch.Tensor],
|
| 27 |
+
descriptions: list[str],
|
| 28 |
+
target_wavs_sample_rate: int = 48_000,
|
| 29 |
+
**kwargs,
|
| 30 |
+
) -> list[dict[str, float]]:
|
| 31 |
+
with TemporaryDirectory() as tdir, torch.inference_mode():
|
| 32 |
+
file_list = []
|
| 33 |
+
for i, wav in enumerate(target_wavs):
|
| 34 |
+
file_list.append(f"{tdir}/hyp_{i}.wav")
|
| 35 |
+
encoder = AudioEncoder(
|
| 36 |
+
samples=wav.cpu()[None] if wav.ndim == 1 else wav.cpu(),
|
| 37 |
+
sample_rate=target_wavs_sample_rate,
|
| 38 |
+
)
|
| 39 |
+
encoder.to_file(file_list[-1])
|
| 40 |
+
audio_embs = self.model.get_audio_embedding_from_filelist(
|
| 41 |
+
file_list, use_tensor=True
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
text_embs = self.model.get_text_embedding(descriptions, use_tensor=True)
|
| 45 |
+
sims = audio_embs.unsqueeze(1) @ text_embs.unsqueeze(2)
|
| 46 |
+
return {"CLAPSimilarity": sims.cpu()[:, 0, 0].tolist()}
|
eval/metrics/imagebind.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from imagebind.models.imagebind_model import ModalityType, imagebind_huge
|
| 7 |
+
|
| 8 |
+
from sam_audio.ranking.imagebind import VideoTransform, load_and_transform_audio_data
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ImageBind(torch.nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
checkpoint: Optional[str] = None,
|
| 15 |
+
device: Optional[torch.device] = None,
|
| 16 |
+
):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
self.model = imagebind_huge(pretrained=checkpoint is None)
|
| 20 |
+
if checkpoint is not None:
|
| 21 |
+
self.model.load_state_dict(torch.load(checkpoint, map_location="cpu"))
|
| 22 |
+
self.model = self.model.eval()
|
| 23 |
+
self.video_transform = VideoTransform()
|
| 24 |
+
self.device = device or torch.device(
|
| 25 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
+
)
|
| 27 |
+
self.model = self.model.to(self.device)
|
| 28 |
+
|
| 29 |
+
def __call__(
|
| 30 |
+
self,
|
| 31 |
+
target_wavs: list[torch.Tensor],
|
| 32 |
+
videos: list[torch.Tensor],
|
| 33 |
+
target_wavs_sample_rate: int = 48_000,
|
| 34 |
+
**kwargs,
|
| 35 |
+
) -> dict[str, list[float]]:
|
| 36 |
+
audio_data = load_and_transform_audio_data(
|
| 37 |
+
target_wavs, input_sample_rate=target_wavs_sample_rate
|
| 38 |
+
)
|
| 39 |
+
durations = [x.size(-1) / target_wavs_sample_rate for x in target_wavs]
|
| 40 |
+
video_data = self.video_transform(videos, durations, audio_data.device)
|
| 41 |
+
|
| 42 |
+
inputs = {ModalityType.AUDIO: audio_data, ModalityType.VISION: video_data}
|
| 43 |
+
embs = self.model(inputs)
|
| 44 |
+
audio_embs, video_embs = embs[ModalityType.AUDIO], embs[ModalityType.VISION]
|
| 45 |
+
audio_embs, video_embs = (
|
| 46 |
+
audio_embs / ((audio_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
|
| 47 |
+
video_embs / ((video_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
|
| 48 |
+
)
|
| 49 |
+
bsz = len(target_wavs)
|
| 50 |
+
candidates = len(audio_embs) // bsz
|
| 51 |
+
scores = audio_embs.view(bsz, candidates, -1) @ video_embs.view(bsz, -1, 1)
|
| 52 |
+
return {"ImageBind": scores.squeeze(1, 2).cpu().tolist()}
|
eval/metrics/judge.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from sam_audio import SAMAudioJudgeModel, SAMAudioJudgeProcessor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Judge(torch.nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
checkpoint: str = "facebook/sam-audio-judge",
|
| 14 |
+
device: Optional[torch.device] = None,
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.model = SAMAudioJudgeModel.from_pretrained(checkpoint).to(device)
|
| 18 |
+
self.processor = SAMAudioJudgeProcessor.from_pretrained(checkpoint)
|
| 19 |
+
self.device = device or torch.device(
|
| 20 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def forward(
|
| 24 |
+
self,
|
| 25 |
+
input_wavs: list[torch.Tensor],
|
| 26 |
+
target_wavs: list[torch.Tensor],
|
| 27 |
+
descriptions: list[str],
|
| 28 |
+
target_wavs_sample_rate: int = 48_000,
|
| 29 |
+
**kwargs,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
with torch.inference_mode():
|
| 32 |
+
processed = self.processor(
|
| 33 |
+
text=descriptions,
|
| 34 |
+
input_audio=[x.cpu() for x in input_wavs],
|
| 35 |
+
separated_audio=[x.cpu() for x in target_wavs],
|
| 36 |
+
sampling_rate=target_wavs_sample_rate,
|
| 37 |
+
).to(self.device)
|
| 38 |
+
result = self.model(**processed)
|
| 39 |
+
return {
|
| 40 |
+
"JudgeOverall": result.overall.squeeze(-1).cpu().tolist(),
|
| 41 |
+
"JudgeFaithfulness": result.faithfulness.squeeze(-1).cpu().tolist(),
|
| 42 |
+
"JudgeRecall": result.recall.squeeze(-1).cpu().tolist(),
|
| 43 |
+
"JudgePrecision": result.precision.squeeze(-1).cpu().tolist(),
|
| 44 |
+
}
|
sam_audio/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from .model import * # noqa
|
| 4 |
+
from .processor import * # noqa
|
sam_audio/model/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from .model import * # noqa
|
| 4 |
+
from .judge import * # noqa
|
sam_audio/model/align.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AlignModalities(torch.nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
in_channels: int,
|
| 12 |
+
out_channels: int,
|
| 13 |
+
normalize: bool = True,
|
| 14 |
+
with_gate: bool = True,
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.conv = torch.nn.Conv1d(
|
| 18 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1
|
| 19 |
+
)
|
| 20 |
+
self.normalize = normalize
|
| 21 |
+
if self.normalize:
|
| 22 |
+
self.layer_norm = torch.nn.LayerNorm(out_channels)
|
| 23 |
+
|
| 24 |
+
self.gate = None
|
| 25 |
+
if with_gate:
|
| 26 |
+
self.gate = torch.nn.Parameter(torch.tensor([0.0]))
|
| 27 |
+
|
| 28 |
+
self.out_channels = out_channels
|
| 29 |
+
|
| 30 |
+
def forward(self, anchor: torch.Tensor, tgt: Optional[torch.Tensor] = None):
|
| 31 |
+
"""
|
| 32 |
+
Align video features to the input audio features
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
anchor (torch.Tensor): Input anchor tensor of shape (B, T, C), where B is batch size, C is channel size, and T is sequence length.
|
| 36 |
+
tgt (Optional[torch.Tensor]): Optional features tensor to be aligned to anchor, expected shape (B, in_channels, T).
|
| 37 |
+
"""
|
| 38 |
+
if tgt is None:
|
| 39 |
+
return anchor
|
| 40 |
+
|
| 41 |
+
post_conv = self.conv(tgt)
|
| 42 |
+
post_conv = post_conv.permute(0, 2, 1) # BCT -> BTC
|
| 43 |
+
|
| 44 |
+
if self.normalize:
|
| 45 |
+
post_conv = self.layer_norm(post_conv)
|
| 46 |
+
|
| 47 |
+
if self.gate is None:
|
| 48 |
+
return post_conv
|
| 49 |
+
else:
|
| 50 |
+
return anchor + self.gate.tanh() * post_conv
|
sam_audio/model/base.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from typing import Callable, Dict, Optional, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from huggingface_hub import ModelHubMixin, snapshot_download
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseModel(torch.nn.Module, ModelHubMixin):
|
| 12 |
+
config_cls: Callable
|
| 13 |
+
|
| 14 |
+
def device(self):
|
| 15 |
+
return next(self.parameters()).device
|
| 16 |
+
|
| 17 |
+
@classmethod
|
| 18 |
+
def _from_pretrained(
|
| 19 |
+
cls,
|
| 20 |
+
*,
|
| 21 |
+
model_id: str,
|
| 22 |
+
cache_dir: str,
|
| 23 |
+
force_download: bool,
|
| 24 |
+
proxies: Optional[Dict],
|
| 25 |
+
resume_download: bool,
|
| 26 |
+
local_files_only: bool,
|
| 27 |
+
token: Union[str, bool, None],
|
| 28 |
+
map_location: str = "cpu",
|
| 29 |
+
strict: bool = True,
|
| 30 |
+
revision: Optional[str] = None,
|
| 31 |
+
**model_kwargs,
|
| 32 |
+
):
|
| 33 |
+
if os.path.isdir(model_id):
|
| 34 |
+
cached_model_dir = model_id
|
| 35 |
+
else:
|
| 36 |
+
cached_model_dir = snapshot_download(
|
| 37 |
+
repo_id=model_id,
|
| 38 |
+
revision=cls.revision,
|
| 39 |
+
cache_dir=cache_dir,
|
| 40 |
+
force_download=force_download,
|
| 41 |
+
proxies=proxies,
|
| 42 |
+
resume_download=resume_download,
|
| 43 |
+
token=token,
|
| 44 |
+
local_files_only=local_files_only,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
with open(os.path.join(cached_model_dir, "config.json")) as fin:
|
| 48 |
+
config = json.load(fin)
|
| 49 |
+
|
| 50 |
+
config = cls.config_cls(**config)
|
| 51 |
+
model = cls(config)
|
| 52 |
+
state_dict = torch.load(
|
| 53 |
+
os.path.join(cached_model_dir, "checkpoint.pt"),
|
| 54 |
+
weights_only=True,
|
| 55 |
+
map_location=map_location,
|
| 56 |
+
)
|
| 57 |
+
model.load_state_dict(state_dict, strict=strict)
|
| 58 |
+
return model
|
sam_audio/model/codec.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from abc import ABCMeta, abstractmethod
|
| 5 |
+
from typing import Union
|
| 6 |
+
|
| 7 |
+
import dacvae
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from sam_audio.model.config import DACVAEConfig
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Encoder(torch.nn.Module, metaclass=ABCMeta):
|
| 14 |
+
@abstractmethod
|
| 15 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor: ...
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Codec(Encoder):
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor: ...
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def wav_idx_to_feature_idx(
|
| 24 |
+
self, wav_idx: Union[torch.Tensor, int], sample_rate=None
|
| 25 |
+
) -> Union[torch.Tensor, int]: ...
|
| 26 |
+
|
| 27 |
+
@abstractmethod
|
| 28 |
+
def feature_idx_to_wav_idx(
|
| 29 |
+
self, feature_idx: Union[torch.Tensor, int], sample_rate=None
|
| 30 |
+
) -> Union[torch.Tensor, int]: ...
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def cast_to_int(
|
| 34 |
+
x: Union[int, torch.Tensor],
|
| 35 |
+
) -> Union[int, torch.Tensor]:
|
| 36 |
+
if isinstance(x, torch.Tensor):
|
| 37 |
+
return x.int()
|
| 38 |
+
else:
|
| 39 |
+
return int(x)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DACVAEEncoder(Encoder):
|
| 43 |
+
def __init__(self, config: DACVAEConfig) -> None:
|
| 44 |
+
super().__init__()
|
| 45 |
+
model = dacvae.DACVAE(
|
| 46 |
+
encoder_dim=config.encoder_dim,
|
| 47 |
+
encoder_rates=config.encoder_rates,
|
| 48 |
+
latent_dim=config.latent_dim,
|
| 49 |
+
decoder_dim=config.decoder_dim,
|
| 50 |
+
decoder_rates=config.decoder_rates,
|
| 51 |
+
n_codebooks=config.n_codebooks,
|
| 52 |
+
codebook_size=config.codebook_size,
|
| 53 |
+
codebook_dim=config.codebook_dim,
|
| 54 |
+
quantizer_dropout=config.quantizer_dropout,
|
| 55 |
+
sample_rate=config.sample_rate,
|
| 56 |
+
).eval()
|
| 57 |
+
self._setup_model(model)
|
| 58 |
+
self.hop_length = config.hop_length
|
| 59 |
+
self.sample_rate = config.sample_rate
|
| 60 |
+
|
| 61 |
+
def _setup_model(self, model):
|
| 62 |
+
self.encoder = model.encoder
|
| 63 |
+
self.quantizer = model.quantizer
|
| 64 |
+
|
| 65 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
z = self.encoder(self._pad(waveform))
|
| 68 |
+
mean, scale = self.quantizer.in_proj(z).chunk(2, dim=1)
|
| 69 |
+
encoded_frames = mean
|
| 70 |
+
return encoded_frames
|
| 71 |
+
|
| 72 |
+
def _pad(self, wavs):
|
| 73 |
+
length = wavs.size(-1)
|
| 74 |
+
if length % self.hop_length:
|
| 75 |
+
p1d = (0, self.hop_length - (length % self.hop_length))
|
| 76 |
+
return torch.nn.functional.pad(wavs, p1d, "reflect")
|
| 77 |
+
else:
|
| 78 |
+
return wavs
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class DACVAE(DACVAEEncoder, Codec):
|
| 82 |
+
def _setup_model(self, model):
|
| 83 |
+
super()._setup_model(model)
|
| 84 |
+
self.decoder = model.decoder
|
| 85 |
+
|
| 86 |
+
def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
emb = self.quantizer.out_proj(encoded_frames)
|
| 88 |
+
return self.decoder(emb)
|
| 89 |
+
|
| 90 |
+
def feature_idx_to_wav_idx(self, feature_idx, sample_rate=None):
|
| 91 |
+
if sample_rate is None:
|
| 92 |
+
sample_rate = self.sample_rate
|
| 93 |
+
orig_freq = sample_rate
|
| 94 |
+
new_freq = self.sample_rate
|
| 95 |
+
wav_chunklen = feature_idx * self.hop_length * (orig_freq / new_freq)
|
| 96 |
+
return self.cast_to_int(wav_chunklen)
|
| 97 |
+
|
| 98 |
+
def wav_idx_to_feature_idx(self, wav_idx, sample_rate=None):
|
| 99 |
+
ceil = math.ceil
|
| 100 |
+
if torch.is_tensor(wav_idx):
|
| 101 |
+
ceil = torch.ceil
|
| 102 |
+
if sample_rate is None:
|
| 103 |
+
sample_rate = self.sample_rate
|
| 104 |
+
orig_freq = sample_rate
|
| 105 |
+
new_freq = self.sample_rate
|
| 106 |
+
target_length = ceil(new_freq * wav_idx / orig_freq)
|
| 107 |
+
res = ceil(target_length / self.hop_length)
|
| 108 |
+
return self.cast_to_int(res)
|
sam_audio/model/config.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from core.audio_visual_encoder.config import TransformerConfig as PEAVTransformerConfig
|
| 7 |
+
from transformers import ModernBertConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DACVAEConfig:
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
encoder_dim: int = 64,
|
| 14 |
+
encoder_rates: list[int] = [2, 8, 10, 12],
|
| 15 |
+
latent_dim: int = 1024,
|
| 16 |
+
decoder_dim: int = 1536,
|
| 17 |
+
decoder_rates: list[int] = [12, 10, 8, 2],
|
| 18 |
+
n_codebooks: int = 16,
|
| 19 |
+
codebook_size: int = 1024,
|
| 20 |
+
codebook_dim: int = 128,
|
| 21 |
+
quantizer_dropout: bool = False,
|
| 22 |
+
sample_rate: int = 48_000,
|
| 23 |
+
mean: float = 0.0,
|
| 24 |
+
std: float = 1.0,
|
| 25 |
+
):
|
| 26 |
+
self.encoder_dim = encoder_dim
|
| 27 |
+
self.encoder_rates = encoder_rates
|
| 28 |
+
self.latent_dim = latent_dim
|
| 29 |
+
self.decoder_dim = decoder_dim
|
| 30 |
+
self.decoder_rates = decoder_rates
|
| 31 |
+
self.n_codebooks = n_codebooks
|
| 32 |
+
self.codebook_size = codebook_size
|
| 33 |
+
self.codebook_dim = codebook_dim
|
| 34 |
+
self.quantizer_dropout = quantizer_dropout
|
| 35 |
+
self.sample_rate = sample_rate
|
| 36 |
+
self.mean = mean
|
| 37 |
+
self.std = std
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def hop_length(self):
|
| 41 |
+
return int(np.prod(self.encoder_rates))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TextEncoderConfig:
|
| 45 |
+
def __init__(self, dim: int = 768):
|
| 46 |
+
self.dim = dim
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class T5EncoderConfig(TextEncoderConfig):
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
name: str = "t5-base",
|
| 53 |
+
max_length: Optional[int] = 512,
|
| 54 |
+
pad_mode: str = "longest",
|
| 55 |
+
dim: int = 768,
|
| 56 |
+
):
|
| 57 |
+
super().__init__(dim=dim)
|
| 58 |
+
self.name = name
|
| 59 |
+
self.max_length = max_length
|
| 60 |
+
self.pad_mode = pad_mode
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class VisionEncoderConfig:
|
| 64 |
+
def __init__(self, dim: int = 1024, batch_size: int = 300):
|
| 65 |
+
self.dim = dim
|
| 66 |
+
self.batch_size = batch_size
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class PerceptionEncoderConfig(VisionEncoderConfig):
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
dim: int = 1024,
|
| 73 |
+
batch_size: int = 300,
|
| 74 |
+
name: str = "PE-Core-L14-336",
|
| 75 |
+
normalize_feature: bool = True,
|
| 76 |
+
interpolation_mode: str = "BICUBIC",
|
| 77 |
+
image_size: int = 336,
|
| 78 |
+
):
|
| 79 |
+
super().__init__(dim=dim, batch_size=batch_size)
|
| 80 |
+
self.name = name
|
| 81 |
+
self.normalize_feature = normalize_feature
|
| 82 |
+
self.interpolation_mode = interpolation_mode
|
| 83 |
+
self.image_size = image_size
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class TransformerConfig:
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
dim: int = 2048,
|
| 90 |
+
n_heads: int = 16,
|
| 91 |
+
n_layers: int = 16,
|
| 92 |
+
dropout: float = 0.1,
|
| 93 |
+
norm_eps: float = 1.0e-05,
|
| 94 |
+
qk_norm: bool = True,
|
| 95 |
+
fc_bias: bool = False,
|
| 96 |
+
ffn_exp: int = 4,
|
| 97 |
+
ffn_dim_multiplier: int = 1,
|
| 98 |
+
multiple_of: int = 64,
|
| 99 |
+
non_linearity: str = "swiglu",
|
| 100 |
+
use_rope: bool = True,
|
| 101 |
+
max_positions: int = 10000,
|
| 102 |
+
frequency_embedding_dim: int = 256,
|
| 103 |
+
timestep_non_linearity: str = "swiglu",
|
| 104 |
+
t_block_non_linearity: str = "silu",
|
| 105 |
+
t_block_bias: bool = True,
|
| 106 |
+
context_dim: int = 2048,
|
| 107 |
+
context_non_linearity: str = "swiglu",
|
| 108 |
+
context_embedder_dropout: float = 0.0,
|
| 109 |
+
context_norm: bool = False,
|
| 110 |
+
out_channels: int = 256,
|
| 111 |
+
in_channels: Optional[int] = None,
|
| 112 |
+
):
|
| 113 |
+
self.dim = dim
|
| 114 |
+
self.n_heads = n_heads
|
| 115 |
+
self.n_layers = n_layers
|
| 116 |
+
self.dropout = dropout
|
| 117 |
+
self.norm_eps = norm_eps
|
| 118 |
+
self.qk_norm = qk_norm
|
| 119 |
+
self.fc_bias = fc_bias
|
| 120 |
+
self.ffn_exp = ffn_exp
|
| 121 |
+
self.ffn_dim_multiplier = ffn_dim_multiplier
|
| 122 |
+
self.multiple_of = multiple_of
|
| 123 |
+
self.non_linearity = non_linearity
|
| 124 |
+
self.use_rope = use_rope
|
| 125 |
+
self.max_positions = max_positions
|
| 126 |
+
self.frequency_embedding_dim = frequency_embedding_dim
|
| 127 |
+
self.timestep_non_linearity = timestep_non_linearity
|
| 128 |
+
self.t_block_non_linearity = t_block_non_linearity
|
| 129 |
+
self.t_block_bias = t_block_bias
|
| 130 |
+
self.context_dim = context_dim
|
| 131 |
+
self.context_non_linearity = context_non_linearity
|
| 132 |
+
self.context_embedder_dropout = context_embedder_dropout
|
| 133 |
+
self.context_norm = context_norm
|
| 134 |
+
self.out_channels = out_channels
|
| 135 |
+
self.in_channels = in_channels
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class RankerConfig:
|
| 139 |
+
kind: str
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class ImageBindRankerConfig(RankerConfig):
|
| 143 |
+
kind: str = "imagebind"
|
| 144 |
+
|
| 145 |
+
def __init__(self, checkpoint: Optional[str] = None):
|
| 146 |
+
self.checkpoint = checkpoint
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class ClapRankerConfig(RankerConfig):
|
| 150 |
+
kind: str = "clap"
|
| 151 |
+
|
| 152 |
+
def __init__(self, checkpoint: Optional[str] = None):
|
| 153 |
+
self.checkpoint = checkpoint
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class JudgeRankerConfig(RankerConfig):
|
| 157 |
+
kind: str = "judge"
|
| 158 |
+
|
| 159 |
+
def __init__(self, checkpoint_or_model_id: str = "facebook/sam-audio-judge"):
|
| 160 |
+
self.checkpoint_or_model_id = checkpoint_or_model_id
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class SoundActivityRankerConfig(RankerConfig):
|
| 164 |
+
kind: str = "sound_activity"
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
threshold_mode: str = "rel_to_max",
|
| 169 |
+
sil_threshold: float = -40,
|
| 170 |
+
metric: str = "iou",
|
| 171 |
+
):
|
| 172 |
+
self.threshold_mode = threshold_mode
|
| 173 |
+
self.sil_threshold = sil_threshold
|
| 174 |
+
self.metric = metric
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class EnsembleRankerConfig(RankerConfig):
|
| 178 |
+
kind: str = "ensemble"
|
| 179 |
+
|
| 180 |
+
def __init__(self, rankers: dict[str, Tuple[RankerConfig, float]]):
|
| 181 |
+
self.rankers = rankers
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def parse_ranker_config(config_dict: dict):
|
| 185 |
+
kind = config_dict.pop("kind")
|
| 186 |
+
match kind:
|
| 187 |
+
case ImageBindRankerConfig.kind:
|
| 188 |
+
return ImageBindRankerConfig(**config_dict)
|
| 189 |
+
case ClapRankerConfig.kind:
|
| 190 |
+
return ClapRankerConfig(**config_dict)
|
| 191 |
+
case JudgeRankerConfig.kind:
|
| 192 |
+
return JudgeRankerConfig(**config_dict)
|
| 193 |
+
case SoundActivityRankerConfig.kind:
|
| 194 |
+
return SoundActivityRankerConfig(**config_dict)
|
| 195 |
+
case EnsembleRankerConfig.kind:
|
| 196 |
+
return EnsembleRankerConfig(
|
| 197 |
+
{
|
| 198 |
+
k: (parse_ranker_config(v), w)
|
| 199 |
+
for k, (v, w) in config_dict["rankers"].items()
|
| 200 |
+
}
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class SAMAudioConfig:
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
in_channels: int = 768,
|
| 208 |
+
audio_codec=None,
|
| 209 |
+
text_encoder=None,
|
| 210 |
+
vision_encoder=None,
|
| 211 |
+
transformer=None,
|
| 212 |
+
num_anchors: int = 3,
|
| 213 |
+
anchor_embedding_dim: int = 128,
|
| 214 |
+
visual_ranker=None,
|
| 215 |
+
text_ranker=None,
|
| 216 |
+
span_predictor: Optional[str] = "pe-a-frame-large",
|
| 217 |
+
):
|
| 218 |
+
self.in_channels = in_channels
|
| 219 |
+
self.audio_codec = DACVAEConfig(**(audio_codec or {}))
|
| 220 |
+
self.text_encoder = T5EncoderConfig(**(text_encoder or {}))
|
| 221 |
+
self.vision_encoder = PerceptionEncoderConfig(**(vision_encoder or {}))
|
| 222 |
+
self.transformer = TransformerConfig(**(transformer or {}))
|
| 223 |
+
self.num_anchors = num_anchors
|
| 224 |
+
self.anchor_embedding_dim = anchor_embedding_dim
|
| 225 |
+
self.visual_ranker = (
|
| 226 |
+
None if visual_ranker is None else parse_ranker_config(visual_ranker)
|
| 227 |
+
)
|
| 228 |
+
self.text_ranker = (
|
| 229 |
+
None if text_ranker is None else parse_ranker_config(text_ranker)
|
| 230 |
+
)
|
| 231 |
+
self.span_predictor = span_predictor
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class SAMAudioJudgeConfig:
|
| 235 |
+
def __init__(
|
| 236 |
+
self,
|
| 237 |
+
audio_codec: DACVAEConfig = None,
|
| 238 |
+
transformer: PEAVTransformerConfig = None,
|
| 239 |
+
text_model: ModernBertConfig = None,
|
| 240 |
+
finetune_transformer: PEAVTransformerConfig = None,
|
| 241 |
+
nth_text_layer: int = 22,
|
| 242 |
+
bottleneck_dim: int = 256,
|
| 243 |
+
):
|
| 244 |
+
self.audio_codec = DACVAEConfig(**(audio_codec or {}))
|
| 245 |
+
self.transformer = PEAVTransformerConfig(**(transformer or {}))
|
| 246 |
+
self.text_model = ModernBertConfig(**(text_model or {}))
|
| 247 |
+
self.finetune_transformer = PEAVTransformerConfig(
|
| 248 |
+
**(finetune_transformer or {})
|
| 249 |
+
)
|
| 250 |
+
self.nth_text_layer = nth_text_layer
|
| 251 |
+
self.bottleneck_dim = bottleneck_dim
|
sam_audio/model/judge.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from core.audio_visual_encoder.transformer import BaseModelOutputWithPooling
|
| 8 |
+
from core.audio_visual_encoder.transformer import Transformer as PEAVTransformer
|
| 9 |
+
from transformers import AutoModel
|
| 10 |
+
|
| 11 |
+
from .base import BaseModel
|
| 12 |
+
from .codec import DACVAEEncoder
|
| 13 |
+
from .config import SAMAudioJudgeConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class SAMAudioJudgeOutput:
|
| 18 |
+
r"""
|
| 19 |
+
overall (torch.Tensor, optional): Overall score tensor of shape (batch_size, 1).
|
| 20 |
+
recall (torch.Tensor, optional): Recall score tensor of shape (batch_size, 1).
|
| 21 |
+
precision (torch.Tensor, optional): Precision score tensor of shape (batch_size, 1).
|
| 22 |
+
faithfulness (torch.Tensor, optional): Faithfulness score tensor of shape (batch_size, 1).
|
| 23 |
+
text_model_output (BaseModelOutputWithPooling): Output from the text model.
|
| 24 |
+
audio_model_output (BaseModelOutputWithPooling): Output from the audio model.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
overall: Optional[torch.Tensor] = None
|
| 28 |
+
recall: Optional[torch.Tensor] = None
|
| 29 |
+
precision: Optional[torch.Tensor] = None
|
| 30 |
+
faithfulness: Optional[torch.Tensor] = None
|
| 31 |
+
text_model_output: BaseModelOutputWithPooling = None
|
| 32 |
+
audio_model_output: BaseModelOutputWithPooling = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SAMAudioJudgeModel(BaseModel):
|
| 36 |
+
config_cls = SAMAudioJudgeConfig
|
| 37 |
+
revision = "sam_audio"
|
| 38 |
+
|
| 39 |
+
def __init__(self, config: SAMAudioJudgeConfig):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.config = config
|
| 42 |
+
self.data_proj = torch.nn.Linear(
|
| 43 |
+
config.audio_codec.codebook_dim, config.transformer.hidden_size
|
| 44 |
+
)
|
| 45 |
+
self.audio_codec = DACVAEEncoder(config.audio_codec)
|
| 46 |
+
self.transformer = PEAVTransformer(config.transformer)
|
| 47 |
+
self.finetune_transformer = PEAVTransformer(config.finetune_transformer)
|
| 48 |
+
self.text_model = AutoModel.from_config(config.text_model)
|
| 49 |
+
self.cat_audio_proj = torch.nn.Linear(
|
| 50 |
+
2 * config.transformer.hidden_size, config.bottleneck_dim
|
| 51 |
+
)
|
| 52 |
+
self.text_proj1 = torch.nn.Linear(
|
| 53 |
+
in_features=config.text_model.hidden_size,
|
| 54 |
+
out_features=config.transformer.hidden_size,
|
| 55 |
+
bias=False,
|
| 56 |
+
)
|
| 57 |
+
self.text_proj2 = torch.nn.Linear(
|
| 58 |
+
in_features=config.transformer.hidden_size,
|
| 59 |
+
out_features=config.bottleneck_dim,
|
| 60 |
+
)
|
| 61 |
+
self.layer_norm = torch.nn.LayerNorm(config.bottleneck_dim)
|
| 62 |
+
self.proj_audio_and_text = torch.nn.Linear(
|
| 63 |
+
2 * config.bottleneck_dim, config.bottleneck_dim
|
| 64 |
+
)
|
| 65 |
+
self.finetune_data_proj = torch.nn.Linear(
|
| 66 |
+
config.bottleneck_dim, config.finetune_transformer.hidden_size
|
| 67 |
+
)
|
| 68 |
+
self.head = torch.nn.Linear(
|
| 69 |
+
config.finetune_transformer.hidden_size, 4, bias=False
|
| 70 |
+
)
|
| 71 |
+
self.mean = torch.nn.Parameter(torch.zeros(4, requires_grad=False))
|
| 72 |
+
self.std = torch.nn.Parameter(torch.ones(4, requires_grad=False))
|
| 73 |
+
|
| 74 |
+
def _get_text_output(self, input_ids, attention_mask):
|
| 75 |
+
nth_layer = self.config.nth_text_layer
|
| 76 |
+
output = self.text_model(
|
| 77 |
+
input_ids=input_ids,
|
| 78 |
+
attention_mask=attention_mask,
|
| 79 |
+
output_hidden_states=nth_layer is not None,
|
| 80 |
+
)
|
| 81 |
+
if nth_layer is None:
|
| 82 |
+
text_model_output = output.last_hidden_state
|
| 83 |
+
else:
|
| 84 |
+
text_model_output = output.hidden_states[nth_layer]
|
| 85 |
+
|
| 86 |
+
return BaseModelOutputWithPooling(
|
| 87 |
+
last_hidden_state=text_model_output, pooler_output=text_model_output[:, 0]
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def forward(
|
| 91 |
+
self,
|
| 92 |
+
input_ids: torch.Tensor, # tokenized text
|
| 93 |
+
input_values: torch.Tensor, # input audio waveform
|
| 94 |
+
separated_values: torch.Tensor, # separated audio waveform
|
| 95 |
+
attention_mask: Optional[torch.Tensor] = None, # text attention mask
|
| 96 |
+
padding_mask: Optional[torch.Tensor] = None, # audio padding mask
|
| 97 |
+
) -> SAMAudioJudgeOutput:
|
| 98 |
+
text_features = self.text_proj1(
|
| 99 |
+
self._get_text_output(input_ids, attention_mask).pooler_output
|
| 100 |
+
)
|
| 101 |
+
stacked_audios = torch.cat([input_values, separated_values], dim=0)
|
| 102 |
+
stacked_codec_features = self.audio_codec(stacked_audios)
|
| 103 |
+
feature_padding_mask = None
|
| 104 |
+
if padding_mask is not None:
|
| 105 |
+
feature_padding_mask = padding_mask[
|
| 106 |
+
:, :: self.config.audio_codec.hop_length
|
| 107 |
+
]
|
| 108 |
+
stacked_features = self.transformer(
|
| 109 |
+
self.data_proj(stacked_codec_features.transpose(1, 2)),
|
| 110 |
+
padding_mask=feature_padding_mask,
|
| 111 |
+
)
|
| 112 |
+
input_features, hyp_features = stacked_features.last_hidden_state.chunk(2, 0)
|
| 113 |
+
audio_features = self.cat_audio_proj(
|
| 114 |
+
torch.cat([hyp_features, input_features], dim=2)
|
| 115 |
+
)
|
| 116 |
+
expanded_text = (
|
| 117 |
+
self.layer_norm(self.text_proj2(text_features))
|
| 118 |
+
.unsqueeze(1)
|
| 119 |
+
.expand_as(audio_features)
|
| 120 |
+
)
|
| 121 |
+
audio_and_text = self.proj_audio_and_text(
|
| 122 |
+
torch.cat([audio_features, expanded_text], dim=2)
|
| 123 |
+
)
|
| 124 |
+
finetune_transformer_output = self.finetune_transformer(
|
| 125 |
+
self.finetune_data_proj(audio_and_text), padding_mask=feature_padding_mask
|
| 126 |
+
)
|
| 127 |
+
result = self.head(finetune_transformer_output.last_hidden_state)
|
| 128 |
+
if feature_padding_mask is not None:
|
| 129 |
+
feature_padding_mask = feature_padding_mask.unsqueeze(-1)
|
| 130 |
+
pooled = torch.masked.mean(result, mask=feature_padding_mask, dim=1)
|
| 131 |
+
de_normalized = pooled * self.std + self.mean
|
| 132 |
+
return SAMAudioJudgeOutput(*de_normalized.chunk(4, dim=1))
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
__all__ = ["SAMAudioJudgeModel", "SAMAudioJudgeOutput"]
|
sam_audio/model/model.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import re
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, Dict, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from core.audio_visual_encoder import PEAudioFrame, PEAudioFrameTransform
|
| 10 |
+
from torchdiffeq import odeint
|
| 11 |
+
|
| 12 |
+
from sam_audio.model.align import AlignModalities
|
| 13 |
+
from sam_audio.model.base import BaseModel
|
| 14 |
+
from sam_audio.model.codec import DACVAE
|
| 15 |
+
from sam_audio.model.config import SAMAudioConfig
|
| 16 |
+
from sam_audio.model.text_encoder import T5TextEncoder
|
| 17 |
+
from sam_audio.model.transformer import DiT
|
| 18 |
+
from sam_audio.model.vision_encoder import PerceptionEncoder
|
| 19 |
+
from sam_audio.processor import Batch
|
| 20 |
+
from sam_audio.ranking import create_ranker
|
| 21 |
+
|
| 22 |
+
DFLT_ODE_OPT = {"method": "midpoint", "options": {"step_size": 2 / 32}}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SinusoidalEmbedding(torch.nn.Module):
|
| 26 |
+
def __init__(self, dim, theta=10000):
|
| 27 |
+
super().__init__()
|
| 28 |
+
assert (dim % 2) == 0
|
| 29 |
+
half_dim = dim // 2
|
| 30 |
+
inv_freq = torch.exp(
|
| 31 |
+
-math.log(theta) * torch.arange(half_dim).float() / half_dim
|
| 32 |
+
)
|
| 33 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 34 |
+
|
| 35 |
+
def forward(self, x, pos=None):
|
| 36 |
+
if pos is None:
|
| 37 |
+
seq_len, device = x.shape[1], x.device
|
| 38 |
+
pos = torch.arange(seq_len, device=device)
|
| 39 |
+
|
| 40 |
+
emb = torch.einsum("i, j -> i j", pos, self.inv_freq)
|
| 41 |
+
emb = torch.cat((emb.cos(), emb.sin()), dim=-1)
|
| 42 |
+
return emb
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class EmbedAnchors(torch.nn.Module):
|
| 46 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.embed = torch.nn.Embedding(
|
| 49 |
+
num_embeddings + 1, embedding_dim, padding_idx=num_embeddings
|
| 50 |
+
)
|
| 51 |
+
self.gate = torch.nn.Parameter(torch.tensor([0.0]))
|
| 52 |
+
self.proj = torch.nn.Linear(embedding_dim, out_dim, bias=False)
|
| 53 |
+
|
| 54 |
+
def forward(
|
| 55 |
+
self,
|
| 56 |
+
x: torch.Tensor,
|
| 57 |
+
anchor_ids: Optional[torch.Tensor] = None,
|
| 58 |
+
anchor_alignment: Optional[torch.Tensor] = None,
|
| 59 |
+
):
|
| 60 |
+
if anchor_ids is None:
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
embs = self.embed(anchor_ids.gather(1, anchor_alignment))
|
| 64 |
+
proj = self.proj(embs)
|
| 65 |
+
return x + self.gate.tanh() * proj
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class SeparationResult:
|
| 70 |
+
target: torch.Tensor
|
| 71 |
+
residual: torch.Tensor
|
| 72 |
+
noise: torch.Tensor
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class SAMAudio(BaseModel):
|
| 76 |
+
config_cls = SAMAudioConfig
|
| 77 |
+
revision = None
|
| 78 |
+
|
| 79 |
+
def __init__(self, cfg: SAMAudioConfig):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.audio_codec = DACVAE(cfg.audio_codec)
|
| 82 |
+
self.text_encoder = T5TextEncoder(cfg.text_encoder)
|
| 83 |
+
self.vision_encoder = PerceptionEncoder(cfg.vision_encoder)
|
| 84 |
+
self.transformer = DiT(cfg.transformer)
|
| 85 |
+
self.proj = torch.nn.Linear(cfg.in_channels, cfg.transformer.dim)
|
| 86 |
+
self.align_masked_video = AlignModalities(
|
| 87 |
+
cfg.vision_encoder.dim, cfg.transformer.dim
|
| 88 |
+
)
|
| 89 |
+
self.embed_anchors = EmbedAnchors(
|
| 90 |
+
cfg.num_anchors, cfg.anchor_embedding_dim, cfg.transformer.dim
|
| 91 |
+
)
|
| 92 |
+
self.memory_proj = torch.nn.Linear(cfg.text_encoder.dim, cfg.transformer.dim)
|
| 93 |
+
self.timestep_emb = SinusoidalEmbedding(cfg.transformer.dim)
|
| 94 |
+
self.visual_ranker = create_ranker(cfg.visual_ranker)
|
| 95 |
+
self.text_ranker = create_ranker(cfg.text_ranker)
|
| 96 |
+
if cfg.span_predictor is not None:
|
| 97 |
+
self.span_predictor = PEAudioFrame.from_config(
|
| 98 |
+
cfg.span_predictor, pretrained=True
|
| 99 |
+
)
|
| 100 |
+
self.span_predictor_transform = PEAudioFrameTransform.from_config(
|
| 101 |
+
cfg.span_predictor
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def sample_rate(self):
|
| 106 |
+
return self.audio_codec.sample_rate
|
| 107 |
+
|
| 108 |
+
def align_inputs(
|
| 109 |
+
self,
|
| 110 |
+
noisy_audio,
|
| 111 |
+
audio_features: torch.Tensor,
|
| 112 |
+
masked_video_features: Optional[torch.Tensor] = None,
|
| 113 |
+
anchor_ids: Optional[torch.Tensor] = None,
|
| 114 |
+
anchor_alignment: Optional[torch.Tensor] = None,
|
| 115 |
+
):
|
| 116 |
+
x = torch.cat(
|
| 117 |
+
[
|
| 118 |
+
noisy_audio,
|
| 119 |
+
torch.zeros_like(audio_features),
|
| 120 |
+
audio_features,
|
| 121 |
+
],
|
| 122 |
+
dim=2,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
projected = self.proj(x)
|
| 126 |
+
aligned = self.align_masked_video(projected, masked_video_features)
|
| 127 |
+
aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment)
|
| 128 |
+
return aligned
|
| 129 |
+
|
| 130 |
+
def forward(
|
| 131 |
+
self,
|
| 132 |
+
noisy_audio: torch.Tensor,
|
| 133 |
+
audio_features: torch.Tensor,
|
| 134 |
+
text_features: torch.Tensor,
|
| 135 |
+
time: torch.Tensor,
|
| 136 |
+
masked_video_features: Optional[torch.Tensor] = None,
|
| 137 |
+
text_mask: Optional[torch.Tensor] = None,
|
| 138 |
+
anchor_ids: Optional[torch.Tensor] = None,
|
| 139 |
+
anchor_alignment: Optional[torch.Tensor] = None,
|
| 140 |
+
audio_pad_mask: Optional[torch.Tensor] = None,
|
| 141 |
+
):
|
| 142 |
+
"""
|
| 143 |
+
Forward pass for the model. Represents one function evaluation of the ODE.
|
| 144 |
+
In the below descriptions, B is batch size, T is sequence length, C is channel size.
|
| 145 |
+
Note that the size of C and T may vary across arguments (ex. text_features vs. audio_features),
|
| 146 |
+
it is used only to designate a Channel or time/sequence-length dimension respectively.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
noisy_audio (torch.Tensor): Noisy audio input tensor (being denoised).
|
| 150 |
+
audio_features (torch.Tensor): Clean audio features [B x T x C].
|
| 151 |
+
text_features (torch.Tensor): Encoded text features tensor [B x T x C].
|
| 152 |
+
time (torch.Tensor): Timestep tensor for positional encoding [B].
|
| 153 |
+
masked_video_features (Optional[torch.Tensor], optional): Masked video features tensor. [B x C x T].
|
| 154 |
+
text_mask (Optional[torch.Tensor], optional): Padding mask for text features. [B x T].
|
| 155 |
+
anchor_ids (Optional[torch.Tensor], optional): Anchor IDs tensor. Defaults to None [B x T].
|
| 156 |
+
anchor_alignment (Optional[torch.Tensor], optional): Anchor alignment tensor. B x T.
|
| 157 |
+
audio_pad_mask (Optional[torch.Tensor], optional): Padding mask for audio input. [B x T].
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
torch.Tensor
|
| 161 |
+
"""
|
| 162 |
+
aligned_inputs = self.align_inputs(
|
| 163 |
+
noisy_audio,
|
| 164 |
+
audio_features,
|
| 165 |
+
masked_video_features=masked_video_features,
|
| 166 |
+
anchor_ids=anchor_ids,
|
| 167 |
+
anchor_alignment=anchor_alignment,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
memory = timestep_emb = self.timestep_emb(time, pos=time).unsqueeze(1)
|
| 171 |
+
if text_features is not None:
|
| 172 |
+
memory = self.memory_proj(text_features) + timestep_emb
|
| 173 |
+
|
| 174 |
+
return self.transformer(
|
| 175 |
+
aligned_inputs,
|
| 176 |
+
time,
|
| 177 |
+
padding_mask=audio_pad_mask,
|
| 178 |
+
memory=memory,
|
| 179 |
+
memory_padding_mask=text_mask,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def _get_audio_features(self, audios: torch.Tensor):
|
| 183 |
+
audio_features = self.audio_codec(audios).transpose(1, 2)
|
| 184 |
+
return torch.cat([audio_features, audio_features], dim=2)
|
| 185 |
+
|
| 186 |
+
def _get_video_features(self, video, audio_features):
|
| 187 |
+
B, T, _ = audio_features.shape
|
| 188 |
+
if video is None:
|
| 189 |
+
return audio_features.new_zeros(B, self.vision_encoder.dim, T)
|
| 190 |
+
else:
|
| 191 |
+
return self.vision_encoder(video).transpose(1, 2)
|
| 192 |
+
|
| 193 |
+
def _repeat_for_reranking(self, tensor, candidates):
|
| 194 |
+
if candidates > 1:
|
| 195 |
+
B = tensor.size(0)
|
| 196 |
+
rest = tensor.shape[1:]
|
| 197 |
+
return (
|
| 198 |
+
tensor.unsqueeze(1)
|
| 199 |
+
.expand(B, candidates, *rest)
|
| 200 |
+
.reshape(B * candidates, *rest)
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
return tensor
|
| 204 |
+
|
| 205 |
+
def _unrepeat_from_reranking(self, tensor, candidates):
|
| 206 |
+
return tensor[::candidates]
|
| 207 |
+
|
| 208 |
+
def _get_forward_args(self, batch: Batch, candidates: int = 1):
|
| 209 |
+
audio_features = self._get_audio_features(batch.audios)
|
| 210 |
+
text_features, text_mask = self.text_encoder(batch.descriptions)
|
| 211 |
+
masked_video_features = self._get_video_features(
|
| 212 |
+
batch.masked_video, audio_features
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
return {
|
| 216 |
+
"audio_features": self._repeat_for_reranking(audio_features, candidates),
|
| 217 |
+
"text_features": self._repeat_for_reranking(text_features, candidates),
|
| 218 |
+
"text_mask": self._repeat_for_reranking(text_mask, candidates),
|
| 219 |
+
"masked_video_features": self._repeat_for_reranking(
|
| 220 |
+
masked_video_features, candidates
|
| 221 |
+
),
|
| 222 |
+
"anchor_ids": self._repeat_for_reranking(batch.anchor_ids, candidates),
|
| 223 |
+
"anchor_alignment": self._repeat_for_reranking(
|
| 224 |
+
batch.anchor_alignment, candidates
|
| 225 |
+
),
|
| 226 |
+
"audio_pad_mask": self._repeat_for_reranking(
|
| 227 |
+
batch.audio_pad_mask, candidates
|
| 228 |
+
),
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
def predict_spans(
|
| 232 |
+
self, batch: Batch, audio_features: torch.Tensor, audio_pad_mask: torch.Tensor
|
| 233 |
+
) -> Batch:
|
| 234 |
+
input = self.span_predictor_transform(text=batch.descriptions).to(
|
| 235 |
+
audio_features.device
|
| 236 |
+
)
|
| 237 |
+
output = self.span_predictor(
|
| 238 |
+
input_features=audio_features[:, :, :128],
|
| 239 |
+
padding_mask=audio_pad_mask,
|
| 240 |
+
return_spans=True,
|
| 241 |
+
**input,
|
| 242 |
+
)
|
| 243 |
+
anchors = [[["+"] + anchor for anchor in anchors] for anchors in output.spans]
|
| 244 |
+
batch.process_anchors(anchors)
|
| 245 |
+
return batch
|
| 246 |
+
|
| 247 |
+
@torch.inference_mode()
|
| 248 |
+
def separate(
|
| 249 |
+
self,
|
| 250 |
+
batch: Batch,
|
| 251 |
+
noise: Optional[torch.Tensor] = None,
|
| 252 |
+
ode_opt: Dict[str, Any] = DFLT_ODE_OPT,
|
| 253 |
+
reranking_candidates: int = 1,
|
| 254 |
+
predict_spans: bool = False,
|
| 255 |
+
) -> SeparationResult:
|
| 256 |
+
# Encode audio
|
| 257 |
+
forward_args = self._get_forward_args(batch, candidates=reranking_candidates)
|
| 258 |
+
|
| 259 |
+
if predict_spans and hasattr(self, "span_predictor") and batch.anchors is None:
|
| 260 |
+
batch = self.predict_spans(
|
| 261 |
+
batch=batch,
|
| 262 |
+
audio_features=self._unrepeat_from_reranking(
|
| 263 |
+
forward_args["audio_features"], reranking_candidates
|
| 264 |
+
),
|
| 265 |
+
audio_pad_mask=self._unrepeat_from_reranking(
|
| 266 |
+
forward_args["audio_pad_mask"], reranking_candidates
|
| 267 |
+
),
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
audio_features = forward_args["audio_features"]
|
| 271 |
+
B, T, C = audio_features.shape
|
| 272 |
+
C = C // 2 # we stack audio_features, so the actual channels is half
|
| 273 |
+
|
| 274 |
+
if noise is None:
|
| 275 |
+
noise = torch.randn_like(audio_features)
|
| 276 |
+
|
| 277 |
+
def vector_field(t, noisy_audio):
|
| 278 |
+
res = self.forward(
|
| 279 |
+
noisy_audio=noisy_audio,
|
| 280 |
+
time=t.expand(noisy_audio.size(0)),
|
| 281 |
+
**forward_args,
|
| 282 |
+
)
|
| 283 |
+
return res
|
| 284 |
+
|
| 285 |
+
states = odeint(
|
| 286 |
+
vector_field,
|
| 287 |
+
noise,
|
| 288 |
+
torch.tensor([0.0, 1.0], device=noise.device),
|
| 289 |
+
**ode_opt,
|
| 290 |
+
)
|
| 291 |
+
generated_features = states[-1].transpose(1, 2)
|
| 292 |
+
# generated_features has shape [B, 2C, T]. Reshape to stack along the batch dimension
|
| 293 |
+
wavs = self.audio_codec.decode(generated_features.reshape(2 * B, C, T)).view(
|
| 294 |
+
B, 2, -1
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
bsz = wavs.size(0) // reranking_candidates
|
| 298 |
+
sizes = self.audio_codec.feature_idx_to_wav_idx(batch.sizes)
|
| 299 |
+
target_wavs = self.unbatch(
|
| 300 |
+
wavs[:, 0].view(bsz, reranking_candidates, -1), sizes
|
| 301 |
+
)
|
| 302 |
+
residual_wavs = self.unbatch(
|
| 303 |
+
wavs[:, 1].view(bsz, reranking_candidates, -1), sizes
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
if (
|
| 307 |
+
reranking_candidates > 1
|
| 308 |
+
and batch.masked_video is not None
|
| 309 |
+
and self.visual_ranker is not None
|
| 310 |
+
):
|
| 311 |
+
scores = self.visual_ranker(
|
| 312 |
+
extracted_audio=target_wavs,
|
| 313 |
+
videos=batch.masked_video,
|
| 314 |
+
sample_rate=self.audio_codec.sample_rate,
|
| 315 |
+
)
|
| 316 |
+
idxs = scores.argmax(dim=1)
|
| 317 |
+
elif reranking_candidates > 1 and self.text_ranker is not None:
|
| 318 |
+
input_audio = [
|
| 319 |
+
audio[:, :size].expand(reranking_candidates, -1)
|
| 320 |
+
for audio, size in zip(batch.audios, sizes, strict=False)
|
| 321 |
+
]
|
| 322 |
+
scores = self.text_ranker(
|
| 323 |
+
extracted_audio=target_wavs,
|
| 324 |
+
input_audio=input_audio,
|
| 325 |
+
descriptions=batch.descriptions,
|
| 326 |
+
sample_rate=self.audio_codec.sample_rate,
|
| 327 |
+
)
|
| 328 |
+
idxs = scores.argmax(dim=1)
|
| 329 |
+
else:
|
| 330 |
+
idxs = torch.zeros(bsz, dtype=torch.long, device=noise.device)
|
| 331 |
+
|
| 332 |
+
return SeparationResult(
|
| 333 |
+
target=[wav[idx] for wav, idx in zip(target_wavs, idxs, strict=False)],
|
| 334 |
+
residual=[
|
| 335 |
+
wavs[idx] for wavs, idx in zip(residual_wavs, idxs, strict=False)
|
| 336 |
+
],
|
| 337 |
+
noise=noise,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
def unbatch(self, wavs: torch.Tensor, sizes: torch.Tensor, time_dim: int = -1):
|
| 341 |
+
result = []
|
| 342 |
+
for row, size in zip(wavs, sizes, strict=False):
|
| 343 |
+
result.append(row.narrow(dim=time_dim, start=0, length=size))
|
| 344 |
+
return result
|
| 345 |
+
|
| 346 |
+
def load_state_dict(self, state_dict, strict=True):
|
| 347 |
+
if strict:
|
| 348 |
+
missing_keys, unexpected_keys = super().load_state_dict(
|
| 349 |
+
state_dict, strict=False
|
| 350 |
+
)
|
| 351 |
+
# We load this directly from HF, not in checkpoint
|
| 352 |
+
skip_regex = re.compile(
|
| 353 |
+
"(^text_encoder|^visual_ranker|^text_ranker|^span_predictor)"
|
| 354 |
+
)
|
| 355 |
+
missing_keys = [x for x in missing_keys if not re.search(skip_regex, x)]
|
| 356 |
+
if len(missing_keys) > 0 or len(unexpected_keys) > 0:
|
| 357 |
+
raise RuntimeError(
|
| 358 |
+
f"Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}"
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
__all__ = ["SAMAudio"]
|
sam_audio/model/patcher.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def pad1d(
|
| 12 |
+
x: torch.Tensor,
|
| 13 |
+
paddings: Tuple[int, int],
|
| 14 |
+
mode: str = "constant",
|
| 15 |
+
value: float = 0.0,
|
| 16 |
+
):
|
| 17 |
+
# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py
|
| 18 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
| 19 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
| 20 |
+
"""
|
| 21 |
+
length = x.shape[-1]
|
| 22 |
+
padding_left, padding_right = paddings
|
| 23 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 24 |
+
if mode == "reflect":
|
| 25 |
+
max_pad = max(padding_left, padding_right)
|
| 26 |
+
extra_pad = 0
|
| 27 |
+
if length <= max_pad:
|
| 28 |
+
extra_pad = max_pad - length + 1
|
| 29 |
+
x = F.pad(x, (0, extra_pad))
|
| 30 |
+
padded = F.pad(x, paddings, mode, value)
|
| 31 |
+
end = padded.shape[-1] - extra_pad
|
| 32 |
+
return padded[..., :end]
|
| 33 |
+
else:
|
| 34 |
+
return F.pad(x, paddings, mode, value)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_extra_padding_for_conv1d(
|
| 38 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 39 |
+
) -> int:
|
| 40 |
+
# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py
|
| 41 |
+
"""See `pad_for_conv1d`."""
|
| 42 |
+
length = x.shape[-1]
|
| 43 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 44 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 45 |
+
return ideal_length - length
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Conv1d(torch.nn.Conv1d):
|
| 49 |
+
def __init__(self, *args, **kwargs):
|
| 50 |
+
super().__init__(*args, **kwargs)
|
| 51 |
+
|
| 52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
kernel_size = self.kernel_size[0]
|
| 54 |
+
stride = self.stride[0]
|
| 55 |
+
dilation = self.dilation[0]
|
| 56 |
+
kernel_size = (
|
| 57 |
+
kernel_size - 1
|
| 58 |
+
) * dilation + 1 # effective kernel size with dilations
|
| 59 |
+
padding_total = kernel_size - stride
|
| 60 |
+
extra_padding = get_extra_padding_for_conv1d(
|
| 61 |
+
x, kernel_size, stride, padding_total
|
| 62 |
+
)
|
| 63 |
+
# Asymmetric padding required for odd strides
|
| 64 |
+
padding_right = padding_total // 2
|
| 65 |
+
padding_left = padding_total - padding_right
|
| 66 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding))
|
| 67 |
+
return super().forward(x)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ConvBlock1d(torch.nn.Module):
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
in_channels: int,
|
| 74 |
+
out_channels: int,
|
| 75 |
+
*,
|
| 76 |
+
kernel_size: int = 3,
|
| 77 |
+
stride: int = 1,
|
| 78 |
+
dilation: int = 1,
|
| 79 |
+
num_groups: int = 8,
|
| 80 |
+
) -> None:
|
| 81 |
+
super().__init__()
|
| 82 |
+
|
| 83 |
+
self.groupnorm = torch.nn.GroupNorm(
|
| 84 |
+
num_groups=num_groups, num_channels=in_channels
|
| 85 |
+
)
|
| 86 |
+
self.activation = torch.nn.SiLU()
|
| 87 |
+
self.project = Conv1d(
|
| 88 |
+
in_channels=in_channels,
|
| 89 |
+
out_channels=out_channels,
|
| 90 |
+
kernel_size=kernel_size,
|
| 91 |
+
stride=stride,
|
| 92 |
+
dilation=dilation,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def forward(
|
| 96 |
+
self,
|
| 97 |
+
x: torch.Tensor,
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
x = self.groupnorm(x)
|
| 100 |
+
x = self.activation(x)
|
| 101 |
+
return self.project(x)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ResnetBlock1d(torch.nn.Module):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
in_channels: int,
|
| 108 |
+
out_channels: int,
|
| 109 |
+
*,
|
| 110 |
+
kernel_size: int = 3,
|
| 111 |
+
stride: int = 1,
|
| 112 |
+
dilation: int = 1,
|
| 113 |
+
num_groups: int = 8,
|
| 114 |
+
) -> None:
|
| 115 |
+
super().__init__()
|
| 116 |
+
|
| 117 |
+
self.block1 = ConvBlock1d(
|
| 118 |
+
in_channels=in_channels,
|
| 119 |
+
out_channels=out_channels,
|
| 120 |
+
kernel_size=kernel_size,
|
| 121 |
+
stride=stride,
|
| 122 |
+
dilation=dilation,
|
| 123 |
+
num_groups=num_groups,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self.block2 = ConvBlock1d(
|
| 127 |
+
in_channels=out_channels,
|
| 128 |
+
out_channels=out_channels,
|
| 129 |
+
num_groups=num_groups,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.to_out = (
|
| 133 |
+
Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
|
| 134 |
+
if in_channels != out_channels
|
| 135 |
+
else torch.nn.Identity()
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 139 |
+
h = self.block1(x)
|
| 140 |
+
h = self.block2(h)
|
| 141 |
+
return h + self.to_out(x)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class Patcher(torch.nn.Module):
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
in_channels: int,
|
| 148 |
+
out_channels: int,
|
| 149 |
+
patch_size: int,
|
| 150 |
+
):
|
| 151 |
+
super().__init__()
|
| 152 |
+
assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
|
| 153 |
+
assert out_channels % patch_size == 0, assert_message
|
| 154 |
+
self.patch_size = patch_size
|
| 155 |
+
self.block = ResnetBlock1d(
|
| 156 |
+
in_channels=in_channels,
|
| 157 |
+
out_channels=out_channels // patch_size,
|
| 158 |
+
num_groups=1,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 162 |
+
x = self.block(x)
|
| 163 |
+
x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
|
| 164 |
+
return x
|
sam_audio/model/rope.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
|
| 10 |
+
"""
|
| 11 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
| 12 |
+
|
| 13 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
| 14 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
|
| 18 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
| 19 |
+
seq_dim (int): Sequence dimension index.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
torch.Tensor: Reshaped frequency tensor.
|
| 23 |
+
"""
|
| 24 |
+
ndim = x.ndim
|
| 25 |
+
assert 0 <= seq_dim < ndim
|
| 26 |
+
assert freqs_cis.shape == (
|
| 27 |
+
x.shape[seq_dim],
|
| 28 |
+
x.shape[-3],
|
| 29 |
+
2,
|
| 30 |
+
2,
|
| 31 |
+
), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
|
| 32 |
+
shape = [
|
| 33 |
+
d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
|
| 34 |
+
] + [2, 2]
|
| 35 |
+
return freqs_cis.view(*shape)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def apply_rotary_emb(
|
| 39 |
+
xq: torch.Tensor,
|
| 40 |
+
xk: torch.Tensor,
|
| 41 |
+
seq_dim: int,
|
| 42 |
+
freqs_cis: torch.Tensor,
|
| 43 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 44 |
+
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
|
| 45 |
+
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
|
| 46 |
+
freqs_cis = reshape_for_broadcast(
|
| 47 |
+
freqs_cis, xq_, seq_dim
|
| 48 |
+
).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
|
| 49 |
+
xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
|
| 50 |
+
xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
|
| 51 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class RotaryEmbedding(torch.nn.Module):
|
| 55 |
+
"""
|
| 56 |
+
RotaryEmbedding Module
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
theta: float,
|
| 62 |
+
head_dim: int,
|
| 63 |
+
max_seqlen: int = 1024,
|
| 64 |
+
scale_factor: int = 1,
|
| 65 |
+
low_freq_factor: int = 1,
|
| 66 |
+
high_freq_factor: int = 32,
|
| 67 |
+
old_context_len: int = 8192,
|
| 68 |
+
):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.theta = theta
|
| 72 |
+
self.head_dim = head_dim
|
| 73 |
+
self.max_seqlen = max_seqlen
|
| 74 |
+
self.scale_factor = scale_factor
|
| 75 |
+
self.low_freq_factor = low_freq_factor
|
| 76 |
+
self.high_freq_factor = high_freq_factor
|
| 77 |
+
self.old_context_len = old_context_len
|
| 78 |
+
if scale_factor != 1:
|
| 79 |
+
self.low_freq_wavelen = old_context_len / low_freq_factor
|
| 80 |
+
self.high_freq_wavelen = old_context_len / high_freq_factor
|
| 81 |
+
assert self.low_freq_wavelen >= self.high_freq_wavelen
|
| 82 |
+
|
| 83 |
+
def reset_parameters(self):
|
| 84 |
+
freqs_cis = self.precompute_freqs_cis(
|
| 85 |
+
dim=self.head_dim, end=self.max_seqlen, theta=self.theta
|
| 86 |
+
)
|
| 87 |
+
S, D, _, _ = freqs_cis.shape
|
| 88 |
+
# S D 2 2 -> 1 S 1 D 2 2
|
| 89 |
+
freqs_cis = freqs_cis.view(1, S, 1, D, 2, 2)
|
| 90 |
+
self.register_buffer(
|
| 91 |
+
"freqs_cis",
|
| 92 |
+
freqs_cis,
|
| 93 |
+
persistent=False,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def apply_scaling(self, freqs):
|
| 97 |
+
if self.scale_factor == 1:
|
| 98 |
+
return freqs
|
| 99 |
+
new_freqs = []
|
| 100 |
+
for freq in freqs:
|
| 101 |
+
wavelen = 2 * math.pi / freq
|
| 102 |
+
if wavelen < self.high_freq_wavelen:
|
| 103 |
+
new_freqs.append(freq)
|
| 104 |
+
elif wavelen > self.low_freq_wavelen:
|
| 105 |
+
new_freqs.append(freq / self.scale_factor)
|
| 106 |
+
else:
|
| 107 |
+
assert self.low_freq_wavelen != self.high_freq_wavelen
|
| 108 |
+
smooth = (self.old_context_len / wavelen - self.low_freq_factor) / (
|
| 109 |
+
self.high_freq_factor - self.low_freq_factor
|
| 110 |
+
)
|
| 111 |
+
new_freqs.append(
|
| 112 |
+
(1 - smooth) * freq / self.scale_factor + smooth * freq
|
| 113 |
+
)
|
| 114 |
+
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
| 115 |
+
|
| 116 |
+
def precompute_freqs_cis(
|
| 117 |
+
self,
|
| 118 |
+
dim: int,
|
| 119 |
+
end: int,
|
| 120 |
+
theta: float = 10000.0,
|
| 121 |
+
):
|
| 122 |
+
"""
|
| 123 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
| 124 |
+
|
| 125 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
| 126 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
| 127 |
+
The returned tensor contains complex values in complex64 data type.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
dim (int): Dimension of the frequency tensor.
|
| 131 |
+
end (int): End index for precomputing frequencies.
|
| 132 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
torch.Tensor: Precomputed frequency tensor with complex exponentials.
|
| 136 |
+
"""
|
| 137 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 138 |
+
freqs = self.apply_scaling(freqs)
|
| 139 |
+
|
| 140 |
+
t = torch.arange(end, device=freqs.device)
|
| 141 |
+
freqs = torch.outer(t, freqs).float()
|
| 142 |
+
|
| 143 |
+
cos, sin = freqs.cos(), freqs.sin()
|
| 144 |
+
|
| 145 |
+
return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
|
| 146 |
+
|
| 147 |
+
def forward(self, x: torch.Tensor, bhle: bool = False, **kwargs):
|
| 148 |
+
if bhle:
|
| 149 |
+
x = x.transpose(1, 2) # (B H L E) -> (B L H E)
|
| 150 |
+
seqlen = x.size(1)
|
| 151 |
+
x_ = x.reshape(*x.shape[:-1], -1, 1, 2) # B L H E -> B L H E/2 1 2
|
| 152 |
+
x_out = (x_ * self.freqs_cis[:, :seqlen]).sum(5).flatten(3)
|
| 153 |
+
if bhle:
|
| 154 |
+
x_out = x_out.transpose(1, 2)
|
| 155 |
+
return x_out.type_as(x)
|
sam_audio/model/text_encoder.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import transformers
|
| 7 |
+
|
| 8 |
+
from sam_audio.model.config import T5EncoderConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class T5TextEncoder(torch.nn.Module):
|
| 12 |
+
def __init__(self, cfg: T5EncoderConfig):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.model = transformers.T5EncoderModel.from_pretrained(cfg.name)
|
| 15 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(cfg.name)
|
| 16 |
+
self.pad_mode = cfg.pad_mode
|
| 17 |
+
self.max_length = cfg.max_length
|
| 18 |
+
|
| 19 |
+
def forward(self, texts: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 20 |
+
device = next(self.model.parameters()).device
|
| 21 |
+
encoded = self.tokenizer(
|
| 22 |
+
texts,
|
| 23 |
+
truncation=True,
|
| 24 |
+
max_length=self.max_length,
|
| 25 |
+
padding=self.pad_mode,
|
| 26 |
+
return_tensors="pt",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
input_ids = encoded["input_ids"].to(device)
|
| 30 |
+
attention_mask = encoded["attention_mask"].to(device)
|
| 31 |
+
res = self.model(
|
| 32 |
+
input_ids=input_ids,
|
| 33 |
+
attention_mask=attention_mask,
|
| 34 |
+
output_hidden_states=True,
|
| 35 |
+
)["last_hidden_state"]
|
| 36 |
+
|
| 37 |
+
return res, attention_mask.bool()
|
sam_audio/model/transformer.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
from .config import TransformerConfig
|
| 13 |
+
from .patcher import Patcher
|
| 14 |
+
from .rope import RotaryEmbedding
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def gate(x, gate):
|
| 18 |
+
return x * gate
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def modulate(x, shift, scale):
|
| 22 |
+
return x * (1 + scale) + shift
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_nonlinearity(kind: str):
|
| 26 |
+
return {
|
| 27 |
+
"relu": F.relu,
|
| 28 |
+
"gelu": F.gelu,
|
| 29 |
+
"swiglu": None,
|
| 30 |
+
"approx_gelu": partial(F.gelu, approximate="tanh"),
|
| 31 |
+
"srelu": lambda x: F.relu(x) ** 2,
|
| 32 |
+
"silu": F.silu,
|
| 33 |
+
}[kind]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class RMSNorm(torch.nn.Module):
|
| 37 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.eps = eps
|
| 40 |
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
| 41 |
+
|
| 42 |
+
def _norm(self, x):
|
| 43 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
output = self._norm(x.float())
|
| 47 |
+
return (output * self.weight).type_as(x)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ProjectionLayer(torch.nn.Module):
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
in_dim: int,
|
| 54 |
+
out_dim: int,
|
| 55 |
+
non_linearity: str,
|
| 56 |
+
dropout: float,
|
| 57 |
+
fc_bias: bool = False,
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
|
| 61 |
+
self.swiglu = non_linearity == "swiglu"
|
| 62 |
+
self.dropout = dropout
|
| 63 |
+
self.w1 = torch.nn.Linear(in_dim, out_dim, bias=fc_bias)
|
| 64 |
+
|
| 65 |
+
self.w2 = torch.nn.Linear(out_dim, out_dim, bias=fc_bias)
|
| 66 |
+
if self.swiglu:
|
| 67 |
+
self.w3 = torch.nn.Linear(in_dim, out_dim, bias=fc_bias)
|
| 68 |
+
|
| 69 |
+
# non-linearity
|
| 70 |
+
self.non_linearity = get_nonlinearity(non_linearity)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
hidden1 = self.w1(x)
|
| 74 |
+
if self.swiglu:
|
| 75 |
+
hidden3 = self.w3(x)
|
| 76 |
+
hidden = F.silu(hidden1) * hidden3
|
| 77 |
+
else:
|
| 78 |
+
hidden = self.non_linearity(hidden1)
|
| 79 |
+
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
| 80 |
+
return self.w2(hidden)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Attention(nn.Module):
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
dim: int,
|
| 87 |
+
head_dim: int,
|
| 88 |
+
n_heads: int,
|
| 89 |
+
n_kv_heads: int,
|
| 90 |
+
norm_eps: float = 1e-5,
|
| 91 |
+
use_qk_norm: bool = False,
|
| 92 |
+
fc_bias: bool = False,
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
assert n_heads % n_kv_heads == 0
|
| 96 |
+
|
| 97 |
+
self.head_dim = head_dim
|
| 98 |
+
self.n_heads = n_heads
|
| 99 |
+
self.n_kv_heads = n_kv_heads
|
| 100 |
+
self.use_qk_norm = use_qk_norm
|
| 101 |
+
|
| 102 |
+
self.wq = torch.nn.Linear(dim, n_heads * head_dim, bias=fc_bias)
|
| 103 |
+
self.wk, self.wv = [
|
| 104 |
+
torch.nn.Linear(
|
| 105 |
+
dim,
|
| 106 |
+
n_kv_heads * head_dim,
|
| 107 |
+
bias=fc_bias,
|
| 108 |
+
)
|
| 109 |
+
for _ in range(2)
|
| 110 |
+
]
|
| 111 |
+
self.wo = torch.nn.Linear(
|
| 112 |
+
n_heads * head_dim,
|
| 113 |
+
dim,
|
| 114 |
+
bias=fc_bias,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if self.use_qk_norm is True:
|
| 118 |
+
self.q_norm = RMSNorm(head_dim, eps=norm_eps)
|
| 119 |
+
self.k_norm = RMSNorm(head_dim, eps=norm_eps)
|
| 120 |
+
|
| 121 |
+
def reshape_heads(self, x: torch.Tensor, heads: int) -> torch.Tensor:
|
| 122 |
+
B, T, C = x.shape
|
| 123 |
+
# B x T x C -> B x T x C/H x H
|
| 124 |
+
x = x.reshape(B, T, C // heads, heads)
|
| 125 |
+
# B x T x C/H x H -> B x H x T x C/H
|
| 126 |
+
return x.permute(0, 3, 1, 2)
|
| 127 |
+
|
| 128 |
+
def forward(
|
| 129 |
+
self,
|
| 130 |
+
x: torch.Tensor,
|
| 131 |
+
cross_x: Optional[torch.Tensor] = None,
|
| 132 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
| 133 |
+
rope: Optional[RotaryEmbedding] = None,
|
| 134 |
+
):
|
| 135 |
+
# x: B, T, E
|
| 136 |
+
xq = self.wq(x)
|
| 137 |
+
if cross_x is not None:
|
| 138 |
+
xk, xv = self.wk(cross_x), self.wv(cross_x)
|
| 139 |
+
else:
|
| 140 |
+
xk, xv = self.wk(x), self.wv(x)
|
| 141 |
+
|
| 142 |
+
xk = self.reshape_heads(xk, self.n_kv_heads)
|
| 143 |
+
xv = self.reshape_heads(xv, self.n_kv_heads)
|
| 144 |
+
xq = self.reshape_heads(xq, self.n_heads)
|
| 145 |
+
if self.use_qk_norm:
|
| 146 |
+
xq = self.q_norm(xq)
|
| 147 |
+
xk = self.k_norm(xk)
|
| 148 |
+
|
| 149 |
+
if rope is not None:
|
| 150 |
+
xq = rope(xq, bhle=True)
|
| 151 |
+
xk = rope(xk, bhle=True)
|
| 152 |
+
|
| 153 |
+
attn_mask = None
|
| 154 |
+
|
| 155 |
+
if key_padding_mask is not None:
|
| 156 |
+
attn_mask = key_padding_mask[:, None, None, :]
|
| 157 |
+
|
| 158 |
+
output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask)
|
| 159 |
+
|
| 160 |
+
output = rearrange(output, "b h n d -> b n (h d)")
|
| 161 |
+
return self.wo(output)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class FeedForward(torch.nn.Module):
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
dim: int,
|
| 168 |
+
hidden_dim: int,
|
| 169 |
+
ffn_dim_multiplier: float,
|
| 170 |
+
multiple_of: int,
|
| 171 |
+
dropout: float,
|
| 172 |
+
non_linearity: str = "swiglu",
|
| 173 |
+
fc_bias: bool = False,
|
| 174 |
+
):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.dropout = dropout
|
| 177 |
+
self.swiglu = non_linearity == "swiglu"
|
| 178 |
+
# swiglu hidden dim factor multiplier (same #params as relu / gelu)
|
| 179 |
+
if self.swiglu:
|
| 180 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 181 |
+
|
| 182 |
+
# custom dim factor multiplier
|
| 183 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 184 |
+
# round hidden dimension to `multiple_of`
|
| 185 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 186 |
+
# layers
|
| 187 |
+
self.w1 = torch.nn.Linear(dim, hidden_dim, bias=fc_bias)
|
| 188 |
+
self.w2 = torch.nn.Linear(hidden_dim, dim, bias=fc_bias)
|
| 189 |
+
if self.swiglu:
|
| 190 |
+
self.w3 = torch.nn.Linear(dim, hidden_dim, bias=fc_bias)
|
| 191 |
+
|
| 192 |
+
# non-linearity
|
| 193 |
+
self.non_linearity = get_nonlinearity(non_linearity)
|
| 194 |
+
|
| 195 |
+
def forward(
|
| 196 |
+
self,
|
| 197 |
+
x,
|
| 198 |
+
):
|
| 199 |
+
hidden1 = self.w1(x)
|
| 200 |
+
if self.swiglu:
|
| 201 |
+
hidden3 = self.w3(x)
|
| 202 |
+
hidden = F.silu(hidden1) * hidden3
|
| 203 |
+
else:
|
| 204 |
+
hidden = self.non_linearity(hidden1)
|
| 205 |
+
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
| 206 |
+
return self.w2(hidden)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class TimestepEmbedder(torch.nn.Module):
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
dim: int,
|
| 213 |
+
frequency_embedding_dim: int,
|
| 214 |
+
non_linearity: str,
|
| 215 |
+
dropout: float,
|
| 216 |
+
fc_bias: bool,
|
| 217 |
+
max_period: int = 10000,
|
| 218 |
+
):
|
| 219 |
+
super().__init__()
|
| 220 |
+
self.frequency_embedding_size = frequency_embedding_dim
|
| 221 |
+
self.projection = ProjectionLayer(
|
| 222 |
+
in_dim=frequency_embedding_dim,
|
| 223 |
+
out_dim=dim,
|
| 224 |
+
non_linearity=non_linearity,
|
| 225 |
+
dropout=dropout,
|
| 226 |
+
fc_bias=fc_bias,
|
| 227 |
+
)
|
| 228 |
+
half = frequency_embedding_dim // 2
|
| 229 |
+
freqs = torch.exp(
|
| 230 |
+
-math.log(max_period)
|
| 231 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
| 232 |
+
/ half
|
| 233 |
+
)
|
| 234 |
+
self.register_buffer("freqs", freqs, persistent=False)
|
| 235 |
+
|
| 236 |
+
def timestep_embedding(self, t, dim):
|
| 237 |
+
"""
|
| 238 |
+
Create sinusoidal timestep embeddings.
|
| 239 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 240 |
+
These may be fractional.
|
| 241 |
+
:param dim: the dimension of the output.
|
| 242 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 243 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 244 |
+
"""
|
| 245 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 246 |
+
self.freqs = self.freqs.to(device=t.device)
|
| 247 |
+
args = t[:, None].float() * self.freqs[None]
|
| 248 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 249 |
+
if dim % 2:
|
| 250 |
+
embedding = torch.cat(
|
| 251 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
| 252 |
+
)
|
| 253 |
+
return embedding.to(t)
|
| 254 |
+
|
| 255 |
+
def forward(self, t):
|
| 256 |
+
x = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 257 |
+
return self.projection(x)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class ContextEmbedder(torch.nn.Module):
|
| 261 |
+
def __init__(
|
| 262 |
+
self,
|
| 263 |
+
in_dim: int,
|
| 264 |
+
out_dim: int,
|
| 265 |
+
non_linearity: str,
|
| 266 |
+
dropout: float,
|
| 267 |
+
fc_bias: bool,
|
| 268 |
+
norm_eps: float = 1e-5,
|
| 269 |
+
context_norm: bool = False,
|
| 270 |
+
):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.context_norm = context_norm
|
| 273 |
+
if context_norm:
|
| 274 |
+
self.norm = RMSNorm(in_dim, norm_eps)
|
| 275 |
+
|
| 276 |
+
self.projection = ProjectionLayer(
|
| 277 |
+
in_dim=in_dim,
|
| 278 |
+
out_dim=out_dim,
|
| 279 |
+
non_linearity=non_linearity,
|
| 280 |
+
dropout=dropout,
|
| 281 |
+
fc_bias=fc_bias,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
def forward(self, x):
|
| 285 |
+
if self.context_norm:
|
| 286 |
+
x = self.norm(x)
|
| 287 |
+
h = self.projection(x)
|
| 288 |
+
return h
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class DiTBlock(torch.nn.Module):
|
| 292 |
+
def __init__(
|
| 293 |
+
self,
|
| 294 |
+
dim: int,
|
| 295 |
+
n_heads: int,
|
| 296 |
+
n_kv_heads: Optional[int] = None,
|
| 297 |
+
dropout: float = 0.0,
|
| 298 |
+
norm_eps: float = 1e-5,
|
| 299 |
+
qk_norm: bool = False,
|
| 300 |
+
fc_bias: bool = False,
|
| 301 |
+
ffn_exp: int = 1,
|
| 302 |
+
ffn_dim_multiplier: int = 4,
|
| 303 |
+
multiple_of: int = 64,
|
| 304 |
+
non_linearity: str = "silu",
|
| 305 |
+
no_cross_attention: bool = False,
|
| 306 |
+
):
|
| 307 |
+
super().__init__()
|
| 308 |
+
assert dim % n_heads == 0
|
| 309 |
+
self.n_heads = n_heads
|
| 310 |
+
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
|
| 311 |
+
self.dim = dim
|
| 312 |
+
self.dropout = dropout
|
| 313 |
+
self.head_dim = dim // n_heads
|
| 314 |
+
|
| 315 |
+
assert self.n_heads % self.n_kv_heads == 0
|
| 316 |
+
|
| 317 |
+
self.attention = Attention(
|
| 318 |
+
dim=dim,
|
| 319 |
+
head_dim=self.head_dim,
|
| 320 |
+
n_heads=self.n_heads,
|
| 321 |
+
n_kv_heads=self.n_kv_heads,
|
| 322 |
+
norm_eps=norm_eps,
|
| 323 |
+
use_qk_norm=qk_norm,
|
| 324 |
+
fc_bias=fc_bias,
|
| 325 |
+
)
|
| 326 |
+
self.feed_forward = FeedForward(
|
| 327 |
+
dim=dim,
|
| 328 |
+
hidden_dim=int(ffn_exp * dim),
|
| 329 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
| 330 |
+
multiple_of=multiple_of,
|
| 331 |
+
dropout=dropout,
|
| 332 |
+
non_linearity=non_linearity,
|
| 333 |
+
fc_bias=fc_bias,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
self.attention_norm, self.ffn_norm = [RMSNorm(dim, norm_eps) for _ in range(2)]
|
| 337 |
+
|
| 338 |
+
self.cross_attention = None
|
| 339 |
+
if not no_cross_attention:
|
| 340 |
+
self.cross_attention = Attention(
|
| 341 |
+
dim=dim,
|
| 342 |
+
head_dim=self.head_dim,
|
| 343 |
+
n_heads=self.n_heads,
|
| 344 |
+
n_kv_heads=self.n_heads,
|
| 345 |
+
norm_eps=norm_eps,
|
| 346 |
+
use_qk_norm=qk_norm,
|
| 347 |
+
fc_bias=fc_bias,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
self.scale_shift_table = nn.Parameter(
|
| 351 |
+
torch.randn(6, self.dim) / self.dim**0.5,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
def forward(
|
| 355 |
+
self,
|
| 356 |
+
x: torch.Tensor,
|
| 357 |
+
cross_x: Optional[torch.Tensor],
|
| 358 |
+
t: torch.Tensor,
|
| 359 |
+
padding_mask: Optional[torch.Tensor],
|
| 360 |
+
memory_padding_mask: Optional[torch.Tensor],
|
| 361 |
+
rope: Optional[RotaryEmbedding] = None,
|
| 362 |
+
):
|
| 363 |
+
biases = self.scale_shift_table[None] + t.reshape(x.size(0), 6, -1)
|
| 364 |
+
(
|
| 365 |
+
shift_msa,
|
| 366 |
+
scale_msa,
|
| 367 |
+
gate_msa,
|
| 368 |
+
shift_mlp,
|
| 369 |
+
scale_mlp,
|
| 370 |
+
gate_mlp,
|
| 371 |
+
) = biases.chunk(6, dim=1)
|
| 372 |
+
|
| 373 |
+
assert self.attention is not None and self.attention_norm is not None
|
| 374 |
+
h_attn = self.attention(
|
| 375 |
+
modulate(self.attention_norm(x), shift_msa, scale_msa),
|
| 376 |
+
key_padding_mask=padding_mask,
|
| 377 |
+
rope=rope,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
h = x + gate(h_attn, gate_msa)
|
| 381 |
+
|
| 382 |
+
if self.cross_attention is not None:
|
| 383 |
+
h_cross = self.cross_attention(
|
| 384 |
+
x=h,
|
| 385 |
+
cross_x=cross_x,
|
| 386 |
+
key_padding_mask=memory_padding_mask,
|
| 387 |
+
)
|
| 388 |
+
h = h + h_cross # residual
|
| 389 |
+
h_ff = self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp))
|
| 390 |
+
out = h + gate(h_ff, gate_mlp)
|
| 391 |
+
return out
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class DiT(torch.nn.Module):
|
| 395 |
+
def __init__(self, config: TransformerConfig):
|
| 396 |
+
super().__init__()
|
| 397 |
+
self.dropout = config.dropout
|
| 398 |
+
if config.in_channels is not None:
|
| 399 |
+
self.data_proj = torch.nn.Linear(config.in_channels, config.dim)
|
| 400 |
+
|
| 401 |
+
# embeddings
|
| 402 |
+
self.rope_embeddings = None
|
| 403 |
+
# rotary embeddings
|
| 404 |
+
if config.use_rope:
|
| 405 |
+
self.rope_embeddings = RotaryEmbedding(
|
| 406 |
+
theta=max(10000, 2 * config.max_positions),
|
| 407 |
+
head_dim=config.dim // config.n_heads,
|
| 408 |
+
max_seqlen=config.max_positions,
|
| 409 |
+
)
|
| 410 |
+
self.rope_embeddings.reset_parameters()
|
| 411 |
+
|
| 412 |
+
# transformer blocks
|
| 413 |
+
self.layers = nn.ModuleList()
|
| 414 |
+
for _ in range(config.n_layers):
|
| 415 |
+
self.layers.append(
|
| 416 |
+
DiTBlock(
|
| 417 |
+
dim=config.dim,
|
| 418 |
+
n_heads=config.n_heads,
|
| 419 |
+
dropout=config.dropout,
|
| 420 |
+
norm_eps=config.norm_eps,
|
| 421 |
+
qk_norm=config.qk_norm,
|
| 422 |
+
fc_bias=config.fc_bias,
|
| 423 |
+
ffn_exp=config.ffn_exp,
|
| 424 |
+
ffn_dim_multiplier=config.ffn_dim_multiplier,
|
| 425 |
+
multiple_of=config.multiple_of,
|
| 426 |
+
non_linearity=config.non_linearity,
|
| 427 |
+
)
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
self.norm = RMSNorm(config.dim, config.norm_eps)
|
| 431 |
+
|
| 432 |
+
# output layer
|
| 433 |
+
self.output = torch.nn.Linear(
|
| 434 |
+
config.dim, config.out_channels, bias=config.fc_bias
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
self.x_embedder = Patcher(
|
| 438 |
+
in_channels=config.dim,
|
| 439 |
+
out_channels=config.dim,
|
| 440 |
+
patch_size=1,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
self.y_embedder = ContextEmbedder(
|
| 444 |
+
in_dim=config.context_dim,
|
| 445 |
+
out_dim=config.dim,
|
| 446 |
+
non_linearity=config.context_non_linearity,
|
| 447 |
+
dropout=config.context_embedder_dropout,
|
| 448 |
+
fc_bias=config.fc_bias,
|
| 449 |
+
norm_eps=config.norm_eps,
|
| 450 |
+
context_norm=config.context_norm,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
self.t_embedder = TimestepEmbedder(
|
| 454 |
+
config.dim,
|
| 455 |
+
config.frequency_embedding_dim,
|
| 456 |
+
non_linearity=config.timestep_non_linearity,
|
| 457 |
+
dropout=config.dropout,
|
| 458 |
+
fc_bias=config.fc_bias,
|
| 459 |
+
max_period=10000,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
self.t_block_non_linearity = get_nonlinearity(config.t_block_non_linearity)
|
| 463 |
+
self.t_block = torch.nn.Linear(
|
| 464 |
+
config.dim,
|
| 465 |
+
config.dim * 6,
|
| 466 |
+
bias=config.t_block_bias,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
self.final_layer_scale_shift_table = nn.Parameter(
|
| 470 |
+
torch.randn(2, config.dim) / config.dim**0.5,
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
def forward(
|
| 474 |
+
self,
|
| 475 |
+
x: torch.Tensor,
|
| 476 |
+
time: torch.Tensor,
|
| 477 |
+
*,
|
| 478 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 479 |
+
memory: Optional[torch.Tensor] = None,
|
| 480 |
+
memory_padding_mask: Optional[torch.Tensor] = None,
|
| 481 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 482 |
+
x = rearrange(x, "b l c-> b c l")
|
| 483 |
+
h = self.x_embedder(x)
|
| 484 |
+
h = rearrange(h, "b c l -> b l c")
|
| 485 |
+
original_N = h.shape[1]
|
| 486 |
+
N = h.shape[1]
|
| 487 |
+
|
| 488 |
+
h = F.dropout(h, p=self.dropout, training=self.training)
|
| 489 |
+
|
| 490 |
+
t = self.t_embedder(time) # B -> B D
|
| 491 |
+
|
| 492 |
+
t0 = self.t_block_non_linearity(t)
|
| 493 |
+
t0 = self.t_block(t0) # B D -> B 6D
|
| 494 |
+
|
| 495 |
+
y = self.y_embedder(memory)
|
| 496 |
+
|
| 497 |
+
for layer in self.layers:
|
| 498 |
+
h = layer(
|
| 499 |
+
x=h,
|
| 500 |
+
cross_x=y,
|
| 501 |
+
t=t0,
|
| 502 |
+
padding_mask=padding_mask,
|
| 503 |
+
memory_padding_mask=memory_padding_mask,
|
| 504 |
+
rope=self.rope_embeddings,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
shift, scale = (self.final_layer_scale_shift_table[None] + t[:, None]).chunk(
|
| 508 |
+
2, dim=1
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# output layer
|
| 512 |
+
if self.norm is not None:
|
| 513 |
+
h = self.norm(h)
|
| 514 |
+
|
| 515 |
+
h = modulate(h, shift, scale)
|
| 516 |
+
|
| 517 |
+
h = F.dropout(h, p=self.dropout, training=self.training)
|
| 518 |
+
|
| 519 |
+
output = self.output(h)
|
| 520 |
+
|
| 521 |
+
N = output.shape[1]
|
| 522 |
+
if original_N != N:
|
| 523 |
+
output = output[:, -original_N:]
|
| 524 |
+
return output
|
sam_audio/model/vision_encoder.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from abc import ABCMeta, abstractmethod
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision
|
| 7 |
+
from core.vision_encoder import pe
|
| 8 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 9 |
+
|
| 10 |
+
from sam_audio.model.config import (
|
| 11 |
+
PerceptionEncoderConfig,
|
| 12 |
+
VisionEncoderConfig,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class RescaleTransform(object):
|
| 17 |
+
"""Rescale the image in a sample to a given size.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
output_size (tuple or int): Desired output size. If tuple, output is
|
| 21 |
+
matched to output_size. If int, smaller of image edges is matched
|
| 22 |
+
to output_size keeping aspect ratio the same.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, output_size, interpolation):
|
| 26 |
+
assert isinstance(output_size, (int, tuple))
|
| 27 |
+
self.output_size = output_size
|
| 28 |
+
if isinstance(output_size, int):
|
| 29 |
+
self.output_size = (output_size, output_size)
|
| 30 |
+
self.interpolation = interpolation
|
| 31 |
+
|
| 32 |
+
def __call__(self, sample):
|
| 33 |
+
# sample: [T, C, H, W]
|
| 34 |
+
sample = torch.nn.functional.interpolate(
|
| 35 |
+
sample.float(), size=self.output_size, mode=self.interpolation.value
|
| 36 |
+
)
|
| 37 |
+
return sample
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class VisionEncoder(torch.nn.Module, metaclass=ABCMeta):
|
| 41 |
+
def __init__(self, cfg: VisionEncoderConfig):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.batch_size = cfg.batch_size
|
| 44 |
+
self.dim = cfg.dim
|
| 45 |
+
self.transform = self.get_transform()
|
| 46 |
+
|
| 47 |
+
@torch.no_grad()
|
| 48 |
+
def forward(self, videos: list[torch.Tensor]) -> torch.Tensor:
|
| 49 |
+
"""
|
| 50 |
+
Encodes a list of input videos. Each element of the list is a video represented
|
| 51 |
+
as a tensor [T, C, H, W]
|
| 52 |
+
Args:
|
| 53 |
+
videos (list[torch.Tensor]): List of input image tensors to be processed.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
torch.Tensor: Encoded feature representations of the input tensors.
|
| 57 |
+
The output is padded along the time dimension for variable length videos
|
| 58 |
+
"""
|
| 59 |
+
result = []
|
| 60 |
+
for video in videos:
|
| 61 |
+
video = self.transform(video)
|
| 62 |
+
if self.batch_size > 0 and video.size(0) > self.batch_size:
|
| 63 |
+
res = []
|
| 64 |
+
for i in range(0, video.size(0), self.batch_size):
|
| 65 |
+
res.append(self.encode(video[i : i + self.batch_size]))
|
| 66 |
+
result.append(torch.cat(res, dim=0))
|
| 67 |
+
else:
|
| 68 |
+
result.append(self.encode(video))
|
| 69 |
+
return pad_sequence(result, batch_first=True, padding_value=0.0)
|
| 70 |
+
|
| 71 |
+
@abstractmethod
|
| 72 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
@abstractmethod
|
| 76 |
+
def get_transform(self):
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class PerceptionEncoder(VisionEncoder):
|
| 81 |
+
def __init__(self, cfg: PerceptionEncoderConfig):
|
| 82 |
+
self.normalize_feature = cfg.normalize_feature
|
| 83 |
+
self.interpolation_mode = cfg.interpolation_mode
|
| 84 |
+
self.image_size = cfg.image_size
|
| 85 |
+
super().__init__(cfg)
|
| 86 |
+
self.model = pe.CLIP.from_config(cfg.name)
|
| 87 |
+
|
| 88 |
+
def encode(self, x):
|
| 89 |
+
image_features = self.model.encode_image(x, normalize=self.normalize_feature)
|
| 90 |
+
return image_features
|
| 91 |
+
|
| 92 |
+
def get_transform(self):
|
| 93 |
+
T = torchvision.transforms
|
| 94 |
+
try:
|
| 95 |
+
interp = getattr(T.InterpolationMode, self.interpolation_mode.upper())
|
| 96 |
+
except AttributeError as err:
|
| 97 |
+
raise ValueError(
|
| 98 |
+
f"Unsupported interpolation_mode: {self.interpolation_mode}"
|
| 99 |
+
) from err
|
| 100 |
+
crop = [
|
| 101 |
+
T.Resize(
|
| 102 |
+
(self.image_size, self.image_size),
|
| 103 |
+
interpolation=interp,
|
| 104 |
+
)
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
return T.Compose(
|
| 108 |
+
crop
|
| 109 |
+
+ [
|
| 110 |
+
T.Lambda(lambda x: x.float() / 255.0),
|
| 111 |
+
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
|
| 112 |
+
]
|
| 113 |
+
)
|
sam_audio/processor.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 13 |
+
from torchcodec.decoders import AudioDecoder, VideoDecoder
|
| 14 |
+
from transformers import AutoTokenizer, BatchFeature
|
| 15 |
+
|
| 16 |
+
from sam_audio.model.config import SAMAudioConfig, SAMAudioJudgeConfig
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
Anchor = Tuple[str, float, float]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def batch_audio(
|
| 24 |
+
audios: list[str | torch.Tensor], audio_sampling_rate: int = 48_000
|
| 25 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 26 |
+
wavs = []
|
| 27 |
+
for audio in audios:
|
| 28 |
+
if isinstance(audio, str):
|
| 29 |
+
wav, sr = torchaudio.load(audio)
|
| 30 |
+
if sr != audio_sampling_rate:
|
| 31 |
+
wav = torchaudio.functional.resample(wav, sr, audio_sampling_rate)
|
| 32 |
+
else:
|
| 33 |
+
wav = audio
|
| 34 |
+
wavs.append(wav.mean(0))
|
| 35 |
+
sizes = torch.tensor([wav.size(-1) for wav in wavs])
|
| 36 |
+
return pad_sequence(wavs, batch_first=True).unsqueeze(1), sizes
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Batch:
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
audios: torch.Tensor,
|
| 43 |
+
sizes: torch.Tensor,
|
| 44 |
+
wav_sizes: torch.Tensor,
|
| 45 |
+
descriptions: list[str],
|
| 46 |
+
hop_length: int,
|
| 47 |
+
audio_sampling_rate: int,
|
| 48 |
+
anchors: Optional[list[list[Anchor]]] = None,
|
| 49 |
+
audio_pad_mask: Optional[torch.Tensor] = None,
|
| 50 |
+
masked_video: Optional[torch.Tensor] = None,
|
| 51 |
+
):
|
| 52 |
+
self.audios = audios
|
| 53 |
+
self.sizes = sizes
|
| 54 |
+
self.wav_sizes = wav_sizes
|
| 55 |
+
self.descriptions = descriptions
|
| 56 |
+
self.audio_pad_mask = audio_pad_mask
|
| 57 |
+
self.masked_video = masked_video
|
| 58 |
+
self.hop_length = hop_length
|
| 59 |
+
self.audio_sampling_rate = audio_sampling_rate
|
| 60 |
+
self.process_anchors(anchors)
|
| 61 |
+
assert self.audios.size(0) == len(self.descriptions)
|
| 62 |
+
|
| 63 |
+
def _wav_to_feature_idx(self, wav_idx: int):
|
| 64 |
+
return math.ceil(wav_idx / self.hop_length)
|
| 65 |
+
|
| 66 |
+
def to(self, device: torch.device):
|
| 67 |
+
self.audios = self.audios.to(device)
|
| 68 |
+
self.anchor_ids = self.anchor_ids.to(device)
|
| 69 |
+
self.anchor_alignment = self.anchor_alignment.to(device)
|
| 70 |
+
self.sizes = self.sizes.to(device)
|
| 71 |
+
self.wav_sizes = self.wav_sizes.to(device)
|
| 72 |
+
if self.audio_pad_mask is not None:
|
| 73 |
+
self.audio_pad_mask = self.audio_pad_mask.to(device)
|
| 74 |
+
if self.masked_video is not None:
|
| 75 |
+
self.masked_video = [v.to(device) for v in self.masked_video]
|
| 76 |
+
return self
|
| 77 |
+
|
| 78 |
+
def process_anchors(self, anchors: Optional[list[list[Anchor]]]):
|
| 79 |
+
batch_size = len(self.audios)
|
| 80 |
+
anchor_dict = {"<null>": 0, "+": 1, "-": 2, "<pad>": 3}
|
| 81 |
+
if anchors is None:
|
| 82 |
+
anchor_ids = torch.full(
|
| 83 |
+
(batch_size, 2), anchor_dict["<null>"], dtype=torch.long
|
| 84 |
+
)
|
| 85 |
+
anchor_ids[:, 1] = anchor_dict["<pad>"]
|
| 86 |
+
anchor_alignment = torch.full(
|
| 87 |
+
(
|
| 88 |
+
batch_size,
|
| 89 |
+
self.audio_pad_mask.size(-1),
|
| 90 |
+
),
|
| 91 |
+
0,
|
| 92 |
+
dtype=torch.long,
|
| 93 |
+
)
|
| 94 |
+
anchor_alignment[~self.audio_pad_mask] = 1 # point to pad token
|
| 95 |
+
else:
|
| 96 |
+
anchor_alignment = torch.full(
|
| 97 |
+
(
|
| 98 |
+
batch_size,
|
| 99 |
+
self.audio_pad_mask.size(-1),
|
| 100 |
+
),
|
| 101 |
+
0,
|
| 102 |
+
dtype=torch.long,
|
| 103 |
+
)
|
| 104 |
+
anchor_alignment[~self.audio_pad_mask] = 1 # point to pad token
|
| 105 |
+
ids = []
|
| 106 |
+
|
| 107 |
+
for i, anchor_list in enumerate(anchors):
|
| 108 |
+
current = [anchor_dict["<null>"], anchor_dict["<pad>"]]
|
| 109 |
+
for token, start_time, end_time in anchor_list:
|
| 110 |
+
start_idx = self._wav_to_feature_idx(
|
| 111 |
+
start_time * self.audio_sampling_rate
|
| 112 |
+
)
|
| 113 |
+
end_idx = self._wav_to_feature_idx(
|
| 114 |
+
end_time * self.audio_sampling_rate
|
| 115 |
+
)
|
| 116 |
+
anchor_alignment[i, start_idx:end_idx] = len(current)
|
| 117 |
+
current.append(anchor_dict[token])
|
| 118 |
+
ids.append(torch.tensor(current))
|
| 119 |
+
anchor_ids = pad_sequence(
|
| 120 |
+
ids, batch_first=True, padding_value=anchor_dict["<pad>"]
|
| 121 |
+
)
|
| 122 |
+
self.anchor_ids = anchor_ids
|
| 123 |
+
self.anchor_alignment = anchor_alignment
|
| 124 |
+
self.anchors = anchors
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def mask_from_sizes(sizes: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
return torch.arange(sizes.max()).expand(len(sizes), -1) < sizes.unsqueeze(1)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def load_video(
|
| 132 |
+
sizes: torch.Tensor,
|
| 133 |
+
videos: List[str],
|
| 134 |
+
feature_idx_to_wav_idx: Callable[[torch.Tensor], torch.Tensor],
|
| 135 |
+
audio_sampling_rate: int,
|
| 136 |
+
) -> list[torch.Tensor]:
|
| 137 |
+
all_frames = []
|
| 138 |
+
for size, video in zip(sizes, videos, strict=False):
|
| 139 |
+
audio_timestamps = (
|
| 140 |
+
feature_idx_to_wav_idx(torch.arange(size)) / audio_sampling_rate
|
| 141 |
+
)
|
| 142 |
+
if isinstance(video, str):
|
| 143 |
+
decoder = VideoDecoder(video, dimension_order="NCHW")
|
| 144 |
+
data = decoder.get_frames_in_range(0, len(decoder))
|
| 145 |
+
diffs = (audio_timestamps[None] - data.pts_seconds[:, None]).abs()
|
| 146 |
+
frame_idxs = diffs.argmin(dim=0)
|
| 147 |
+
frames = data.data[frame_idxs]
|
| 148 |
+
else:
|
| 149 |
+
assert video.size(1) == 3, (
|
| 150 |
+
f"Expected video tensor to be in NCHW format, but found {video.size(1)} channels"
|
| 151 |
+
)
|
| 152 |
+
idx = torch.linspace(0, video.size(0) - 1, int(size)).round().long()
|
| 153 |
+
frames = video[idx]
|
| 154 |
+
all_frames.append(frames)
|
| 155 |
+
return all_frames
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class Processor:
|
| 159 |
+
config_cls: Callable
|
| 160 |
+
|
| 161 |
+
def __init__(self, audio_hop_length: int, audio_sampling_rate: int):
|
| 162 |
+
self.audio_hop_length = audio_hop_length
|
| 163 |
+
self.audio_sampling_rate = audio_sampling_rate
|
| 164 |
+
|
| 165 |
+
@classmethod
|
| 166 |
+
def _get_config(cls, model_name_or_path: str):
|
| 167 |
+
if os.path.exists(model_name_or_path):
|
| 168 |
+
config_path = os.path.join(model_name_or_path, "config.json")
|
| 169 |
+
else:
|
| 170 |
+
config_path = hf_hub_download(
|
| 171 |
+
repo_id=model_name_or_path,
|
| 172 |
+
filename="config.json",
|
| 173 |
+
revision=cls.revision,
|
| 174 |
+
)
|
| 175 |
+
with open(config_path) as fin:
|
| 176 |
+
config = cls.config_cls(**json.load(fin))
|
| 177 |
+
return config
|
| 178 |
+
|
| 179 |
+
@classmethod
|
| 180 |
+
def from_pretrained(cls, model_name_or_path: str) -> "Processor":
|
| 181 |
+
config = cls._get_config(model_name_or_path)
|
| 182 |
+
return cls(
|
| 183 |
+
audio_hop_length=config.audio_codec.hop_length,
|
| 184 |
+
audio_sampling_rate=config.audio_codec.sample_rate,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def feature_to_wav_idx(self, feature_idx):
|
| 188 |
+
return feature_idx * self.audio_hop_length
|
| 189 |
+
|
| 190 |
+
def wav_to_feature_idx(self, wav_idx):
|
| 191 |
+
if torch.is_tensor(wav_idx):
|
| 192 |
+
ceil = torch.ceil
|
| 193 |
+
else:
|
| 194 |
+
ceil = math.ceil
|
| 195 |
+
return ceil(wav_idx / self.audio_hop_length)
|
| 196 |
+
|
| 197 |
+
def mask_videos(
|
| 198 |
+
self,
|
| 199 |
+
videos: List[str | torch.Tensor],
|
| 200 |
+
masks: List[str | torch.Tensor],
|
| 201 |
+
) -> list[torch.Tensor]:
|
| 202 |
+
video = [VideoDecoder(v)[:] if isinstance(v, str) else v for v in videos]
|
| 203 |
+
video_mask = [VideoDecoder(v)[:] if isinstance(v, str) else v for v in masks]
|
| 204 |
+
return [v * m.eq(0) for v, m in zip(video, video_mask, strict=False)]
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class SAMAudioProcessor(Processor):
|
| 208 |
+
config_cls = SAMAudioConfig
|
| 209 |
+
revision = None
|
| 210 |
+
|
| 211 |
+
def __call__(
|
| 212 |
+
self,
|
| 213 |
+
descriptions: list[str],
|
| 214 |
+
audios: list[str | torch.Tensor],
|
| 215 |
+
anchors: Optional[list[list[Anchor]]] = None,
|
| 216 |
+
masked_videos: Optional[list[str | torch.Tensor]] = None,
|
| 217 |
+
):
|
| 218 |
+
"""
|
| 219 |
+
Processes input data for the model.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
descriptions (list[str]): List of text descriptions corresponding to each audio sample.
|
| 223 |
+
audios (list[str]): List of audio file paths or tensors.
|
| 224 |
+
If a tensor:
|
| 225 |
+
- should have shape (channels, time) where channels=1 for mono and 2 for stereo.
|
| 226 |
+
- should be resampled to 48_000 hz
|
| 227 |
+
anchors (Optional[list[list[Anchor]]], optional): List of anchors for each sample,
|
| 228 |
+
where each anchor is a tuple (token, start_time, end_time).
|
| 229 |
+
masked_videos (Optional[list[str | torch.Tensor]], optional): List of masked video file paths or tensors.
|
| 230 |
+
If a tensor, should have shape (N, C, H, W)
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Batch: A Batch object containing processed audio, sizes, descriptions, anchor ids, anchor alignment, audio pad mask, and optionally masked video.
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
assert len(descriptions) == len(audios)
|
| 237 |
+
assert anchors is None or len(descriptions) == len(anchors)
|
| 238 |
+
assert masked_videos is None or len(descriptions) == len(masked_videos)
|
| 239 |
+
|
| 240 |
+
audios, wav_sizes = batch_audio(audios, self.audio_sampling_rate)
|
| 241 |
+
|
| 242 |
+
sizes = self.wav_to_feature_idx(wav_sizes)
|
| 243 |
+
audio_pad_mask = mask_from_sizes(sizes)
|
| 244 |
+
masked_video = None
|
| 245 |
+
if masked_videos is not None:
|
| 246 |
+
masked_video = load_video(
|
| 247 |
+
sizes, masked_videos, self.feature_to_wav_idx, self.audio_sampling_rate
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
return Batch(
|
| 251 |
+
audios=audios,
|
| 252 |
+
sizes=sizes,
|
| 253 |
+
descriptions=descriptions,
|
| 254 |
+
audio_pad_mask=audio_pad_mask,
|
| 255 |
+
anchors=anchors,
|
| 256 |
+
masked_video=masked_video,
|
| 257 |
+
hop_length=self.audio_hop_length,
|
| 258 |
+
audio_sampling_rate=self.audio_sampling_rate,
|
| 259 |
+
wav_sizes=wav_sizes,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class SAMAudioJudgeProcessor(Processor):
|
| 264 |
+
config_cls = SAMAudioJudgeConfig
|
| 265 |
+
revision = "sam_audio"
|
| 266 |
+
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
audio_hop_length: int,
|
| 270 |
+
audio_sampling_rate: int,
|
| 271 |
+
tokenizer: AutoTokenizer,
|
| 272 |
+
):
|
| 273 |
+
super().__init__(audio_hop_length, audio_sampling_rate)
|
| 274 |
+
self.tokenizer = tokenizer
|
| 275 |
+
|
| 276 |
+
@classmethod
|
| 277 |
+
def from_pretrained(cls, model_name_or_path: str) -> "SAMAudioJudgeProcessor":
|
| 278 |
+
config = cls._get_config(model_name_or_path)
|
| 279 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 280 |
+
return cls(
|
| 281 |
+
audio_hop_length=config.audio_codec.hop_length,
|
| 282 |
+
audio_sampling_rate=config.audio_codec.sample_rate,
|
| 283 |
+
tokenizer=tokenizer,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
def _reflect_pad(self, wav):
|
| 287 |
+
if wav.ndim == 1:
|
| 288 |
+
wav = wav.unsqueeze(0)
|
| 289 |
+
if wav.size(-1) % self.audio_hop_length == 0:
|
| 290 |
+
return wav
|
| 291 |
+
p1d = (0, self.audio_hop_length - (wav.size(-1) % self.audio_hop_length))
|
| 292 |
+
return torch.nn.functional.pad(wav, p1d, mode="reflect")
|
| 293 |
+
|
| 294 |
+
def _load_audio(self, path: str):
|
| 295 |
+
ad = AudioDecoder(path, sample_rate=self.audio_sampling_rate, num_channels=1)
|
| 296 |
+
return ad.get_all_samples().data
|
| 297 |
+
|
| 298 |
+
def _process_audio(
|
| 299 |
+
self,
|
| 300 |
+
raw_audio,
|
| 301 |
+
sampling_rate: Optional[int] = None,
|
| 302 |
+
):
|
| 303 |
+
from_file = False
|
| 304 |
+
if isinstance(raw_audio, str):
|
| 305 |
+
raw_audio = [raw_audio]
|
| 306 |
+
|
| 307 |
+
if isinstance(raw_audio, (list, tuple)) and isinstance(raw_audio[0], str):
|
| 308 |
+
loaded = []
|
| 309 |
+
for audio_file in raw_audio:
|
| 310 |
+
loaded.append(self._load_audio(audio_file))
|
| 311 |
+
raw_audio = loaded
|
| 312 |
+
from_file = True
|
| 313 |
+
|
| 314 |
+
if sampling_rate is not None:
|
| 315 |
+
if sampling_rate != self.audio_sampling_rate:
|
| 316 |
+
raise ValueError(
|
| 317 |
+
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
|
| 318 |
+
f" {self.audio_sampling_rate}. Please make sure that the provided audio input was sampled with"
|
| 319 |
+
f" {self.audio_sampling_rate} and not {sampling_rate}."
|
| 320 |
+
)
|
| 321 |
+
elif not from_file:
|
| 322 |
+
logger.warning(
|
| 323 |
+
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
|
| 324 |
+
"Failing to do so can result in silent errors that might be hard to debug."
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
if isinstance(raw_audio, list):
|
| 328 |
+
raw_audio = [self._reflect_pad(x).T for x in raw_audio]
|
| 329 |
+
else:
|
| 330 |
+
raw_audio = self._reflect_pad(raw_audio).T
|
| 331 |
+
|
| 332 |
+
# verify inputs are valid
|
| 333 |
+
for example in raw_audio:
|
| 334 |
+
if example.ndim > 2:
|
| 335 |
+
raise ValueError(
|
| 336 |
+
f"Expected input shape (channels, num_samples), but got shape ({example.shape})"
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
lengths = torch.tensor([x.size(0) for x in raw_audio])
|
| 340 |
+
input_values = pad_sequence(raw_audio, batch_first=True).transpose(1, 2)
|
| 341 |
+
padding_mask = torch.arange(lengths.max())[None] < lengths[:, None]
|
| 342 |
+
|
| 343 |
+
return BatchFeature(
|
| 344 |
+
{"input_values": input_values, "padding_mask": padding_mask}
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def __call__(
|
| 348 |
+
self,
|
| 349 |
+
text: Optional[str] = None,
|
| 350 |
+
input_audio: Optional[
|
| 351 |
+
str | list[str] | torch.Tensor | list[torch.Tensor]
|
| 352 |
+
] = None,
|
| 353 |
+
separated_audio: Optional[
|
| 354 |
+
str | list[str] | torch.Tensor | list[torch.Tensor]
|
| 355 |
+
] = None,
|
| 356 |
+
sampling_rate: Optional[int] = None,
|
| 357 |
+
**kwargs,
|
| 358 |
+
):
|
| 359 |
+
batch = BatchFeature()
|
| 360 |
+
if text is not None:
|
| 361 |
+
batch.update(
|
| 362 |
+
self.tokenizer(
|
| 363 |
+
text,
|
| 364 |
+
return_tensors="pt",
|
| 365 |
+
padding="longest",
|
| 366 |
+
max_length=512,
|
| 367 |
+
truncation=True,
|
| 368 |
+
)
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
if input_audio is not None:
|
| 372 |
+
batch.update(self._process_audio(input_audio, sampling_rate))
|
| 373 |
+
|
| 374 |
+
if separated_audio is not None:
|
| 375 |
+
batch["separated_values"] = self._process_audio(
|
| 376 |
+
separated_audio, sampling_rate
|
| 377 |
+
)["input_values"]
|
| 378 |
+
|
| 379 |
+
return batch
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
__all__ = ["SAMAudioProcessor", "SAMAudioJudgeProcessor", "Batch"]
|
sam_audio/ranking/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from sam_audio.model.config import (
|
| 4 |
+
ClapRankerConfig,
|
| 5 |
+
EnsembleRankerConfig,
|
| 6 |
+
ImageBindRankerConfig,
|
| 7 |
+
JudgeRankerConfig,
|
| 8 |
+
)
|
| 9 |
+
from sam_audio.ranking.clap import ClapRanker
|
| 10 |
+
from sam_audio.ranking.imagebind import ImageBindRanker
|
| 11 |
+
from sam_audio.ranking.judge import JudgeRanker
|
| 12 |
+
from sam_audio.ranking.ranker import EnsembleRanker
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def create_ranker(config):
|
| 16 |
+
if isinstance(config, ImageBindRankerConfig):
|
| 17 |
+
return ImageBindRanker(config)
|
| 18 |
+
elif isinstance(config, ClapRankerConfig):
|
| 19 |
+
return ClapRanker(config)
|
| 20 |
+
elif isinstance(config, JudgeRankerConfig):
|
| 21 |
+
return JudgeRanker(config)
|
| 22 |
+
elif isinstance(config, EnsembleRankerConfig):
|
| 23 |
+
ranker_cfgs, weights = zip(*config.rankers.values(), strict=False)
|
| 24 |
+
return EnsembleRanker(
|
| 25 |
+
rankers=[create_ranker(cfg) for cfg in ranker_cfgs],
|
| 26 |
+
weights=weights,
|
| 27 |
+
)
|
| 28 |
+
else:
|
| 29 |
+
assert config is None
|
| 30 |
+
return None
|
sam_audio/ranking/clap.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
|
| 7 |
+
from sam_audio.model.config import ClapRankerConfig
|
| 8 |
+
from sam_audio.ranking.ranker import Ranker
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_model(device="cpu"):
|
| 12 |
+
import laion_clap
|
| 13 |
+
|
| 14 |
+
model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-tiny").to(device)
|
| 15 |
+
checkpoint_file = hf_hub_download(
|
| 16 |
+
repo_id="lukewys/laion_clap", filename="630k-best.pt"
|
| 17 |
+
)
|
| 18 |
+
state_dict = torch.load(checkpoint_file, map_location=device, weights_only=False)[
|
| 19 |
+
"state_dict"
|
| 20 |
+
]
|
| 21 |
+
if next(iter(state_dict.items()))[0].startswith("module"):
|
| 22 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
| 23 |
+
|
| 24 |
+
if "text_branch.embeddings.position_ids" in state_dict:
|
| 25 |
+
del state_dict["text_branch.embeddings.position_ids"]
|
| 26 |
+
|
| 27 |
+
model.model.load_state_dict(state_dict)
|
| 28 |
+
return model.eval()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ClapRanker(Ranker):
|
| 32 |
+
def __init__(self, config: ClapRankerConfig):
|
| 33 |
+
from laion_clap.training import data
|
| 34 |
+
|
| 35 |
+
self.laion_data_module = data
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.config = config
|
| 38 |
+
self.model = get_model()
|
| 39 |
+
|
| 40 |
+
def _prepare_audio(self, audio, sample_rate):
|
| 41 |
+
audio_features = []
|
| 42 |
+
for candidates in audio:
|
| 43 |
+
if sample_rate != 48_000:
|
| 44 |
+
candidates = torchaudio.functional.resample(
|
| 45 |
+
candidates, sample_rate, 48000
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
quantized = self.laion_data_module.int16_to_float32_torch(
|
| 49 |
+
self.laion_data_module.float32_to_int16_torch(candidates)
|
| 50 |
+
).float()
|
| 51 |
+
for sample in quantized:
|
| 52 |
+
temp_dict = {}
|
| 53 |
+
temp_dict = self.laion_data_module.get_audio_features(
|
| 54 |
+
temp_dict,
|
| 55 |
+
sample,
|
| 56 |
+
480000,
|
| 57 |
+
data_truncating=(
|
| 58 |
+
"fusion" if self.model.enable_fusion else "rand_trunc"
|
| 59 |
+
),
|
| 60 |
+
data_filling="repeatpad",
|
| 61 |
+
audio_cfg=self.model.model_cfg["audio_cfg"],
|
| 62 |
+
require_grad=False,
|
| 63 |
+
)
|
| 64 |
+
audio_features.append(temp_dict)
|
| 65 |
+
return audio_features
|
| 66 |
+
|
| 67 |
+
@torch.inference_mode()
|
| 68 |
+
def forward(
|
| 69 |
+
self,
|
| 70 |
+
extracted_audio: list[torch.Tensor],
|
| 71 |
+
descriptions: list[str],
|
| 72 |
+
sample_rate: int = 48_000,
|
| 73 |
+
**kwargs,
|
| 74 |
+
):
|
| 75 |
+
audio_embed = self.model.model.get_audio_embedding(
|
| 76 |
+
self._prepare_audio(extracted_audio, sample_rate)
|
| 77 |
+
)
|
| 78 |
+
text_embed = self.model.get_text_embedding(descriptions, use_tensor=True)
|
| 79 |
+
bsz = len(extracted_audio)
|
| 80 |
+
candidates = len(audio_embed) // bsz
|
| 81 |
+
audio_embed = audio_embed.reshape(bsz, candidates, -1)
|
| 82 |
+
text_embed = text_embed.reshape(bsz, -1, 1)
|
| 83 |
+
scores = audio_embed @ text_embed
|
| 84 |
+
return scores.squeeze(-1)
|
sam_audio/ranking/imagebind.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import List, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio
|
| 8 |
+
|
| 9 |
+
from sam_audio.model.config import ImageBindRankerConfig
|
| 10 |
+
from sam_audio.ranking.ranker import Ranker
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from imagebind.data import (
|
| 14 |
+
ConstantClipsPerVideoSampler,
|
| 15 |
+
NormalizeVideo,
|
| 16 |
+
SpatialCrop,
|
| 17 |
+
get_clip_timepoints,
|
| 18 |
+
load_and_transform_video_data,
|
| 19 |
+
pv_transforms,
|
| 20 |
+
transforms,
|
| 21 |
+
waveform2melspec,
|
| 22 |
+
)
|
| 23 |
+
from imagebind.models.imagebind_model import ModalityType, imagebind_huge
|
| 24 |
+
|
| 25 |
+
__imagebind_exists__ = True
|
| 26 |
+
except ImportError:
|
| 27 |
+
__imagebind_exists__ = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_and_transform_audio_data(
|
| 31 |
+
audios: List[Union[str, torch.Tensor]],
|
| 32 |
+
input_sample_rate=None,
|
| 33 |
+
num_mel_bins=128,
|
| 34 |
+
target_length=204,
|
| 35 |
+
sample_rate=16000,
|
| 36 |
+
clip_duration=2,
|
| 37 |
+
clips_per_video=3,
|
| 38 |
+
mean=-4.268,
|
| 39 |
+
std=9.138,
|
| 40 |
+
device=None,
|
| 41 |
+
):
|
| 42 |
+
if audios is None:
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
audio_outputs = []
|
| 46 |
+
clip_sampler = ConstantClipsPerVideoSampler(
|
| 47 |
+
clip_duration=clip_duration, clips_per_video=clips_per_video
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
for audio in audios:
|
| 51 |
+
if isinstance(audio, str):
|
| 52 |
+
waveform, input_sample_rate = torchaudio.load(audio)
|
| 53 |
+
else:
|
| 54 |
+
assert torch.is_tensor(audio)
|
| 55 |
+
assert sample_rate is not None
|
| 56 |
+
# Preprocessing needs to be done in full precision
|
| 57 |
+
waveform = audio.float()
|
| 58 |
+
if waveform.ndim == 1:
|
| 59 |
+
waveform = waveform[None]
|
| 60 |
+
if sample_rate != input_sample_rate:
|
| 61 |
+
waveform = torchaudio.functional.resample(
|
| 62 |
+
waveform, orig_freq=input_sample_rate, new_freq=sample_rate
|
| 63 |
+
)
|
| 64 |
+
all_clips_timepoints = get_clip_timepoints(
|
| 65 |
+
clip_sampler, waveform.size(1) / sample_rate
|
| 66 |
+
)
|
| 67 |
+
all_clips = []
|
| 68 |
+
for clip_timepoints in all_clips_timepoints:
|
| 69 |
+
waveform_clip = waveform[
|
| 70 |
+
:,
|
| 71 |
+
int(clip_timepoints[0] * sample_rate) : int(
|
| 72 |
+
clip_timepoints[1] * sample_rate
|
| 73 |
+
),
|
| 74 |
+
]
|
| 75 |
+
waveform_melspec = waveform2melspec(
|
| 76 |
+
waveform_clip, sample_rate, num_mel_bins, target_length
|
| 77 |
+
)
|
| 78 |
+
all_clips.append(waveform_melspec)
|
| 79 |
+
|
| 80 |
+
normalize = transforms.Normalize(mean=mean, std=std)
|
| 81 |
+
all_clips = [normalize(ac).to(device) for ac in all_clips]
|
| 82 |
+
|
| 83 |
+
all_clips = torch.stack(all_clips, dim=0)
|
| 84 |
+
audio_outputs.append(all_clips)
|
| 85 |
+
|
| 86 |
+
return torch.stack(audio_outputs, dim=0)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class VideoTransform:
|
| 90 |
+
def __init__(self, clip_duration=2, clips_per_video=5):
|
| 91 |
+
self.clip_duration = clip_duration
|
| 92 |
+
self.clips_per_video = clips_per_video
|
| 93 |
+
self.clip_sampler = ConstantClipsPerVideoSampler(
|
| 94 |
+
clip_duration=clip_duration, clips_per_video=clips_per_video
|
| 95 |
+
)
|
| 96 |
+
self.video_transform = transforms.Compose(
|
| 97 |
+
[
|
| 98 |
+
pv_transforms.ShortSideScale(224),
|
| 99 |
+
NormalizeVideo(
|
| 100 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
| 101 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
| 102 |
+
),
|
| 103 |
+
]
|
| 104 |
+
)
|
| 105 |
+
self.spatial_crop = SpatialCrop(224, num_crops=3)
|
| 106 |
+
|
| 107 |
+
def load_video_fast(self, videos, durations, **kwargs):
|
| 108 |
+
result = []
|
| 109 |
+
for video, duration in zip(videos, durations, strict=False):
|
| 110 |
+
nframes = video.size(0)
|
| 111 |
+
fps = video.size(0) / duration
|
| 112 |
+
timepoints = get_clip_timepoints(
|
| 113 |
+
self.clip_sampler,
|
| 114 |
+
duration,
|
| 115 |
+
)
|
| 116 |
+
# Instead of loading 5 2s clips, and then sub-sampling frames, we figure
|
| 117 |
+
# Out the indices of the final clips we want and only decode those.
|
| 118 |
+
all_idxs = []
|
| 119 |
+
for start_time, end_time in timepoints:
|
| 120 |
+
idxs = torch.arange(
|
| 121 |
+
min(int(math.ceil(fps * start_time)), nframes - 1),
|
| 122 |
+
min(int(math.ceil(fps * end_time)), nframes),
|
| 123 |
+
)
|
| 124 |
+
ts = (
|
| 125 |
+
torch.linspace(0, idxs.size(0) - 1, self.clip_duration)
|
| 126 |
+
.clamp(max=idxs.size(0) - 1)
|
| 127 |
+
.long()
|
| 128 |
+
)
|
| 129 |
+
all_idxs.append(idxs[ts])
|
| 130 |
+
all_idxs = torch.cat(all_idxs)
|
| 131 |
+
fast_frames = video[all_idxs].transpose(0, 1)
|
| 132 |
+
result.append(fast_frames.chunk(self.clips_per_video, dim=1))
|
| 133 |
+
return result
|
| 134 |
+
|
| 135 |
+
def transform_video(self, batch, device=None):
|
| 136 |
+
device = device or torch.device("cpu")
|
| 137 |
+
video_outputs = []
|
| 138 |
+
for all_video in batch:
|
| 139 |
+
all_video = [
|
| 140 |
+
self.video_transform(clip.to(device) / 255.0) for clip in all_video
|
| 141 |
+
]
|
| 142 |
+
all_video = self.spatial_crop(all_video)
|
| 143 |
+
all_video = torch.stack(all_video, dim=0)
|
| 144 |
+
video_outputs.append(all_video)
|
| 145 |
+
return torch.stack(video_outputs, dim=0)
|
| 146 |
+
|
| 147 |
+
def __call__(self, videos, durations, device=None):
|
| 148 |
+
return self.transform_video(
|
| 149 |
+
self.load_video_fast(videos, durations), device=device
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class ImageBindRanker(Ranker):
|
| 154 |
+
def __init__(self, cfg: ImageBindRankerConfig):
|
| 155 |
+
super().__init__()
|
| 156 |
+
assert __imagebind_exists__, (
|
| 157 |
+
"Install ImageBind in order to use this ranker: https://github.com/facebookresearch/ImageBind/tree/main"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
self.model = imagebind_huge(pretrained=cfg.checkpoint is None)
|
| 161 |
+
if cfg.checkpoint is not None:
|
| 162 |
+
self.model.load_state_dict(torch.load(cfg.checkpoint, map_location="cpu"))
|
| 163 |
+
self.model = self.model.eval()
|
| 164 |
+
self.video_transform = VideoTransform()
|
| 165 |
+
|
| 166 |
+
@torch.inference_mode()
|
| 167 |
+
def forward(
|
| 168 |
+
self,
|
| 169 |
+
extracted_audio: list[torch.Tensor],
|
| 170 |
+
videos: list[torch.Tensor | str],
|
| 171 |
+
sample_rate: int = 48_000,
|
| 172 |
+
**kwargs,
|
| 173 |
+
):
|
| 174 |
+
audio_data = torch.cat(
|
| 175 |
+
[
|
| 176 |
+
load_and_transform_audio_data(x, input_sample_rate=sample_rate)
|
| 177 |
+
for x in extracted_audio
|
| 178 |
+
],
|
| 179 |
+
dim=0,
|
| 180 |
+
)
|
| 181 |
+
if isinstance(videos[0], str):
|
| 182 |
+
video_data = load_and_transform_video_data(videos)
|
| 183 |
+
else:
|
| 184 |
+
durations = [x.size(-1) / sample_rate for x in extracted_audio]
|
| 185 |
+
video_data = self.video_transform(videos, durations, audio_data.device)
|
| 186 |
+
|
| 187 |
+
inputs = {ModalityType.AUDIO: audio_data, ModalityType.VISION: video_data}
|
| 188 |
+
embs = self.model(inputs)
|
| 189 |
+
audio_embs, video_embs = embs[ModalityType.AUDIO], embs[ModalityType.VISION]
|
| 190 |
+
audio_embs, video_embs = (
|
| 191 |
+
audio_embs / ((audio_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
|
| 192 |
+
video_embs / ((video_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
|
| 193 |
+
)
|
| 194 |
+
bsz = len(extracted_audio)
|
| 195 |
+
candidates = len(audio_embs) // bsz
|
| 196 |
+
scores = audio_embs.view(bsz, candidates, -1) @ video_embs.view(bsz, -1, 1)
|
| 197 |
+
return scores
|
sam_audio/ranking/judge.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ..model.config import JudgeRankerConfig
|
| 6 |
+
from ..model.judge import SAMAudioJudgeModel
|
| 7 |
+
from ..processor import SAMAudioJudgeProcessor
|
| 8 |
+
from .ranker import Ranker
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class JudgeRanker(Ranker):
|
| 12 |
+
def __init__(self, config: JudgeRankerConfig):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.config = config
|
| 15 |
+
self.model = SAMAudioJudgeModel.from_pretrained(config.checkpoint_or_model_id)
|
| 16 |
+
self.processor = SAMAudioJudgeProcessor.from_pretrained(
|
| 17 |
+
config.checkpoint_or_model_id
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
@torch.inference_mode()
|
| 21 |
+
def forward(
|
| 22 |
+
self,
|
| 23 |
+
input_audio: list[torch.Tensor],
|
| 24 |
+
extracted_audio: list[torch.Tensor],
|
| 25 |
+
descriptions: list[str],
|
| 26 |
+
sample_rate: int = 48_000,
|
| 27 |
+
**kwargs,
|
| 28 |
+
):
|
| 29 |
+
bsz, ncandidates = len(input_audio), len(input_audio[0])
|
| 30 |
+
input_seqs = [x[None] for candidates in input_audio for x in candidates]
|
| 31 |
+
extracted_seqs = [x[None] for candidates in extracted_audio for x in candidates]
|
| 32 |
+
repeated_descriptions = [x for x in descriptions for _ in range(ncandidates)]
|
| 33 |
+
processed = self.processor(
|
| 34 |
+
text=repeated_descriptions,
|
| 35 |
+
input_audio=input_seqs,
|
| 36 |
+
separated_audio=extracted_seqs,
|
| 37 |
+
return_tensors="pt",
|
| 38 |
+
padding=True,
|
| 39 |
+
sampling_rate=sample_rate,
|
| 40 |
+
)
|
| 41 |
+
res = self.model(**processed.to(input_audio[0].device))
|
| 42 |
+
return res.overall.view(bsz, ncandidates)
|
sam_audio/ranking/ranker.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from abc import ABCMeta, abstractmethod
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Ranker(torch.nn.Module, metaclass=ABCMeta):
|
| 10 |
+
@abstractmethod
|
| 11 |
+
def forward(self, audio: list[torch.Tensor], **kwargs) -> torch.Tensor:
|
| 12 |
+
"""
|
| 13 |
+
Args:
|
| 14 |
+
audio: (list[torch.Tensor]) where each element in the list corresponds to
|
| 15 |
+
the candidates for the i'th generation (num_candidates, num_frames)
|
| 16 |
+
Returns:
|
| 17 |
+
(torch.Tensor) of shape (batch_size, num_candidates) correspoding to the ranking scores
|
| 18 |
+
"""
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class EnsembleRanker(Ranker):
|
| 23 |
+
def __init__(self, rankers: List[Ranker], weights: List[float]):
|
| 24 |
+
super().__init__()
|
| 25 |
+
assert len(rankers) == len(weights)
|
| 26 |
+
self.rankers = torch.nn.ModuleList(rankers)
|
| 27 |
+
self.weights = weights
|
| 28 |
+
|
| 29 |
+
def forward(self, **kwargs) -> torch.Tensor:
|
| 30 |
+
result = None
|
| 31 |
+
for weight, ranker in zip(self.weights, self.rankers, strict=False):
|
| 32 |
+
if result is None:
|
| 33 |
+
result = weight * ranker(**kwargs)
|
| 34 |
+
else:
|
| 35 |
+
result += weight * ranker(**kwargs)
|
| 36 |
+
return result
|
sam_audio/ranking/sound_activity.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
from typing import Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torchcodec.encoders import AudioEncoder
|
| 8 |
+
|
| 9 |
+
from ..model.config import SoundActivityRankerConfig
|
| 10 |
+
from .ranker import Ranker
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import pydub
|
| 14 |
+
except ImportError:
|
| 15 |
+
pydub = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_peak_rms(audio, win_ms=250, hop_ms=100):
|
| 19 |
+
"""
|
| 20 |
+
win_length and hop_length are in ms
|
| 21 |
+
"""
|
| 22 |
+
last_slice_start = len(audio) - win_ms
|
| 23 |
+
slice_starts = range(0, last_slice_start + 1, hop_ms)
|
| 24 |
+
peak_rms = -1
|
| 25 |
+
for i in slice_starts:
|
| 26 |
+
audio_slice = audio[i : i + win_ms]
|
| 27 |
+
peak_rms = max(peak_rms, audio_slice.rms / audio.max_possible_amplitude)
|
| 28 |
+
# Ensure peak_rms is positive
|
| 29 |
+
peak_rms = max(peak_rms, 0)
|
| 30 |
+
return peak_rms
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def torch_tensor_to_pydub(wav: torch.Tensor, sample_rate: int):
|
| 34 |
+
bytesio = BytesIO()
|
| 35 |
+
encoder = AudioEncoder(wav, sample_rate=sample_rate)
|
| 36 |
+
encoder.to_file_like(bytesio, format="wav")
|
| 37 |
+
bytesio.seek(0)
|
| 38 |
+
audio = pydub.AudioSegment.from_file(bytesio, format="wav")
|
| 39 |
+
return audio
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def detect_nonsilent(
|
| 43 |
+
path: Union[str, Tuple[torch.Tensor, int]], # either a file path or pair wav & sr
|
| 44 |
+
min_sil_ms=250,
|
| 45 |
+
sil_threshold=-40,
|
| 46 |
+
threshold_mode="rel_to_max",
|
| 47 |
+
):
|
| 48 |
+
TH_MODES = {"abs", "rel_to_max"}
|
| 49 |
+
SAMPLE_RATE = 24_000
|
| 50 |
+
assert threshold_mode in TH_MODES, f"{threshold_mode=} not in {TH_MODES}"
|
| 51 |
+
if isinstance(path, str):
|
| 52 |
+
audio = pydub.AudioSegment.from_file(path)
|
| 53 |
+
else: # tuple of (tensor, sr)
|
| 54 |
+
audio = torch_tensor_to_pydub(path[0], path[1])
|
| 55 |
+
audio = audio.set_frame_rate(SAMPLE_RATE)
|
| 56 |
+
if threshold_mode == "rel_to_max":
|
| 57 |
+
peak_rms = get_peak_rms(audio)
|
| 58 |
+
sil_threshold = sil_threshold + pydub.utils.ratio_to_db(
|
| 59 |
+
peak_rms
|
| 60 |
+
) # convert to absolute db threshold
|
| 61 |
+
elif threshold_mode == "abs":
|
| 62 |
+
pass
|
| 63 |
+
else:
|
| 64 |
+
raise NotImplementedError(f"Unknown threshold_mode '{threshold_mode}'")
|
| 65 |
+
spans = pydub.silence.detect_nonsilent(
|
| 66 |
+
audio, min_silence_len=min_sil_ms, silence_thresh=sil_threshold, seek_step=10
|
| 67 |
+
)
|
| 68 |
+
spans = [(round(start / 1000, 3), round(end / 1000, 3)) for start, end in spans]
|
| 69 |
+
return spans
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def compute_iou_recall_precision(hyp_spans, ref_spans):
|
| 73 |
+
def span_length(span):
|
| 74 |
+
return span[1] - span[0]
|
| 75 |
+
|
| 76 |
+
def intersection_length(span1, span2):
|
| 77 |
+
return max(0, min(span1[1], span2[1]) - max(span1[0], span2[0]))
|
| 78 |
+
|
| 79 |
+
total_hyp_length = sum(span_length(span) for span in hyp_spans)
|
| 80 |
+
total_ref_length = sum(span_length(span) for span in ref_spans)
|
| 81 |
+
total_intersection = 0
|
| 82 |
+
for hyp_span in hyp_spans:
|
| 83 |
+
for ref_span in ref_spans:
|
| 84 |
+
total_intersection += intersection_length(hyp_span, ref_span)
|
| 85 |
+
|
| 86 |
+
union_spans = hyp_spans + ref_spans # Combine both lists to compute union
|
| 87 |
+
union_length = sum(span_length(span) for span in union_spans) - total_intersection
|
| 88 |
+
|
| 89 |
+
iou = total_intersection / union_length if union_length > 0 else 0
|
| 90 |
+
recall = total_intersection / total_ref_length if total_ref_length > 0 else 0
|
| 91 |
+
precision = total_intersection / total_hyp_length if total_hyp_length > 0 else 0
|
| 92 |
+
|
| 93 |
+
return {"iou": iou, "recall": recall, "precision": precision}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SoundActivityRanker(Ranker):
|
| 97 |
+
def __init__(self, config: SoundActivityRankerConfig):
|
| 98 |
+
if pydub is None:
|
| 99 |
+
raise ImportError(
|
| 100 |
+
'Install reranking dependencies: `pip install "sam-audio[reranking]"`'
|
| 101 |
+
)
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.config = config
|
| 104 |
+
|
| 105 |
+
@torch.inference_mode()
|
| 106 |
+
def forward(
|
| 107 |
+
self,
|
| 108 |
+
extracted_audio: list[torch.Tensor],
|
| 109 |
+
spans: list[list[list[float]]],
|
| 110 |
+
sample_rate: int = 48_000,
|
| 111 |
+
**kwargs,
|
| 112 |
+
):
|
| 113 |
+
device = extracted_audio[0].device
|
| 114 |
+
scores = []
|
| 115 |
+
for wav, current_spans in zip(extracted_audio, spans, strict=True):
|
| 116 |
+
wav = wav.to(torch.float32).cpu()
|
| 117 |
+
# get non-silent spans
|
| 118 |
+
hyp_spans = detect_nonsilent(
|
| 119 |
+
(wav, sample_rate),
|
| 120 |
+
sil_threshold=self.config.sil_threshold,
|
| 121 |
+
threshold_mode=self.config.threshold_mode,
|
| 122 |
+
)
|
| 123 |
+
timestamps = [[span[1], span[2]] for span in current_spans]
|
| 124 |
+
result = compute_iou_recall_precision(hyp_spans, timestamps)
|
| 125 |
+
scores.append(result[self.config.metric])
|
| 126 |
+
|
| 127 |
+
# convert to tensor
|
| 128 |
+
scores = torch.tensor(scores, device=device)
|
| 129 |
+
return scores
|