akhaliq HF Staff commited on
Commit
cff486f
·
verified ·
1 Parent(s): 3250861

Upload 30 files

Browse files
eval/README.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation
2
+
3
+ This directory contains the evaluation code to reproduce the results from the SAM-Audio paper. The evaluation framework supports multiple datasets, prompting modes (text-only, span, visual), and metrics.
4
+
5
+ ## Setup
6
+
7
+ Before running evaluation, ensure you have:
8
+
9
+ 1. Installed the SAM-Audio package and its dependencies
10
+ 2. Authenticated with Hugging Face to access the model checkpoints (see main [README](../README.md))
11
+
12
+ ## Quick Start
13
+
14
+ Run evaluation on the default setting (instr-pro):
15
+
16
+ ```bash
17
+ python main.py
18
+ ```
19
+
20
+ You can also use multiple GPUs to speed up evaluation:
21
+
22
+ ```bash
23
+ torchrun --nproc_per_node=<ngpus> python main.py
24
+ ```
25
+
26
+ Evaluate on a specific setting:
27
+
28
+ ```bash
29
+ python main.py --setting sfx
30
+ ```
31
+
32
+ Evaluate on multiple settings:
33
+
34
+ ```bash
35
+ python main.py --setting sfx speech music
36
+ ```
37
+
38
+ ## Available Evaluation Settings
39
+
40
+ Run `python main.py --help` to see all available settings
41
+
42
+ ## Command Line Options
43
+
44
+ ```bash
45
+ python main.py [OPTIONS]
46
+ ```
47
+
48
+ ### Options:
49
+
50
+ - `-s, --setting` - Which setting(s) to evaluate (default: `instr-pro`)
51
+ - Choices: See available settings above
52
+ - Can specify multiple settings: `--setting sfx speech music`
53
+
54
+ - `--cache-path` - Where to cache downloaded datasets (default: `~/.cache/sam_audio`)
55
+
56
+ - `-p, --checkpoint-path` - Model checkpoint to evaluate (default: `facebook/sam-audio-1b`)
57
+ - Can use local path or Hugging Face model ID
58
+
59
+ - `-b, --batch-size` - Batch size for evaluation (default: `1`)
60
+
61
+ - `-w, --num-workers` - Number of data loading workers (default: `4`)
62
+
63
+ - `-c, --candidates` - Number of reranking candidates (default: `8`)
64
+
65
+ ## Evaluation Metrics
66
+
67
+ The evaluation framework computes the following metrics:
68
+
69
+ - **Judge** - SAM Audio Judge quality assessment metric
70
+ - **Aesthetic** - Aesthetic quality metric
71
+ - **CLAP** - Audio-text alignment metric (CLAP similarity)
72
+ - **ImageBind** - Audio-video alignment metric (for visual settings only)
73
+
74
+ ## Output
75
+
76
+ Results are saved to the `results/` directory as JSON files, one per setting:
77
+
78
+ ```
79
+ results/
80
+ ├── sfx.json
81
+ ├── speech.json
82
+ └── music.json
83
+ ```
84
+
85
+ Each JSON file contains the averaged metric scores across all samples in that setting.
86
+
87
+ Example output:
88
+ ```json
89
+ {
90
+ "JudgeOverall": "4.386",
91
+ "JudgeFaithfulness": "4.708",
92
+ "JudgeRecall": "4.934",
93
+ "JudgePrecision": "4.451",
94
+ "ContentEnjoyment": "5.296",
95
+ "ContentUsefulness": "6.903",
96
+ "ProductionComplexity": "4.301",
97
+ "ProductionQuality": "7.100",
98
+ "CLAPSimilarity": "0.271"
99
+ }
100
+ ```
eval/dataset/__init__.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Callable
4
+
5
+ from .musdb import MUSDB
6
+ from .sam_audio_bench import SAMAudioBench
7
+
8
+ SETTINGS = {
9
+ # Text-only settings
10
+ "sfx": (
11
+ SAMAudioBench,
12
+ {"span": False, "visual": False, "subset": "others-50:text-only"},
13
+ ),
14
+ "speech": (
15
+ SAMAudioBench,
16
+ {"span": False, "visual": False, "subset": "speech-clean-50:text-only"},
17
+ ),
18
+ "speaker": (
19
+ SAMAudioBench,
20
+ {"span": False, "visual": False, "subset": "spk-50:text-only"},
21
+ ),
22
+ "music": (
23
+ SAMAudioBench,
24
+ {"span": False, "visual": False, "subset": "music-clean-50:text-only"},
25
+ ),
26
+ "instr-wild": (
27
+ SAMAudioBench,
28
+ {"span": False, "visual": False, "subset": "instr-50:text-only"},
29
+ ),
30
+ "instr-pro": (MUSDB, {}),
31
+ # Span settings
32
+ "sfx-span": (
33
+ SAMAudioBench,
34
+ {"span": True, "visual": False, "subset": "others-50:text+span"},
35
+ ),
36
+ "speech-span": (
37
+ SAMAudioBench,
38
+ {"span": True, "visual": False, "subset": "speech-clean-50:text+span"},
39
+ ),
40
+ "speaker-span": (
41
+ SAMAudioBench,
42
+ {"span": True, "visual": False, "subset": "spk-50:text+span"},
43
+ ),
44
+ "music-span": (
45
+ SAMAudioBench,
46
+ {"span": True, "visual": False, "subset": "music-clean-50:text+span"},
47
+ ),
48
+ "instr-wild-span": (
49
+ SAMAudioBench,
50
+ {"span": True, "visual": False, "subset": "instr-50:text+span"},
51
+ ),
52
+ # Visual settings
53
+ "sfx-visual": (
54
+ SAMAudioBench,
55
+ {"span": False, "visual": True, "subset": "others-onscreen-50:visual-only"},
56
+ ),
57
+ "speaker-visual": (
58
+ SAMAudioBench,
59
+ {"span": False, "visual": True, "subset": "spk-onscreen-50:visual-only"},
60
+ ),
61
+ "instr-wild-visual": (
62
+ SAMAudioBench,
63
+ {"span": False, "visual": True, "subset": "instr-onscreen-50:visual-only"},
64
+ ),
65
+ }
66
+
67
+
68
+ def make_dataset(setting: str, cache_path: str, collate_fn: Callable):
69
+ dataset, kwargs = SETTINGS[setting]
70
+ return dataset(cache_path=cache_path, collate_fn=collate_fn, **kwargs)
eval/dataset/musdb.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import os
4
+ from subprocess import check_call
5
+
6
+ import torchaudio
7
+ from datasets import load_dataset
8
+ from torch.utils.data import Dataset
9
+ from torchcodec.decoders import AudioDecoder
10
+
11
+
12
+ def cache_file(url, outfile):
13
+ if not os.path.exists(outfile):
14
+ print("Downloading musdb18hq dataset...")
15
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
16
+ check_call(["curl", "--url", url, "--output", outfile + ".tmp"])
17
+ os.rename(outfile + ".tmp", outfile)
18
+
19
+
20
+ class MUSDB(Dataset):
21
+ def __init__(
22
+ self,
23
+ collate_fn,
24
+ sample_rate: int = 48_000,
25
+ cache_path: str = os.path.expanduser("~/.cache/sam_audio"),
26
+ ):
27
+ self.cache_path = os.path.join(cache_path, "musdb18hq")
28
+ self.ds = self.get_dataset(cache_path)
29
+ self.captions = ["bass", "drums", "vocals"]
30
+ self.collate_fn = collate_fn
31
+ self.sample_rate = sample_rate
32
+
33
+ @property
34
+ def visual(self):
35
+ return False
36
+
37
+ def get_dataset(self, cache_path):
38
+ zip_file = os.path.join(cache_path, "musdb18hq.zip")
39
+ url = "https://zenodo.org/records/3338373/files/musdb18hq.zip?download=1"
40
+ cache_file(url, zip_file)
41
+ extracted_dir = os.path.join(cache_path, "musdb18hq")
42
+ if not os.path.exists(extracted_dir):
43
+ check_call(["unzip", zip_file, "-d", extracted_dir + ".tmp"])
44
+ os.rename(extracted_dir + ".tmp", extracted_dir)
45
+ return load_dataset("facebook/sam-audio-musdb18hq-test")["test"]
46
+
47
+ def __len__(self):
48
+ return len(self.ds)
49
+
50
+ def collate(self, items):
51
+ audios, descriptions = zip(*items, strict=False)
52
+ return self.collate_fn(
53
+ audios=audios,
54
+ descriptions=descriptions,
55
+ )
56
+
57
+ def __getitem__(self, idx):
58
+ item = self.ds[idx]
59
+ path = os.path.join(self.cache_path, "test", item["id"], "mixture.wav")
60
+ assert os.path.exists(path), f"{path} does not exist!"
61
+ decoder = AudioDecoder(path)
62
+ data = decoder.get_samples_played_in_range(item["start_time"], item["end_time"])
63
+ wav = data.data
64
+ if data.sample_rate != self.sample_rate:
65
+ wav = torchaudio.functional.resample(
66
+ wav, data.sample_rate, self.sample_rate
67
+ )
68
+ wav = wav.mean(0, keepdim=True)
69
+ return wav, item["description"]
70
+
71
+
72
+ if __name__ == "__main__":
73
+ dataset = MUSDB(lambda **kwargs: None)
74
+ print(len(dataset))
75
+ print(dataset[0])
eval/dataset/sam_audio_bench.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from io import BytesIO
6
+ from typing import Optional, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torchaudio
12
+ from datasets import load_dataset
13
+ from torchcodec.decoders import AudioDecoder, VideoDecoder
14
+
15
+
16
+ @dataclass
17
+ class Item:
18
+ anchors: list[Tuple[str, float, float]]
19
+ masked_video_frames: torch.Tensor
20
+ audio_samples: torch.Tensor
21
+ description: str
22
+
23
+
24
+ class SAMAudioBench(torch.utils.data.Dataset):
25
+ def __init__(
26
+ self,
27
+ cache_path,
28
+ collate_fn,
29
+ span: bool = True,
30
+ visual: bool = True,
31
+ subset: Optional[str] = None,
32
+ ):
33
+ self.dataset = load_dataset("facebook/sam-audio-bench")["test"]
34
+ self.subset = subset
35
+ self._span = span
36
+ self._visual = visual
37
+ if subset is not None:
38
+ self.dataset = self.dataset.filter(lambda x: subset in x["paper_eval_sets"])
39
+
40
+ self.cache_path = os.path.join(cache_path, "sam_audio_bench")
41
+ self.collate_fn = collate_fn
42
+ DATA_MSG = (
43
+ f"`SAMAudioBench` requires the user to create a directory named {self.cache_path} "
44
+ "see the README.md file for how to prepare"
45
+ )
46
+ assert os.path.exists(self.cache_path), DATA_MSG
47
+
48
+ @property
49
+ def visual(self):
50
+ return self._visual
51
+
52
+ def __len__(self):
53
+ return len(self.dataset)
54
+
55
+ def _get_path(
56
+ self, video_id: str, source_dataset: str, start_offset: float, end_offset: float
57
+ ) -> str:
58
+ path = f"{self.cache_path}/{source_dataset}/{video_id}.mp4"
59
+ select_frames = True
60
+
61
+ if not os.path.exists(path):
62
+ path = f"{self.cache_path}/{source_dataset}/{video_id}_{int(start_offset * 1000)}_{int(end_offset * 1000)}.mp4"
63
+ select_frames = False
64
+
65
+ if not os.path.exists(path):
66
+ path = f"{self.cache_path}/{source_dataset}/{video_id}_{int(start_offset)}_{int(end_offset)}.mp4"
67
+
68
+ if not os.path.exists(path):
69
+ path = f"{self.cache_path}/{source_dataset}/{video_id}.{int(start_offset * 1000):08d}_{int(end_offset * 1000):08d}.mp4"
70
+
71
+ return path, select_frames
72
+
73
+ def collate(self, items: list[Item]):
74
+ has_video = any(item.masked_video_frames is not None for item in items)
75
+ return self.collate_fn(
76
+ descriptions=[item.description for item in items],
77
+ audios=[item.audio_samples for item in items],
78
+ anchors=[item.anchors for item in items] if self._span else None,
79
+ masked_videos=[item.masked_video_frames for item in items]
80
+ if has_video and self._visual
81
+ else None,
82
+ )
83
+
84
+ def _get_masked_video(self, item, video_path, select_frames):
85
+ if item["mask_bytes"] is None:
86
+ return None
87
+
88
+ mask = torch.from_numpy(np.load(BytesIO(item["mask_bytes"]))["video_masklet"])
89
+
90
+ video_decoder = VideoDecoder(video_path)
91
+ if select_frames:
92
+ video_frames = video_decoder.get_frames_played_in_range(
93
+ item["start_offset"], item["end_offset"]
94
+ ).data
95
+ else:
96
+ video_frames = video_decoder[:].data
97
+
98
+ if mask.size(0) != video_frames.size(0):
99
+ # It's possible that the mask and the video frames differ by a small amount
100
+ # we interpolate the mask frame to match
101
+ idxs = (
102
+ torch.linspace(0, mask.size(0) - 1, video_frames.size(0)).round().long()
103
+ )
104
+ mask = mask[idxs]
105
+
106
+ mask = mask.unsqueeze(1)
107
+
108
+ if mask.shape[-2:] != video_frames.shape[-2:]:
109
+ mask = F.interpolate(mask, size=video_frames.shape[-2:])
110
+
111
+ import torchvision
112
+
113
+ torchvision.io.write_video("test.mp4", video_frames.permute(0, 2, 3, 1), 30)
114
+ torchvision.io.write_video(
115
+ "test_mask.mp4", mask.unsqueeze(-1).expand(-1, -1, -1, 3) * 255, 30
116
+ )
117
+
118
+ return video_frames * mask
119
+
120
+ def __getitem__(self, idx) -> Item:
121
+ item = self.dataset[idx]
122
+
123
+ video_path, select_frames = self._get_path(
124
+ item["video_id"],
125
+ item["source_dataset"],
126
+ item["start_offset"],
127
+ item["end_offset"],
128
+ )
129
+ assert os.path.exists(video_path), f"{video_path} does not exist!"
130
+
131
+ audio_decoder = AudioDecoder(video_path)
132
+ audio_samples = audio_decoder.get_samples_played_in_range(
133
+ start_seconds=item["start_offset"] if select_frames else 0,
134
+ stop_seconds=item["end_offset"] if select_frames else None,
135
+ )
136
+
137
+ if audio_samples.sample_rate != self.collate_fn.audio_sampling_rate:
138
+ resampled_audio = torchaudio.functional.resample(
139
+ audio_samples.data,
140
+ audio_samples.sample_rate,
141
+ self.collate_fn.audio_sampling_rate,
142
+ )
143
+ else:
144
+ resampled_audio = audio_samples.data
145
+
146
+ masked_video_frames = self._get_masked_video(item, video_path, select_frames)
147
+
148
+ return Item(
149
+ description=item["description"],
150
+ anchors=[("+", start, end) for start, end in item["spans"]],
151
+ masked_video_frames=masked_video_frames,
152
+ audio_samples=resampled_audio.mean(0, keepdim=True),
153
+ )
eval/main.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+
7
+ import pandas as pd
8
+ import torch
9
+ import torch.distributed as dist
10
+ from dataset import SETTINGS, make_dataset
11
+ from metrics import CLAP, Aesthetic, ImageBind, Judge
12
+ from torch.utils.data import DataLoader
13
+ from torch.utils.data.distributed import DistributedSampler
14
+ from tqdm import tqdm
15
+
16
+ from sam_audio import SAMAudio, SAMAudioProcessor
17
+
18
+
19
+ def gather_and_average_results(results, world_size):
20
+ if world_size == 1:
21
+ return json.loads(results.mean().to_json())
22
+
23
+ # 1. Gather all dictionaries to all ranks
24
+ all_results = [None for _ in range(world_size)]
25
+ dist.all_gather_object(
26
+ all_results, {"sum": results.sum().to_json(), "count": len(results)}
27
+ )
28
+
29
+ summed = {}
30
+ counts = 0
31
+
32
+ for res in all_results:
33
+ for k, v in json.loads(res["sum"]).items():
34
+ if k not in summed:
35
+ summed[k] = 0.0
36
+ summed[k] += v
37
+ counts += res["count"]
38
+
39
+ # 3. Compute average for keys that appeared at least once
40
+ averaged = {k: summed[k] / counts for k in summed}
41
+
42
+ return averaged
43
+
44
+
45
+ def main(
46
+ settings: list[str],
47
+ cache_path: str,
48
+ batch_size: int,
49
+ checkpoint_path: str,
50
+ num_workers: int = 4,
51
+ reranking_candidates: int = 8,
52
+ ):
53
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
54
+ rank = int(os.environ.get("RANK", 0))
55
+
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ if world_size > 1:
58
+ torch.distributed.init_process_group(backend="nccl")
59
+ device = torch.device(f"cuda:{rank}")
60
+ torch.cuda.set_device(device)
61
+
62
+ model = SAMAudio.from_pretrained(checkpoint_path)
63
+ model = model.eval().to(device)
64
+ processor = SAMAudioProcessor.from_pretrained(checkpoint_path)
65
+
66
+ judge_metric = Judge(device=device)
67
+ aes_metric = Aesthetic(device=device)
68
+ clap_metric = CLAP(device=device)
69
+ imagebind_metric = ImageBind(device=device)
70
+
71
+ for setting in settings:
72
+ print(f"Evaluating: {setting}")
73
+ dset = make_dataset(setting, cache_path=cache_path, collate_fn=processor)
74
+ sampler = None
75
+ if world_size > 1:
76
+ sampler = DistributedSampler(dset)
77
+
78
+ dl = DataLoader(
79
+ dset,
80
+ batch_size=batch_size,
81
+ shuffle=False,
82
+ collate_fn=dset.collate,
83
+ num_workers=num_workers,
84
+ sampler=sampler,
85
+ )
86
+
87
+ all_metrics = [
88
+ judge_metric,
89
+ aes_metric,
90
+ clap_metric,
91
+ ]
92
+
93
+ if dset.visual:
94
+ all_metrics.append(imagebind_metric)
95
+
96
+ dfs = []
97
+ with torch.inference_mode():
98
+ for batch in tqdm(dl, disable=rank > 1):
99
+ batch = batch.to(device)
100
+ result = model.separate(
101
+ batch, reranking_candidates=reranking_candidates
102
+ )
103
+ mets = {}
104
+ for metric in all_metrics:
105
+ input_wavs = model.unbatch(batch.audios.squeeze(1), batch.wav_sizes)
106
+
107
+ mets.update(
108
+ metric(
109
+ target_wavs=result.target,
110
+ target_wavs_sample_rate=model.sample_rate,
111
+ descriptions=batch.descriptions,
112
+ input_wavs=input_wavs,
113
+ videos=batch.masked_video,
114
+ )
115
+ )
116
+
117
+ dfs.append(pd.DataFrame.from_dict(mets))
118
+
119
+ df = pd.concat(dfs)
120
+ averaged_results = gather_and_average_results(df, world_size)
121
+ if rank == 0:
122
+ results_dict = {k: f"{v:.3f}" for k, v in averaged_results.items()}
123
+ print(json.dumps(results_dict, indent=4))
124
+ os.makedirs("results", exist_ok=True)
125
+ outfile = f"results/{setting}.json"
126
+ with open(outfile, "w") as fout:
127
+ print(json.dumps(results_dict), file=fout)
128
+
129
+
130
+ if __name__ == "__main__":
131
+ parser = argparse.ArgumentParser()
132
+ parser.add_argument(
133
+ "--setting",
134
+ "-s",
135
+ choices=SETTINGS.keys(),
136
+ help=f"Which setting to evaluate. Choices: {SETTINGS.keys()}",
137
+ default=["instr-pro"],
138
+ nargs="+",
139
+ )
140
+ parser.add_argument(
141
+ "--cache-path",
142
+ type=str,
143
+ default=os.path.expanduser("~/.cache/sam_audio"),
144
+ help="Where to cache downloaded datasets",
145
+ )
146
+ parser.add_argument(
147
+ "--checkpoint-path", "-p", type=str, default="facebook/sam-audio-large"
148
+ )
149
+ parser.add_argument("--batch-size", "-b", type=int, default=1, help="Batch size")
150
+ parser.add_argument(
151
+ "--num-workers", "-w", type=int, default=4, help="Number of workers"
152
+ )
153
+ parser.add_argument("--candidates", "-c", type=int, default=8)
154
+ opt = parser.parse_args()
155
+ main(
156
+ settings=opt.setting,
157
+ cache_path=opt.cache_path,
158
+ batch_size=opt.batch_size,
159
+ checkpoint_path=opt.checkpoint_path,
160
+ num_workers=opt.num_workers,
161
+ reranking_candidates=opt.candidates,
162
+ )
eval/metrics/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from metrics.aes import Aesthetic
4
+ from metrics.clap import CLAP
5
+ from metrics.imagebind import ImageBind
6
+ from metrics.judge import Judge
7
+
8
+ __all__ = [
9
+ "Aesthetic",
10
+ "CLAP",
11
+ "ImageBind",
12
+ "Judge",
13
+ ]
eval/metrics/aes.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from audiobox_aesthetics.infer import AesPredictor
7
+
8
+ COLUMN_MAP = {
9
+ "CE": "ContentEnjoyment",
10
+ "CU": "ContentUsefulness",
11
+ "PC": "ProductionComplexity",
12
+ "PQ": "ProductionQuality",
13
+ }
14
+
15
+
16
+ class Aesthetic(torch.nn.Module):
17
+ def __init__(
18
+ self,
19
+ checkpoint: Optional[str] = None,
20
+ device: Optional[torch.device] = None,
21
+ ):
22
+ super().__init__()
23
+ self.model = AesPredictor(
24
+ checkpoint_pth=checkpoint,
25
+ data_col="wav",
26
+ )
27
+ self.device = device or torch.device(
28
+ "cuda" if torch.cuda.is_available() else "cpu"
29
+ )
30
+
31
+ def __call__(
32
+ self,
33
+ target_wavs: list[torch.Tensor],
34
+ target_wavs_sample_rate: int = 48_000,
35
+ **kwargs,
36
+ ) -> dict[str, list[float]]:
37
+ result = self.model.forward(
38
+ [
39
+ {
40
+ "wav": wav[None] if wav.ndim == 1 else wav,
41
+ "sample_rate": target_wavs_sample_rate,
42
+ }
43
+ for wav in target_wavs
44
+ ]
45
+ )
46
+ return {
47
+ long_name: [x[shortname] for x in result]
48
+ for shortname, long_name in COLUMN_MAP.items()
49
+ }
eval/metrics/clap.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from tempfile import TemporaryDirectory
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torchcodec.encoders import AudioEncoder
8
+
9
+ from sam_audio.ranking.clap import get_model
10
+
11
+
12
+ class CLAP(torch.nn.Module):
13
+ def __init__(
14
+ self,
15
+ checkpoint: Optional[str] = None,
16
+ device: Optional[torch.device] = None,
17
+ ):
18
+ super().__init__()
19
+ self.model = get_model(device)
20
+ self.device = device or torch.device(
21
+ "cuda" if torch.cuda.is_available() else "cpu"
22
+ )
23
+
24
+ def __call__(
25
+ self,
26
+ target_wavs: list[torch.Tensor],
27
+ descriptions: list[str],
28
+ target_wavs_sample_rate: int = 48_000,
29
+ **kwargs,
30
+ ) -> list[dict[str, float]]:
31
+ with TemporaryDirectory() as tdir, torch.inference_mode():
32
+ file_list = []
33
+ for i, wav in enumerate(target_wavs):
34
+ file_list.append(f"{tdir}/hyp_{i}.wav")
35
+ encoder = AudioEncoder(
36
+ samples=wav.cpu()[None] if wav.ndim == 1 else wav.cpu(),
37
+ sample_rate=target_wavs_sample_rate,
38
+ )
39
+ encoder.to_file(file_list[-1])
40
+ audio_embs = self.model.get_audio_embedding_from_filelist(
41
+ file_list, use_tensor=True
42
+ )
43
+
44
+ text_embs = self.model.get_text_embedding(descriptions, use_tensor=True)
45
+ sims = audio_embs.unsqueeze(1) @ text_embs.unsqueeze(2)
46
+ return {"CLAPSimilarity": sims.cpu()[:, 0, 0].tolist()}
eval/metrics/imagebind.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from imagebind.models.imagebind_model import ModalityType, imagebind_huge
7
+
8
+ from sam_audio.ranking.imagebind import VideoTransform, load_and_transform_audio_data
9
+
10
+
11
+ class ImageBind(torch.nn.Module):
12
+ def __init__(
13
+ self,
14
+ checkpoint: Optional[str] = None,
15
+ device: Optional[torch.device] = None,
16
+ ):
17
+ super().__init__()
18
+
19
+ self.model = imagebind_huge(pretrained=checkpoint is None)
20
+ if checkpoint is not None:
21
+ self.model.load_state_dict(torch.load(checkpoint, map_location="cpu"))
22
+ self.model = self.model.eval()
23
+ self.video_transform = VideoTransform()
24
+ self.device = device or torch.device(
25
+ "cuda" if torch.cuda.is_available() else "cpu"
26
+ )
27
+ self.model = self.model.to(self.device)
28
+
29
+ def __call__(
30
+ self,
31
+ target_wavs: list[torch.Tensor],
32
+ videos: list[torch.Tensor],
33
+ target_wavs_sample_rate: int = 48_000,
34
+ **kwargs,
35
+ ) -> dict[str, list[float]]:
36
+ audio_data = load_and_transform_audio_data(
37
+ target_wavs, input_sample_rate=target_wavs_sample_rate
38
+ )
39
+ durations = [x.size(-1) / target_wavs_sample_rate for x in target_wavs]
40
+ video_data = self.video_transform(videos, durations, audio_data.device)
41
+
42
+ inputs = {ModalityType.AUDIO: audio_data, ModalityType.VISION: video_data}
43
+ embs = self.model(inputs)
44
+ audio_embs, video_embs = embs[ModalityType.AUDIO], embs[ModalityType.VISION]
45
+ audio_embs, video_embs = (
46
+ audio_embs / ((audio_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
47
+ video_embs / ((video_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
48
+ )
49
+ bsz = len(target_wavs)
50
+ candidates = len(audio_embs) // bsz
51
+ scores = audio_embs.view(bsz, candidates, -1) @ video_embs.view(bsz, -1, 1)
52
+ return {"ImageBind": scores.squeeze(1, 2).cpu().tolist()}
eval/metrics/judge.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ from sam_audio import SAMAudioJudgeModel, SAMAudioJudgeProcessor
8
+
9
+
10
+ class Judge(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ checkpoint: str = "facebook/sam-audio-judge",
14
+ device: Optional[torch.device] = None,
15
+ ):
16
+ super().__init__()
17
+ self.model = SAMAudioJudgeModel.from_pretrained(checkpoint).to(device)
18
+ self.processor = SAMAudioJudgeProcessor.from_pretrained(checkpoint)
19
+ self.device = device or torch.device(
20
+ "cuda" if torch.cuda.is_available() else "cpu"
21
+ )
22
+
23
+ def forward(
24
+ self,
25
+ input_wavs: list[torch.Tensor],
26
+ target_wavs: list[torch.Tensor],
27
+ descriptions: list[str],
28
+ target_wavs_sample_rate: int = 48_000,
29
+ **kwargs,
30
+ ) -> torch.Tensor:
31
+ with torch.inference_mode():
32
+ processed = self.processor(
33
+ text=descriptions,
34
+ input_audio=[x.cpu() for x in input_wavs],
35
+ separated_audio=[x.cpu() for x in target_wavs],
36
+ sampling_rate=target_wavs_sample_rate,
37
+ ).to(self.device)
38
+ result = self.model(**processed)
39
+ return {
40
+ "JudgeOverall": result.overall.squeeze(-1).cpu().tolist(),
41
+ "JudgeFaithfulness": result.faithfulness.squeeze(-1).cpu().tolist(),
42
+ "JudgeRecall": result.recall.squeeze(-1).cpu().tolist(),
43
+ "JudgePrecision": result.precision.squeeze(-1).cpu().tolist(),
44
+ }
sam_audio/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from .model import * # noqa
4
+ from .processor import * # noqa
sam_audio/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from .model import * # noqa
4
+ from .judge import * # noqa
sam_audio/model/align.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ class AlignModalities(torch.nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_channels: int,
12
+ out_channels: int,
13
+ normalize: bool = True,
14
+ with_gate: bool = True,
15
+ ):
16
+ super().__init__()
17
+ self.conv = torch.nn.Conv1d(
18
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1
19
+ )
20
+ self.normalize = normalize
21
+ if self.normalize:
22
+ self.layer_norm = torch.nn.LayerNorm(out_channels)
23
+
24
+ self.gate = None
25
+ if with_gate:
26
+ self.gate = torch.nn.Parameter(torch.tensor([0.0]))
27
+
28
+ self.out_channels = out_channels
29
+
30
+ def forward(self, anchor: torch.Tensor, tgt: Optional[torch.Tensor] = None):
31
+ """
32
+ Align video features to the input audio features
33
+
34
+ Args:
35
+ anchor (torch.Tensor): Input anchor tensor of shape (B, T, C), where B is batch size, C is channel size, and T is sequence length.
36
+ tgt (Optional[torch.Tensor]): Optional features tensor to be aligned to anchor, expected shape (B, in_channels, T).
37
+ """
38
+ if tgt is None:
39
+ return anchor
40
+
41
+ post_conv = self.conv(tgt)
42
+ post_conv = post_conv.permute(0, 2, 1) # BCT -> BTC
43
+
44
+ if self.normalize:
45
+ post_conv = self.layer_norm(post_conv)
46
+
47
+ if self.gate is None:
48
+ return post_conv
49
+ else:
50
+ return anchor + self.gate.tanh() * post_conv
sam_audio/model/base.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import json
4
+ import os
5
+ from typing import Callable, Dict, Optional, Union
6
+
7
+ import torch
8
+ from huggingface_hub import ModelHubMixin, snapshot_download
9
+
10
+
11
+ class BaseModel(torch.nn.Module, ModelHubMixin):
12
+ config_cls: Callable
13
+
14
+ def device(self):
15
+ return next(self.parameters()).device
16
+
17
+ @classmethod
18
+ def _from_pretrained(
19
+ cls,
20
+ *,
21
+ model_id: str,
22
+ cache_dir: str,
23
+ force_download: bool,
24
+ proxies: Optional[Dict],
25
+ resume_download: bool,
26
+ local_files_only: bool,
27
+ token: Union[str, bool, None],
28
+ map_location: str = "cpu",
29
+ strict: bool = True,
30
+ revision: Optional[str] = None,
31
+ **model_kwargs,
32
+ ):
33
+ if os.path.isdir(model_id):
34
+ cached_model_dir = model_id
35
+ else:
36
+ cached_model_dir = snapshot_download(
37
+ repo_id=model_id,
38
+ revision=cls.revision,
39
+ cache_dir=cache_dir,
40
+ force_download=force_download,
41
+ proxies=proxies,
42
+ resume_download=resume_download,
43
+ token=token,
44
+ local_files_only=local_files_only,
45
+ )
46
+
47
+ with open(os.path.join(cached_model_dir, "config.json")) as fin:
48
+ config = json.load(fin)
49
+
50
+ config = cls.config_cls(**config)
51
+ model = cls(config)
52
+ state_dict = torch.load(
53
+ os.path.join(cached_model_dir, "checkpoint.pt"),
54
+ weights_only=True,
55
+ map_location=map_location,
56
+ )
57
+ model.load_state_dict(state_dict, strict=strict)
58
+ return model
sam_audio/model/codec.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import math
4
+ from abc import ABCMeta, abstractmethod
5
+ from typing import Union
6
+
7
+ import dacvae
8
+ import torch
9
+
10
+ from sam_audio.model.config import DACVAEConfig
11
+
12
+
13
+ class Encoder(torch.nn.Module, metaclass=ABCMeta):
14
+ @abstractmethod
15
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor: ...
16
+
17
+
18
+ class Codec(Encoder):
19
+ @abstractmethod
20
+ def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor: ...
21
+
22
+ @abstractmethod
23
+ def wav_idx_to_feature_idx(
24
+ self, wav_idx: Union[torch.Tensor, int], sample_rate=None
25
+ ) -> Union[torch.Tensor, int]: ...
26
+
27
+ @abstractmethod
28
+ def feature_idx_to_wav_idx(
29
+ self, feature_idx: Union[torch.Tensor, int], sample_rate=None
30
+ ) -> Union[torch.Tensor, int]: ...
31
+
32
+ @staticmethod
33
+ def cast_to_int(
34
+ x: Union[int, torch.Tensor],
35
+ ) -> Union[int, torch.Tensor]:
36
+ if isinstance(x, torch.Tensor):
37
+ return x.int()
38
+ else:
39
+ return int(x)
40
+
41
+
42
+ class DACVAEEncoder(Encoder):
43
+ def __init__(self, config: DACVAEConfig) -> None:
44
+ super().__init__()
45
+ model = dacvae.DACVAE(
46
+ encoder_dim=config.encoder_dim,
47
+ encoder_rates=config.encoder_rates,
48
+ latent_dim=config.latent_dim,
49
+ decoder_dim=config.decoder_dim,
50
+ decoder_rates=config.decoder_rates,
51
+ n_codebooks=config.n_codebooks,
52
+ codebook_size=config.codebook_size,
53
+ codebook_dim=config.codebook_dim,
54
+ quantizer_dropout=config.quantizer_dropout,
55
+ sample_rate=config.sample_rate,
56
+ ).eval()
57
+ self._setup_model(model)
58
+ self.hop_length = config.hop_length
59
+ self.sample_rate = config.sample_rate
60
+
61
+ def _setup_model(self, model):
62
+ self.encoder = model.encoder
63
+ self.quantizer = model.quantizer
64
+
65
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
66
+ with torch.no_grad():
67
+ z = self.encoder(self._pad(waveform))
68
+ mean, scale = self.quantizer.in_proj(z).chunk(2, dim=1)
69
+ encoded_frames = mean
70
+ return encoded_frames
71
+
72
+ def _pad(self, wavs):
73
+ length = wavs.size(-1)
74
+ if length % self.hop_length:
75
+ p1d = (0, self.hop_length - (length % self.hop_length))
76
+ return torch.nn.functional.pad(wavs, p1d, "reflect")
77
+ else:
78
+ return wavs
79
+
80
+
81
+ class DACVAE(DACVAEEncoder, Codec):
82
+ def _setup_model(self, model):
83
+ super()._setup_model(model)
84
+ self.decoder = model.decoder
85
+
86
+ def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor:
87
+ emb = self.quantizer.out_proj(encoded_frames)
88
+ return self.decoder(emb)
89
+
90
+ def feature_idx_to_wav_idx(self, feature_idx, sample_rate=None):
91
+ if sample_rate is None:
92
+ sample_rate = self.sample_rate
93
+ orig_freq = sample_rate
94
+ new_freq = self.sample_rate
95
+ wav_chunklen = feature_idx * self.hop_length * (orig_freq / new_freq)
96
+ return self.cast_to_int(wav_chunklen)
97
+
98
+ def wav_idx_to_feature_idx(self, wav_idx, sample_rate=None):
99
+ ceil = math.ceil
100
+ if torch.is_tensor(wav_idx):
101
+ ceil = torch.ceil
102
+ if sample_rate is None:
103
+ sample_rate = self.sample_rate
104
+ orig_freq = sample_rate
105
+ new_freq = self.sample_rate
106
+ target_length = ceil(new_freq * wav_idx / orig_freq)
107
+ res = ceil(target_length / self.hop_length)
108
+ return self.cast_to_int(res)
sam_audio/model/config.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+ from core.audio_visual_encoder.config import TransformerConfig as PEAVTransformerConfig
7
+ from transformers import ModernBertConfig
8
+
9
+
10
+ class DACVAEConfig:
11
+ def __init__(
12
+ self,
13
+ encoder_dim: int = 64,
14
+ encoder_rates: list[int] = [2, 8, 10, 12],
15
+ latent_dim: int = 1024,
16
+ decoder_dim: int = 1536,
17
+ decoder_rates: list[int] = [12, 10, 8, 2],
18
+ n_codebooks: int = 16,
19
+ codebook_size: int = 1024,
20
+ codebook_dim: int = 128,
21
+ quantizer_dropout: bool = False,
22
+ sample_rate: int = 48_000,
23
+ mean: float = 0.0,
24
+ std: float = 1.0,
25
+ ):
26
+ self.encoder_dim = encoder_dim
27
+ self.encoder_rates = encoder_rates
28
+ self.latent_dim = latent_dim
29
+ self.decoder_dim = decoder_dim
30
+ self.decoder_rates = decoder_rates
31
+ self.n_codebooks = n_codebooks
32
+ self.codebook_size = codebook_size
33
+ self.codebook_dim = codebook_dim
34
+ self.quantizer_dropout = quantizer_dropout
35
+ self.sample_rate = sample_rate
36
+ self.mean = mean
37
+ self.std = std
38
+
39
+ @property
40
+ def hop_length(self):
41
+ return int(np.prod(self.encoder_rates))
42
+
43
+
44
+ class TextEncoderConfig:
45
+ def __init__(self, dim: int = 768):
46
+ self.dim = dim
47
+
48
+
49
+ class T5EncoderConfig(TextEncoderConfig):
50
+ def __init__(
51
+ self,
52
+ name: str = "t5-base",
53
+ max_length: Optional[int] = 512,
54
+ pad_mode: str = "longest",
55
+ dim: int = 768,
56
+ ):
57
+ super().__init__(dim=dim)
58
+ self.name = name
59
+ self.max_length = max_length
60
+ self.pad_mode = pad_mode
61
+
62
+
63
+ class VisionEncoderConfig:
64
+ def __init__(self, dim: int = 1024, batch_size: int = 300):
65
+ self.dim = dim
66
+ self.batch_size = batch_size
67
+
68
+
69
+ class PerceptionEncoderConfig(VisionEncoderConfig):
70
+ def __init__(
71
+ self,
72
+ dim: int = 1024,
73
+ batch_size: int = 300,
74
+ name: str = "PE-Core-L14-336",
75
+ normalize_feature: bool = True,
76
+ interpolation_mode: str = "BICUBIC",
77
+ image_size: int = 336,
78
+ ):
79
+ super().__init__(dim=dim, batch_size=batch_size)
80
+ self.name = name
81
+ self.normalize_feature = normalize_feature
82
+ self.interpolation_mode = interpolation_mode
83
+ self.image_size = image_size
84
+
85
+
86
+ class TransformerConfig:
87
+ def __init__(
88
+ self,
89
+ dim: int = 2048,
90
+ n_heads: int = 16,
91
+ n_layers: int = 16,
92
+ dropout: float = 0.1,
93
+ norm_eps: float = 1.0e-05,
94
+ qk_norm: bool = True,
95
+ fc_bias: bool = False,
96
+ ffn_exp: int = 4,
97
+ ffn_dim_multiplier: int = 1,
98
+ multiple_of: int = 64,
99
+ non_linearity: str = "swiglu",
100
+ use_rope: bool = True,
101
+ max_positions: int = 10000,
102
+ frequency_embedding_dim: int = 256,
103
+ timestep_non_linearity: str = "swiglu",
104
+ t_block_non_linearity: str = "silu",
105
+ t_block_bias: bool = True,
106
+ context_dim: int = 2048,
107
+ context_non_linearity: str = "swiglu",
108
+ context_embedder_dropout: float = 0.0,
109
+ context_norm: bool = False,
110
+ out_channels: int = 256,
111
+ in_channels: Optional[int] = None,
112
+ ):
113
+ self.dim = dim
114
+ self.n_heads = n_heads
115
+ self.n_layers = n_layers
116
+ self.dropout = dropout
117
+ self.norm_eps = norm_eps
118
+ self.qk_norm = qk_norm
119
+ self.fc_bias = fc_bias
120
+ self.ffn_exp = ffn_exp
121
+ self.ffn_dim_multiplier = ffn_dim_multiplier
122
+ self.multiple_of = multiple_of
123
+ self.non_linearity = non_linearity
124
+ self.use_rope = use_rope
125
+ self.max_positions = max_positions
126
+ self.frequency_embedding_dim = frequency_embedding_dim
127
+ self.timestep_non_linearity = timestep_non_linearity
128
+ self.t_block_non_linearity = t_block_non_linearity
129
+ self.t_block_bias = t_block_bias
130
+ self.context_dim = context_dim
131
+ self.context_non_linearity = context_non_linearity
132
+ self.context_embedder_dropout = context_embedder_dropout
133
+ self.context_norm = context_norm
134
+ self.out_channels = out_channels
135
+ self.in_channels = in_channels
136
+
137
+
138
+ class RankerConfig:
139
+ kind: str
140
+
141
+
142
+ class ImageBindRankerConfig(RankerConfig):
143
+ kind: str = "imagebind"
144
+
145
+ def __init__(self, checkpoint: Optional[str] = None):
146
+ self.checkpoint = checkpoint
147
+
148
+
149
+ class ClapRankerConfig(RankerConfig):
150
+ kind: str = "clap"
151
+
152
+ def __init__(self, checkpoint: Optional[str] = None):
153
+ self.checkpoint = checkpoint
154
+
155
+
156
+ class JudgeRankerConfig(RankerConfig):
157
+ kind: str = "judge"
158
+
159
+ def __init__(self, checkpoint_or_model_id: str = "facebook/sam-audio-judge"):
160
+ self.checkpoint_or_model_id = checkpoint_or_model_id
161
+
162
+
163
+ class SoundActivityRankerConfig(RankerConfig):
164
+ kind: str = "sound_activity"
165
+
166
+ def __init__(
167
+ self,
168
+ threshold_mode: str = "rel_to_max",
169
+ sil_threshold: float = -40,
170
+ metric: str = "iou",
171
+ ):
172
+ self.threshold_mode = threshold_mode
173
+ self.sil_threshold = sil_threshold
174
+ self.metric = metric
175
+
176
+
177
+ class EnsembleRankerConfig(RankerConfig):
178
+ kind: str = "ensemble"
179
+
180
+ def __init__(self, rankers: dict[str, Tuple[RankerConfig, float]]):
181
+ self.rankers = rankers
182
+
183
+
184
+ def parse_ranker_config(config_dict: dict):
185
+ kind = config_dict.pop("kind")
186
+ match kind:
187
+ case ImageBindRankerConfig.kind:
188
+ return ImageBindRankerConfig(**config_dict)
189
+ case ClapRankerConfig.kind:
190
+ return ClapRankerConfig(**config_dict)
191
+ case JudgeRankerConfig.kind:
192
+ return JudgeRankerConfig(**config_dict)
193
+ case SoundActivityRankerConfig.kind:
194
+ return SoundActivityRankerConfig(**config_dict)
195
+ case EnsembleRankerConfig.kind:
196
+ return EnsembleRankerConfig(
197
+ {
198
+ k: (parse_ranker_config(v), w)
199
+ for k, (v, w) in config_dict["rankers"].items()
200
+ }
201
+ )
202
+
203
+
204
+ class SAMAudioConfig:
205
+ def __init__(
206
+ self,
207
+ in_channels: int = 768,
208
+ audio_codec=None,
209
+ text_encoder=None,
210
+ vision_encoder=None,
211
+ transformer=None,
212
+ num_anchors: int = 3,
213
+ anchor_embedding_dim: int = 128,
214
+ visual_ranker=None,
215
+ text_ranker=None,
216
+ span_predictor: Optional[str] = "pe-a-frame-large",
217
+ ):
218
+ self.in_channels = in_channels
219
+ self.audio_codec = DACVAEConfig(**(audio_codec or {}))
220
+ self.text_encoder = T5EncoderConfig(**(text_encoder or {}))
221
+ self.vision_encoder = PerceptionEncoderConfig(**(vision_encoder or {}))
222
+ self.transformer = TransformerConfig(**(transformer or {}))
223
+ self.num_anchors = num_anchors
224
+ self.anchor_embedding_dim = anchor_embedding_dim
225
+ self.visual_ranker = (
226
+ None if visual_ranker is None else parse_ranker_config(visual_ranker)
227
+ )
228
+ self.text_ranker = (
229
+ None if text_ranker is None else parse_ranker_config(text_ranker)
230
+ )
231
+ self.span_predictor = span_predictor
232
+
233
+
234
+ class SAMAudioJudgeConfig:
235
+ def __init__(
236
+ self,
237
+ audio_codec: DACVAEConfig = None,
238
+ transformer: PEAVTransformerConfig = None,
239
+ text_model: ModernBertConfig = None,
240
+ finetune_transformer: PEAVTransformerConfig = None,
241
+ nth_text_layer: int = 22,
242
+ bottleneck_dim: int = 256,
243
+ ):
244
+ self.audio_codec = DACVAEConfig(**(audio_codec or {}))
245
+ self.transformer = PEAVTransformerConfig(**(transformer or {}))
246
+ self.text_model = ModernBertConfig(**(text_model or {}))
247
+ self.finetune_transformer = PEAVTransformerConfig(
248
+ **(finetune_transformer or {})
249
+ )
250
+ self.nth_text_layer = nth_text_layer
251
+ self.bottleneck_dim = bottleneck_dim
sam_audio/model/judge.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from core.audio_visual_encoder.transformer import BaseModelOutputWithPooling
8
+ from core.audio_visual_encoder.transformer import Transformer as PEAVTransformer
9
+ from transformers import AutoModel
10
+
11
+ from .base import BaseModel
12
+ from .codec import DACVAEEncoder
13
+ from .config import SAMAudioJudgeConfig
14
+
15
+
16
+ @dataclass
17
+ class SAMAudioJudgeOutput:
18
+ r"""
19
+ overall (torch.Tensor, optional): Overall score tensor of shape (batch_size, 1).
20
+ recall (torch.Tensor, optional): Recall score tensor of shape (batch_size, 1).
21
+ precision (torch.Tensor, optional): Precision score tensor of shape (batch_size, 1).
22
+ faithfulness (torch.Tensor, optional): Faithfulness score tensor of shape (batch_size, 1).
23
+ text_model_output (BaseModelOutputWithPooling): Output from the text model.
24
+ audio_model_output (BaseModelOutputWithPooling): Output from the audio model.
25
+ """
26
+
27
+ overall: Optional[torch.Tensor] = None
28
+ recall: Optional[torch.Tensor] = None
29
+ precision: Optional[torch.Tensor] = None
30
+ faithfulness: Optional[torch.Tensor] = None
31
+ text_model_output: BaseModelOutputWithPooling = None
32
+ audio_model_output: BaseModelOutputWithPooling = None
33
+
34
+
35
+ class SAMAudioJudgeModel(BaseModel):
36
+ config_cls = SAMAudioJudgeConfig
37
+ revision = "sam_audio"
38
+
39
+ def __init__(self, config: SAMAudioJudgeConfig):
40
+ super().__init__()
41
+ self.config = config
42
+ self.data_proj = torch.nn.Linear(
43
+ config.audio_codec.codebook_dim, config.transformer.hidden_size
44
+ )
45
+ self.audio_codec = DACVAEEncoder(config.audio_codec)
46
+ self.transformer = PEAVTransformer(config.transformer)
47
+ self.finetune_transformer = PEAVTransformer(config.finetune_transformer)
48
+ self.text_model = AutoModel.from_config(config.text_model)
49
+ self.cat_audio_proj = torch.nn.Linear(
50
+ 2 * config.transformer.hidden_size, config.bottleneck_dim
51
+ )
52
+ self.text_proj1 = torch.nn.Linear(
53
+ in_features=config.text_model.hidden_size,
54
+ out_features=config.transformer.hidden_size,
55
+ bias=False,
56
+ )
57
+ self.text_proj2 = torch.nn.Linear(
58
+ in_features=config.transformer.hidden_size,
59
+ out_features=config.bottleneck_dim,
60
+ )
61
+ self.layer_norm = torch.nn.LayerNorm(config.bottleneck_dim)
62
+ self.proj_audio_and_text = torch.nn.Linear(
63
+ 2 * config.bottleneck_dim, config.bottleneck_dim
64
+ )
65
+ self.finetune_data_proj = torch.nn.Linear(
66
+ config.bottleneck_dim, config.finetune_transformer.hidden_size
67
+ )
68
+ self.head = torch.nn.Linear(
69
+ config.finetune_transformer.hidden_size, 4, bias=False
70
+ )
71
+ self.mean = torch.nn.Parameter(torch.zeros(4, requires_grad=False))
72
+ self.std = torch.nn.Parameter(torch.ones(4, requires_grad=False))
73
+
74
+ def _get_text_output(self, input_ids, attention_mask):
75
+ nth_layer = self.config.nth_text_layer
76
+ output = self.text_model(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ output_hidden_states=nth_layer is not None,
80
+ )
81
+ if nth_layer is None:
82
+ text_model_output = output.last_hidden_state
83
+ else:
84
+ text_model_output = output.hidden_states[nth_layer]
85
+
86
+ return BaseModelOutputWithPooling(
87
+ last_hidden_state=text_model_output, pooler_output=text_model_output[:, 0]
88
+ )
89
+
90
+ def forward(
91
+ self,
92
+ input_ids: torch.Tensor, # tokenized text
93
+ input_values: torch.Tensor, # input audio waveform
94
+ separated_values: torch.Tensor, # separated audio waveform
95
+ attention_mask: Optional[torch.Tensor] = None, # text attention mask
96
+ padding_mask: Optional[torch.Tensor] = None, # audio padding mask
97
+ ) -> SAMAudioJudgeOutput:
98
+ text_features = self.text_proj1(
99
+ self._get_text_output(input_ids, attention_mask).pooler_output
100
+ )
101
+ stacked_audios = torch.cat([input_values, separated_values], dim=0)
102
+ stacked_codec_features = self.audio_codec(stacked_audios)
103
+ feature_padding_mask = None
104
+ if padding_mask is not None:
105
+ feature_padding_mask = padding_mask[
106
+ :, :: self.config.audio_codec.hop_length
107
+ ]
108
+ stacked_features = self.transformer(
109
+ self.data_proj(stacked_codec_features.transpose(1, 2)),
110
+ padding_mask=feature_padding_mask,
111
+ )
112
+ input_features, hyp_features = stacked_features.last_hidden_state.chunk(2, 0)
113
+ audio_features = self.cat_audio_proj(
114
+ torch.cat([hyp_features, input_features], dim=2)
115
+ )
116
+ expanded_text = (
117
+ self.layer_norm(self.text_proj2(text_features))
118
+ .unsqueeze(1)
119
+ .expand_as(audio_features)
120
+ )
121
+ audio_and_text = self.proj_audio_and_text(
122
+ torch.cat([audio_features, expanded_text], dim=2)
123
+ )
124
+ finetune_transformer_output = self.finetune_transformer(
125
+ self.finetune_data_proj(audio_and_text), padding_mask=feature_padding_mask
126
+ )
127
+ result = self.head(finetune_transformer_output.last_hidden_state)
128
+ if feature_padding_mask is not None:
129
+ feature_padding_mask = feature_padding_mask.unsqueeze(-1)
130
+ pooled = torch.masked.mean(result, mask=feature_padding_mask, dim=1)
131
+ de_normalized = pooled * self.std + self.mean
132
+ return SAMAudioJudgeOutput(*de_normalized.chunk(4, dim=1))
133
+
134
+
135
+ __all__ = ["SAMAudioJudgeModel", "SAMAudioJudgeOutput"]
sam_audio/model/model.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import math
4
+ import re
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, Optional
7
+
8
+ import torch
9
+ from core.audio_visual_encoder import PEAudioFrame, PEAudioFrameTransform
10
+ from torchdiffeq import odeint
11
+
12
+ from sam_audio.model.align import AlignModalities
13
+ from sam_audio.model.base import BaseModel
14
+ from sam_audio.model.codec import DACVAE
15
+ from sam_audio.model.config import SAMAudioConfig
16
+ from sam_audio.model.text_encoder import T5TextEncoder
17
+ from sam_audio.model.transformer import DiT
18
+ from sam_audio.model.vision_encoder import PerceptionEncoder
19
+ from sam_audio.processor import Batch
20
+ from sam_audio.ranking import create_ranker
21
+
22
+ DFLT_ODE_OPT = {"method": "midpoint", "options": {"step_size": 2 / 32}}
23
+
24
+
25
+ class SinusoidalEmbedding(torch.nn.Module):
26
+ def __init__(self, dim, theta=10000):
27
+ super().__init__()
28
+ assert (dim % 2) == 0
29
+ half_dim = dim // 2
30
+ inv_freq = torch.exp(
31
+ -math.log(theta) * torch.arange(half_dim).float() / half_dim
32
+ )
33
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
34
+
35
+ def forward(self, x, pos=None):
36
+ if pos is None:
37
+ seq_len, device = x.shape[1], x.device
38
+ pos = torch.arange(seq_len, device=device)
39
+
40
+ emb = torch.einsum("i, j -> i j", pos, self.inv_freq)
41
+ emb = torch.cat((emb.cos(), emb.sin()), dim=-1)
42
+ return emb
43
+
44
+
45
+ class EmbedAnchors(torch.nn.Module):
46
+ def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int):
47
+ super().__init__()
48
+ self.embed = torch.nn.Embedding(
49
+ num_embeddings + 1, embedding_dim, padding_idx=num_embeddings
50
+ )
51
+ self.gate = torch.nn.Parameter(torch.tensor([0.0]))
52
+ self.proj = torch.nn.Linear(embedding_dim, out_dim, bias=False)
53
+
54
+ def forward(
55
+ self,
56
+ x: torch.Tensor,
57
+ anchor_ids: Optional[torch.Tensor] = None,
58
+ anchor_alignment: Optional[torch.Tensor] = None,
59
+ ):
60
+ if anchor_ids is None:
61
+ return x
62
+
63
+ embs = self.embed(anchor_ids.gather(1, anchor_alignment))
64
+ proj = self.proj(embs)
65
+ return x + self.gate.tanh() * proj
66
+
67
+
68
+ @dataclass
69
+ class SeparationResult:
70
+ target: torch.Tensor
71
+ residual: torch.Tensor
72
+ noise: torch.Tensor
73
+
74
+
75
+ class SAMAudio(BaseModel):
76
+ config_cls = SAMAudioConfig
77
+ revision = None
78
+
79
+ def __init__(self, cfg: SAMAudioConfig):
80
+ super().__init__()
81
+ self.audio_codec = DACVAE(cfg.audio_codec)
82
+ self.text_encoder = T5TextEncoder(cfg.text_encoder)
83
+ self.vision_encoder = PerceptionEncoder(cfg.vision_encoder)
84
+ self.transformer = DiT(cfg.transformer)
85
+ self.proj = torch.nn.Linear(cfg.in_channels, cfg.transformer.dim)
86
+ self.align_masked_video = AlignModalities(
87
+ cfg.vision_encoder.dim, cfg.transformer.dim
88
+ )
89
+ self.embed_anchors = EmbedAnchors(
90
+ cfg.num_anchors, cfg.anchor_embedding_dim, cfg.transformer.dim
91
+ )
92
+ self.memory_proj = torch.nn.Linear(cfg.text_encoder.dim, cfg.transformer.dim)
93
+ self.timestep_emb = SinusoidalEmbedding(cfg.transformer.dim)
94
+ self.visual_ranker = create_ranker(cfg.visual_ranker)
95
+ self.text_ranker = create_ranker(cfg.text_ranker)
96
+ if cfg.span_predictor is not None:
97
+ self.span_predictor = PEAudioFrame.from_config(
98
+ cfg.span_predictor, pretrained=True
99
+ )
100
+ self.span_predictor_transform = PEAudioFrameTransform.from_config(
101
+ cfg.span_predictor
102
+ )
103
+
104
+ @property
105
+ def sample_rate(self):
106
+ return self.audio_codec.sample_rate
107
+
108
+ def align_inputs(
109
+ self,
110
+ noisy_audio,
111
+ audio_features: torch.Tensor,
112
+ masked_video_features: Optional[torch.Tensor] = None,
113
+ anchor_ids: Optional[torch.Tensor] = None,
114
+ anchor_alignment: Optional[torch.Tensor] = None,
115
+ ):
116
+ x = torch.cat(
117
+ [
118
+ noisy_audio,
119
+ torch.zeros_like(audio_features),
120
+ audio_features,
121
+ ],
122
+ dim=2,
123
+ )
124
+
125
+ projected = self.proj(x)
126
+ aligned = self.align_masked_video(projected, masked_video_features)
127
+ aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment)
128
+ return aligned
129
+
130
+ def forward(
131
+ self,
132
+ noisy_audio: torch.Tensor,
133
+ audio_features: torch.Tensor,
134
+ text_features: torch.Tensor,
135
+ time: torch.Tensor,
136
+ masked_video_features: Optional[torch.Tensor] = None,
137
+ text_mask: Optional[torch.Tensor] = None,
138
+ anchor_ids: Optional[torch.Tensor] = None,
139
+ anchor_alignment: Optional[torch.Tensor] = None,
140
+ audio_pad_mask: Optional[torch.Tensor] = None,
141
+ ):
142
+ """
143
+ Forward pass for the model. Represents one function evaluation of the ODE.
144
+ In the below descriptions, B is batch size, T is sequence length, C is channel size.
145
+ Note that the size of C and T may vary across arguments (ex. text_features vs. audio_features),
146
+ it is used only to designate a Channel or time/sequence-length dimension respectively.
147
+
148
+ Args:
149
+ noisy_audio (torch.Tensor): Noisy audio input tensor (being denoised).
150
+ audio_features (torch.Tensor): Clean audio features [B x T x C].
151
+ text_features (torch.Tensor): Encoded text features tensor [B x T x C].
152
+ time (torch.Tensor): Timestep tensor for positional encoding [B].
153
+ masked_video_features (Optional[torch.Tensor], optional): Masked video features tensor. [B x C x T].
154
+ text_mask (Optional[torch.Tensor], optional): Padding mask for text features. [B x T].
155
+ anchor_ids (Optional[torch.Tensor], optional): Anchor IDs tensor. Defaults to None [B x T].
156
+ anchor_alignment (Optional[torch.Tensor], optional): Anchor alignment tensor. B x T.
157
+ audio_pad_mask (Optional[torch.Tensor], optional): Padding mask for audio input. [B x T].
158
+
159
+ Returns:
160
+ torch.Tensor
161
+ """
162
+ aligned_inputs = self.align_inputs(
163
+ noisy_audio,
164
+ audio_features,
165
+ masked_video_features=masked_video_features,
166
+ anchor_ids=anchor_ids,
167
+ anchor_alignment=anchor_alignment,
168
+ )
169
+
170
+ memory = timestep_emb = self.timestep_emb(time, pos=time).unsqueeze(1)
171
+ if text_features is not None:
172
+ memory = self.memory_proj(text_features) + timestep_emb
173
+
174
+ return self.transformer(
175
+ aligned_inputs,
176
+ time,
177
+ padding_mask=audio_pad_mask,
178
+ memory=memory,
179
+ memory_padding_mask=text_mask,
180
+ )
181
+
182
+ def _get_audio_features(self, audios: torch.Tensor):
183
+ audio_features = self.audio_codec(audios).transpose(1, 2)
184
+ return torch.cat([audio_features, audio_features], dim=2)
185
+
186
+ def _get_video_features(self, video, audio_features):
187
+ B, T, _ = audio_features.shape
188
+ if video is None:
189
+ return audio_features.new_zeros(B, self.vision_encoder.dim, T)
190
+ else:
191
+ return self.vision_encoder(video).transpose(1, 2)
192
+
193
+ def _repeat_for_reranking(self, tensor, candidates):
194
+ if candidates > 1:
195
+ B = tensor.size(0)
196
+ rest = tensor.shape[1:]
197
+ return (
198
+ tensor.unsqueeze(1)
199
+ .expand(B, candidates, *rest)
200
+ .reshape(B * candidates, *rest)
201
+ )
202
+ else:
203
+ return tensor
204
+
205
+ def _unrepeat_from_reranking(self, tensor, candidates):
206
+ return tensor[::candidates]
207
+
208
+ def _get_forward_args(self, batch: Batch, candidates: int = 1):
209
+ audio_features = self._get_audio_features(batch.audios)
210
+ text_features, text_mask = self.text_encoder(batch.descriptions)
211
+ masked_video_features = self._get_video_features(
212
+ batch.masked_video, audio_features
213
+ )
214
+
215
+ return {
216
+ "audio_features": self._repeat_for_reranking(audio_features, candidates),
217
+ "text_features": self._repeat_for_reranking(text_features, candidates),
218
+ "text_mask": self._repeat_for_reranking(text_mask, candidates),
219
+ "masked_video_features": self._repeat_for_reranking(
220
+ masked_video_features, candidates
221
+ ),
222
+ "anchor_ids": self._repeat_for_reranking(batch.anchor_ids, candidates),
223
+ "anchor_alignment": self._repeat_for_reranking(
224
+ batch.anchor_alignment, candidates
225
+ ),
226
+ "audio_pad_mask": self._repeat_for_reranking(
227
+ batch.audio_pad_mask, candidates
228
+ ),
229
+ }
230
+
231
+ def predict_spans(
232
+ self, batch: Batch, audio_features: torch.Tensor, audio_pad_mask: torch.Tensor
233
+ ) -> Batch:
234
+ input = self.span_predictor_transform(text=batch.descriptions).to(
235
+ audio_features.device
236
+ )
237
+ output = self.span_predictor(
238
+ input_features=audio_features[:, :, :128],
239
+ padding_mask=audio_pad_mask,
240
+ return_spans=True,
241
+ **input,
242
+ )
243
+ anchors = [[["+"] + anchor for anchor in anchors] for anchors in output.spans]
244
+ batch.process_anchors(anchors)
245
+ return batch
246
+
247
+ @torch.inference_mode()
248
+ def separate(
249
+ self,
250
+ batch: Batch,
251
+ noise: Optional[torch.Tensor] = None,
252
+ ode_opt: Dict[str, Any] = DFLT_ODE_OPT,
253
+ reranking_candidates: int = 1,
254
+ predict_spans: bool = False,
255
+ ) -> SeparationResult:
256
+ # Encode audio
257
+ forward_args = self._get_forward_args(batch, candidates=reranking_candidates)
258
+
259
+ if predict_spans and hasattr(self, "span_predictor") and batch.anchors is None:
260
+ batch = self.predict_spans(
261
+ batch=batch,
262
+ audio_features=self._unrepeat_from_reranking(
263
+ forward_args["audio_features"], reranking_candidates
264
+ ),
265
+ audio_pad_mask=self._unrepeat_from_reranking(
266
+ forward_args["audio_pad_mask"], reranking_candidates
267
+ ),
268
+ )
269
+
270
+ audio_features = forward_args["audio_features"]
271
+ B, T, C = audio_features.shape
272
+ C = C // 2 # we stack audio_features, so the actual channels is half
273
+
274
+ if noise is None:
275
+ noise = torch.randn_like(audio_features)
276
+
277
+ def vector_field(t, noisy_audio):
278
+ res = self.forward(
279
+ noisy_audio=noisy_audio,
280
+ time=t.expand(noisy_audio.size(0)),
281
+ **forward_args,
282
+ )
283
+ return res
284
+
285
+ states = odeint(
286
+ vector_field,
287
+ noise,
288
+ torch.tensor([0.0, 1.0], device=noise.device),
289
+ **ode_opt,
290
+ )
291
+ generated_features = states[-1].transpose(1, 2)
292
+ # generated_features has shape [B, 2C, T]. Reshape to stack along the batch dimension
293
+ wavs = self.audio_codec.decode(generated_features.reshape(2 * B, C, T)).view(
294
+ B, 2, -1
295
+ )
296
+
297
+ bsz = wavs.size(0) // reranking_candidates
298
+ sizes = self.audio_codec.feature_idx_to_wav_idx(batch.sizes)
299
+ target_wavs = self.unbatch(
300
+ wavs[:, 0].view(bsz, reranking_candidates, -1), sizes
301
+ )
302
+ residual_wavs = self.unbatch(
303
+ wavs[:, 1].view(bsz, reranking_candidates, -1), sizes
304
+ )
305
+
306
+ if (
307
+ reranking_candidates > 1
308
+ and batch.masked_video is not None
309
+ and self.visual_ranker is not None
310
+ ):
311
+ scores = self.visual_ranker(
312
+ extracted_audio=target_wavs,
313
+ videos=batch.masked_video,
314
+ sample_rate=self.audio_codec.sample_rate,
315
+ )
316
+ idxs = scores.argmax(dim=1)
317
+ elif reranking_candidates > 1 and self.text_ranker is not None:
318
+ input_audio = [
319
+ audio[:, :size].expand(reranking_candidates, -1)
320
+ for audio, size in zip(batch.audios, sizes, strict=False)
321
+ ]
322
+ scores = self.text_ranker(
323
+ extracted_audio=target_wavs,
324
+ input_audio=input_audio,
325
+ descriptions=batch.descriptions,
326
+ sample_rate=self.audio_codec.sample_rate,
327
+ )
328
+ idxs = scores.argmax(dim=1)
329
+ else:
330
+ idxs = torch.zeros(bsz, dtype=torch.long, device=noise.device)
331
+
332
+ return SeparationResult(
333
+ target=[wav[idx] for wav, idx in zip(target_wavs, idxs, strict=False)],
334
+ residual=[
335
+ wavs[idx] for wavs, idx in zip(residual_wavs, idxs, strict=False)
336
+ ],
337
+ noise=noise,
338
+ )
339
+
340
+ def unbatch(self, wavs: torch.Tensor, sizes: torch.Tensor, time_dim: int = -1):
341
+ result = []
342
+ for row, size in zip(wavs, sizes, strict=False):
343
+ result.append(row.narrow(dim=time_dim, start=0, length=size))
344
+ return result
345
+
346
+ def load_state_dict(self, state_dict, strict=True):
347
+ if strict:
348
+ missing_keys, unexpected_keys = super().load_state_dict(
349
+ state_dict, strict=False
350
+ )
351
+ # We load this directly from HF, not in checkpoint
352
+ skip_regex = re.compile(
353
+ "(^text_encoder|^visual_ranker|^text_ranker|^span_predictor)"
354
+ )
355
+ missing_keys = [x for x in missing_keys if not re.search(skip_regex, x)]
356
+ if len(missing_keys) > 0 or len(unexpected_keys) > 0:
357
+ raise RuntimeError(
358
+ f"Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}"
359
+ )
360
+
361
+
362
+ __all__ = ["SAMAudio"]
sam_audio/model/patcher.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+
11
+ def pad1d(
12
+ x: torch.Tensor,
13
+ paddings: Tuple[int, int],
14
+ mode: str = "constant",
15
+ value: float = 0.0,
16
+ ):
17
+ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py
18
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
19
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
20
+ """
21
+ length = x.shape[-1]
22
+ padding_left, padding_right = paddings
23
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
24
+ if mode == "reflect":
25
+ max_pad = max(padding_left, padding_right)
26
+ extra_pad = 0
27
+ if length <= max_pad:
28
+ extra_pad = max_pad - length + 1
29
+ x = F.pad(x, (0, extra_pad))
30
+ padded = F.pad(x, paddings, mode, value)
31
+ end = padded.shape[-1] - extra_pad
32
+ return padded[..., :end]
33
+ else:
34
+ return F.pad(x, paddings, mode, value)
35
+
36
+
37
+ def get_extra_padding_for_conv1d(
38
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
39
+ ) -> int:
40
+ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py
41
+ """See `pad_for_conv1d`."""
42
+ length = x.shape[-1]
43
+ n_frames = (length - kernel_size + padding_total) / stride + 1
44
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
45
+ return ideal_length - length
46
+
47
+
48
+ class Conv1d(torch.nn.Conv1d):
49
+ def __init__(self, *args, **kwargs):
50
+ super().__init__(*args, **kwargs)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ kernel_size = self.kernel_size[0]
54
+ stride = self.stride[0]
55
+ dilation = self.dilation[0]
56
+ kernel_size = (
57
+ kernel_size - 1
58
+ ) * dilation + 1 # effective kernel size with dilations
59
+ padding_total = kernel_size - stride
60
+ extra_padding = get_extra_padding_for_conv1d(
61
+ x, kernel_size, stride, padding_total
62
+ )
63
+ # Asymmetric padding required for odd strides
64
+ padding_right = padding_total // 2
65
+ padding_left = padding_total - padding_right
66
+ x = pad1d(x, (padding_left, padding_right + extra_padding))
67
+ return super().forward(x)
68
+
69
+
70
+ class ConvBlock1d(torch.nn.Module):
71
+ def __init__(
72
+ self,
73
+ in_channels: int,
74
+ out_channels: int,
75
+ *,
76
+ kernel_size: int = 3,
77
+ stride: int = 1,
78
+ dilation: int = 1,
79
+ num_groups: int = 8,
80
+ ) -> None:
81
+ super().__init__()
82
+
83
+ self.groupnorm = torch.nn.GroupNorm(
84
+ num_groups=num_groups, num_channels=in_channels
85
+ )
86
+ self.activation = torch.nn.SiLU()
87
+ self.project = Conv1d(
88
+ in_channels=in_channels,
89
+ out_channels=out_channels,
90
+ kernel_size=kernel_size,
91
+ stride=stride,
92
+ dilation=dilation,
93
+ )
94
+
95
+ def forward(
96
+ self,
97
+ x: torch.Tensor,
98
+ ) -> torch.Tensor:
99
+ x = self.groupnorm(x)
100
+ x = self.activation(x)
101
+ return self.project(x)
102
+
103
+
104
+ class ResnetBlock1d(torch.nn.Module):
105
+ def __init__(
106
+ self,
107
+ in_channels: int,
108
+ out_channels: int,
109
+ *,
110
+ kernel_size: int = 3,
111
+ stride: int = 1,
112
+ dilation: int = 1,
113
+ num_groups: int = 8,
114
+ ) -> None:
115
+ super().__init__()
116
+
117
+ self.block1 = ConvBlock1d(
118
+ in_channels=in_channels,
119
+ out_channels=out_channels,
120
+ kernel_size=kernel_size,
121
+ stride=stride,
122
+ dilation=dilation,
123
+ num_groups=num_groups,
124
+ )
125
+
126
+ self.block2 = ConvBlock1d(
127
+ in_channels=out_channels,
128
+ out_channels=out_channels,
129
+ num_groups=num_groups,
130
+ )
131
+
132
+ self.to_out = (
133
+ Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
134
+ if in_channels != out_channels
135
+ else torch.nn.Identity()
136
+ )
137
+
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ h = self.block1(x)
140
+ h = self.block2(h)
141
+ return h + self.to_out(x)
142
+
143
+
144
+ class Patcher(torch.nn.Module):
145
+ def __init__(
146
+ self,
147
+ in_channels: int,
148
+ out_channels: int,
149
+ patch_size: int,
150
+ ):
151
+ super().__init__()
152
+ assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
153
+ assert out_channels % patch_size == 0, assert_message
154
+ self.patch_size = patch_size
155
+ self.block = ResnetBlock1d(
156
+ in_channels=in_channels,
157
+ out_channels=out_channels // patch_size,
158
+ num_groups=1,
159
+ )
160
+
161
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
162
+ x = self.block(x)
163
+ x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
164
+ return x
sam_audio/model/rope.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import torch
7
+
8
+
9
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
10
+ """
11
+ Reshape frequency tensor for broadcasting it with another tensor.
12
+
13
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
14
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
15
+
16
+ Args:
17
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
18
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
19
+ seq_dim (int): Sequence dimension index.
20
+
21
+ Returns:
22
+ torch.Tensor: Reshaped frequency tensor.
23
+ """
24
+ ndim = x.ndim
25
+ assert 0 <= seq_dim < ndim
26
+ assert freqs_cis.shape == (
27
+ x.shape[seq_dim],
28
+ x.shape[-3],
29
+ 2,
30
+ 2,
31
+ ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
32
+ shape = [
33
+ d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
34
+ ] + [2, 2]
35
+ return freqs_cis.view(*shape)
36
+
37
+
38
+ def apply_rotary_emb(
39
+ xq: torch.Tensor,
40
+ xk: torch.Tensor,
41
+ seq_dim: int,
42
+ freqs_cis: torch.Tensor,
43
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
44
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
45
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
46
+ freqs_cis = reshape_for_broadcast(
47
+ freqs_cis, xq_, seq_dim
48
+ ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
49
+ xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
50
+ xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
51
+ return xq_out.type_as(xq), xk_out.type_as(xk)
52
+
53
+
54
+ class RotaryEmbedding(torch.nn.Module):
55
+ """
56
+ RotaryEmbedding Module
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ theta: float,
62
+ head_dim: int,
63
+ max_seqlen: int = 1024,
64
+ scale_factor: int = 1,
65
+ low_freq_factor: int = 1,
66
+ high_freq_factor: int = 32,
67
+ old_context_len: int = 8192,
68
+ ):
69
+ super().__init__()
70
+
71
+ self.theta = theta
72
+ self.head_dim = head_dim
73
+ self.max_seqlen = max_seqlen
74
+ self.scale_factor = scale_factor
75
+ self.low_freq_factor = low_freq_factor
76
+ self.high_freq_factor = high_freq_factor
77
+ self.old_context_len = old_context_len
78
+ if scale_factor != 1:
79
+ self.low_freq_wavelen = old_context_len / low_freq_factor
80
+ self.high_freq_wavelen = old_context_len / high_freq_factor
81
+ assert self.low_freq_wavelen >= self.high_freq_wavelen
82
+
83
+ def reset_parameters(self):
84
+ freqs_cis = self.precompute_freqs_cis(
85
+ dim=self.head_dim, end=self.max_seqlen, theta=self.theta
86
+ )
87
+ S, D, _, _ = freqs_cis.shape
88
+ # S D 2 2 -> 1 S 1 D 2 2
89
+ freqs_cis = freqs_cis.view(1, S, 1, D, 2, 2)
90
+ self.register_buffer(
91
+ "freqs_cis",
92
+ freqs_cis,
93
+ persistent=False,
94
+ )
95
+
96
+ def apply_scaling(self, freqs):
97
+ if self.scale_factor == 1:
98
+ return freqs
99
+ new_freqs = []
100
+ for freq in freqs:
101
+ wavelen = 2 * math.pi / freq
102
+ if wavelen < self.high_freq_wavelen:
103
+ new_freqs.append(freq)
104
+ elif wavelen > self.low_freq_wavelen:
105
+ new_freqs.append(freq / self.scale_factor)
106
+ else:
107
+ assert self.low_freq_wavelen != self.high_freq_wavelen
108
+ smooth = (self.old_context_len / wavelen - self.low_freq_factor) / (
109
+ self.high_freq_factor - self.low_freq_factor
110
+ )
111
+ new_freqs.append(
112
+ (1 - smooth) * freq / self.scale_factor + smooth * freq
113
+ )
114
+ return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
115
+
116
+ def precompute_freqs_cis(
117
+ self,
118
+ dim: int,
119
+ end: int,
120
+ theta: float = 10000.0,
121
+ ):
122
+ """
123
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
124
+
125
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
126
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
127
+ The returned tensor contains complex values in complex64 data type.
128
+
129
+ Args:
130
+ dim (int): Dimension of the frequency tensor.
131
+ end (int): End index for precomputing frequencies.
132
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
133
+
134
+ Returns:
135
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
136
+ """
137
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
138
+ freqs = self.apply_scaling(freqs)
139
+
140
+ t = torch.arange(end, device=freqs.device)
141
+ freqs = torch.outer(t, freqs).float()
142
+
143
+ cos, sin = freqs.cos(), freqs.sin()
144
+
145
+ return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
146
+
147
+ def forward(self, x: torch.Tensor, bhle: bool = False, **kwargs):
148
+ if bhle:
149
+ x = x.transpose(1, 2) # (B H L E) -> (B L H E)
150
+ seqlen = x.size(1)
151
+ x_ = x.reshape(*x.shape[:-1], -1, 1, 2) # B L H E -> B L H E/2 1 2
152
+ x_out = (x_ * self.freqs_cis[:, :seqlen]).sum(5).flatten(3)
153
+ if bhle:
154
+ x_out = x_out.transpose(1, 2)
155
+ return x_out.type_as(x)
sam_audio/model/text_encoder.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import transformers
7
+
8
+ from sam_audio.model.config import T5EncoderConfig
9
+
10
+
11
+ class T5TextEncoder(torch.nn.Module):
12
+ def __init__(self, cfg: T5EncoderConfig):
13
+ super().__init__()
14
+ self.model = transformers.T5EncoderModel.from_pretrained(cfg.name)
15
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(cfg.name)
16
+ self.pad_mode = cfg.pad_mode
17
+ self.max_length = cfg.max_length
18
+
19
+ def forward(self, texts: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
20
+ device = next(self.model.parameters()).device
21
+ encoded = self.tokenizer(
22
+ texts,
23
+ truncation=True,
24
+ max_length=self.max_length,
25
+ padding=self.pad_mode,
26
+ return_tensors="pt",
27
+ )
28
+
29
+ input_ids = encoded["input_ids"].to(device)
30
+ attention_mask = encoded["attention_mask"].to(device)
31
+ res = self.model(
32
+ input_ids=input_ids,
33
+ attention_mask=attention_mask,
34
+ output_hidden_states=True,
35
+ )["last_hidden_state"]
36
+
37
+ return res, attention_mask.bool()
sam_audio/model/transformer.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import math
4
+ from functools import partial
5
+ from typing import List, Optional, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+
12
+ from .config import TransformerConfig
13
+ from .patcher import Patcher
14
+ from .rope import RotaryEmbedding
15
+
16
+
17
+ def gate(x, gate):
18
+ return x * gate
19
+
20
+
21
+ def modulate(x, shift, scale):
22
+ return x * (1 + scale) + shift
23
+
24
+
25
+ def get_nonlinearity(kind: str):
26
+ return {
27
+ "relu": F.relu,
28
+ "gelu": F.gelu,
29
+ "swiglu": None,
30
+ "approx_gelu": partial(F.gelu, approximate="tanh"),
31
+ "srelu": lambda x: F.relu(x) ** 2,
32
+ "silu": F.silu,
33
+ }[kind]
34
+
35
+
36
+ class RMSNorm(torch.nn.Module):
37
+ def __init__(self, dim: int, eps: float = 1e-5):
38
+ super().__init__()
39
+ self.eps = eps
40
+ self.weight = torch.nn.Parameter(torch.ones(dim))
41
+
42
+ def _norm(self, x):
43
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44
+
45
+ def forward(self, x):
46
+ output = self._norm(x.float())
47
+ return (output * self.weight).type_as(x)
48
+
49
+
50
+ class ProjectionLayer(torch.nn.Module):
51
+ def __init__(
52
+ self,
53
+ in_dim: int,
54
+ out_dim: int,
55
+ non_linearity: str,
56
+ dropout: float,
57
+ fc_bias: bool = False,
58
+ ):
59
+ super().__init__()
60
+
61
+ self.swiglu = non_linearity == "swiglu"
62
+ self.dropout = dropout
63
+ self.w1 = torch.nn.Linear(in_dim, out_dim, bias=fc_bias)
64
+
65
+ self.w2 = torch.nn.Linear(out_dim, out_dim, bias=fc_bias)
66
+ if self.swiglu:
67
+ self.w3 = torch.nn.Linear(in_dim, out_dim, bias=fc_bias)
68
+
69
+ # non-linearity
70
+ self.non_linearity = get_nonlinearity(non_linearity)
71
+
72
+ def forward(self, x):
73
+ hidden1 = self.w1(x)
74
+ if self.swiglu:
75
+ hidden3 = self.w3(x)
76
+ hidden = F.silu(hidden1) * hidden3
77
+ else:
78
+ hidden = self.non_linearity(hidden1)
79
+ hidden = F.dropout(hidden, p=self.dropout, training=self.training)
80
+ return self.w2(hidden)
81
+
82
+
83
+ class Attention(nn.Module):
84
+ def __init__(
85
+ self,
86
+ dim: int,
87
+ head_dim: int,
88
+ n_heads: int,
89
+ n_kv_heads: int,
90
+ norm_eps: float = 1e-5,
91
+ use_qk_norm: bool = False,
92
+ fc_bias: bool = False,
93
+ ):
94
+ super().__init__()
95
+ assert n_heads % n_kv_heads == 0
96
+
97
+ self.head_dim = head_dim
98
+ self.n_heads = n_heads
99
+ self.n_kv_heads = n_kv_heads
100
+ self.use_qk_norm = use_qk_norm
101
+
102
+ self.wq = torch.nn.Linear(dim, n_heads * head_dim, bias=fc_bias)
103
+ self.wk, self.wv = [
104
+ torch.nn.Linear(
105
+ dim,
106
+ n_kv_heads * head_dim,
107
+ bias=fc_bias,
108
+ )
109
+ for _ in range(2)
110
+ ]
111
+ self.wo = torch.nn.Linear(
112
+ n_heads * head_dim,
113
+ dim,
114
+ bias=fc_bias,
115
+ )
116
+
117
+ if self.use_qk_norm is True:
118
+ self.q_norm = RMSNorm(head_dim, eps=norm_eps)
119
+ self.k_norm = RMSNorm(head_dim, eps=norm_eps)
120
+
121
+ def reshape_heads(self, x: torch.Tensor, heads: int) -> torch.Tensor:
122
+ B, T, C = x.shape
123
+ # B x T x C -> B x T x C/H x H
124
+ x = x.reshape(B, T, C // heads, heads)
125
+ # B x T x C/H x H -> B x H x T x C/H
126
+ return x.permute(0, 3, 1, 2)
127
+
128
+ def forward(
129
+ self,
130
+ x: torch.Tensor,
131
+ cross_x: Optional[torch.Tensor] = None,
132
+ key_padding_mask: Optional[torch.Tensor] = None,
133
+ rope: Optional[RotaryEmbedding] = None,
134
+ ):
135
+ # x: B, T, E
136
+ xq = self.wq(x)
137
+ if cross_x is not None:
138
+ xk, xv = self.wk(cross_x), self.wv(cross_x)
139
+ else:
140
+ xk, xv = self.wk(x), self.wv(x)
141
+
142
+ xk = self.reshape_heads(xk, self.n_kv_heads)
143
+ xv = self.reshape_heads(xv, self.n_kv_heads)
144
+ xq = self.reshape_heads(xq, self.n_heads)
145
+ if self.use_qk_norm:
146
+ xq = self.q_norm(xq)
147
+ xk = self.k_norm(xk)
148
+
149
+ if rope is not None:
150
+ xq = rope(xq, bhle=True)
151
+ xk = rope(xk, bhle=True)
152
+
153
+ attn_mask = None
154
+
155
+ if key_padding_mask is not None:
156
+ attn_mask = key_padding_mask[:, None, None, :]
157
+
158
+ output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask)
159
+
160
+ output = rearrange(output, "b h n d -> b n (h d)")
161
+ return self.wo(output)
162
+
163
+
164
+ class FeedForward(torch.nn.Module):
165
+ def __init__(
166
+ self,
167
+ dim: int,
168
+ hidden_dim: int,
169
+ ffn_dim_multiplier: float,
170
+ multiple_of: int,
171
+ dropout: float,
172
+ non_linearity: str = "swiglu",
173
+ fc_bias: bool = False,
174
+ ):
175
+ super().__init__()
176
+ self.dropout = dropout
177
+ self.swiglu = non_linearity == "swiglu"
178
+ # swiglu hidden dim factor multiplier (same #params as relu / gelu)
179
+ if self.swiglu:
180
+ hidden_dim = int(2 * hidden_dim / 3)
181
+
182
+ # custom dim factor multiplier
183
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
184
+ # round hidden dimension to `multiple_of`
185
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
186
+ # layers
187
+ self.w1 = torch.nn.Linear(dim, hidden_dim, bias=fc_bias)
188
+ self.w2 = torch.nn.Linear(hidden_dim, dim, bias=fc_bias)
189
+ if self.swiglu:
190
+ self.w3 = torch.nn.Linear(dim, hidden_dim, bias=fc_bias)
191
+
192
+ # non-linearity
193
+ self.non_linearity = get_nonlinearity(non_linearity)
194
+
195
+ def forward(
196
+ self,
197
+ x,
198
+ ):
199
+ hidden1 = self.w1(x)
200
+ if self.swiglu:
201
+ hidden3 = self.w3(x)
202
+ hidden = F.silu(hidden1) * hidden3
203
+ else:
204
+ hidden = self.non_linearity(hidden1)
205
+ hidden = F.dropout(hidden, p=self.dropout, training=self.training)
206
+ return self.w2(hidden)
207
+
208
+
209
+ class TimestepEmbedder(torch.nn.Module):
210
+ def __init__(
211
+ self,
212
+ dim: int,
213
+ frequency_embedding_dim: int,
214
+ non_linearity: str,
215
+ dropout: float,
216
+ fc_bias: bool,
217
+ max_period: int = 10000,
218
+ ):
219
+ super().__init__()
220
+ self.frequency_embedding_size = frequency_embedding_dim
221
+ self.projection = ProjectionLayer(
222
+ in_dim=frequency_embedding_dim,
223
+ out_dim=dim,
224
+ non_linearity=non_linearity,
225
+ dropout=dropout,
226
+ fc_bias=fc_bias,
227
+ )
228
+ half = frequency_embedding_dim // 2
229
+ freqs = torch.exp(
230
+ -math.log(max_period)
231
+ * torch.arange(start=0, end=half, dtype=torch.float32)
232
+ / half
233
+ )
234
+ self.register_buffer("freqs", freqs, persistent=False)
235
+
236
+ def timestep_embedding(self, t, dim):
237
+ """
238
+ Create sinusoidal timestep embeddings.
239
+ :param t: a 1-D Tensor of N indices, one per batch element.
240
+ These may be fractional.
241
+ :param dim: the dimension of the output.
242
+ :param max_period: controls the minimum frequency of the embeddings.
243
+ :return: an (N, D) Tensor of positional embeddings.
244
+ """
245
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
246
+ self.freqs = self.freqs.to(device=t.device)
247
+ args = t[:, None].float() * self.freqs[None]
248
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
249
+ if dim % 2:
250
+ embedding = torch.cat(
251
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
252
+ )
253
+ return embedding.to(t)
254
+
255
+ def forward(self, t):
256
+ x = self.timestep_embedding(t, self.frequency_embedding_size)
257
+ return self.projection(x)
258
+
259
+
260
+ class ContextEmbedder(torch.nn.Module):
261
+ def __init__(
262
+ self,
263
+ in_dim: int,
264
+ out_dim: int,
265
+ non_linearity: str,
266
+ dropout: float,
267
+ fc_bias: bool,
268
+ norm_eps: float = 1e-5,
269
+ context_norm: bool = False,
270
+ ):
271
+ super().__init__()
272
+ self.context_norm = context_norm
273
+ if context_norm:
274
+ self.norm = RMSNorm(in_dim, norm_eps)
275
+
276
+ self.projection = ProjectionLayer(
277
+ in_dim=in_dim,
278
+ out_dim=out_dim,
279
+ non_linearity=non_linearity,
280
+ dropout=dropout,
281
+ fc_bias=fc_bias,
282
+ )
283
+
284
+ def forward(self, x):
285
+ if self.context_norm:
286
+ x = self.norm(x)
287
+ h = self.projection(x)
288
+ return h
289
+
290
+
291
+ class DiTBlock(torch.nn.Module):
292
+ def __init__(
293
+ self,
294
+ dim: int,
295
+ n_heads: int,
296
+ n_kv_heads: Optional[int] = None,
297
+ dropout: float = 0.0,
298
+ norm_eps: float = 1e-5,
299
+ qk_norm: bool = False,
300
+ fc_bias: bool = False,
301
+ ffn_exp: int = 1,
302
+ ffn_dim_multiplier: int = 4,
303
+ multiple_of: int = 64,
304
+ non_linearity: str = "silu",
305
+ no_cross_attention: bool = False,
306
+ ):
307
+ super().__init__()
308
+ assert dim % n_heads == 0
309
+ self.n_heads = n_heads
310
+ self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
311
+ self.dim = dim
312
+ self.dropout = dropout
313
+ self.head_dim = dim // n_heads
314
+
315
+ assert self.n_heads % self.n_kv_heads == 0
316
+
317
+ self.attention = Attention(
318
+ dim=dim,
319
+ head_dim=self.head_dim,
320
+ n_heads=self.n_heads,
321
+ n_kv_heads=self.n_kv_heads,
322
+ norm_eps=norm_eps,
323
+ use_qk_norm=qk_norm,
324
+ fc_bias=fc_bias,
325
+ )
326
+ self.feed_forward = FeedForward(
327
+ dim=dim,
328
+ hidden_dim=int(ffn_exp * dim),
329
+ ffn_dim_multiplier=ffn_dim_multiplier,
330
+ multiple_of=multiple_of,
331
+ dropout=dropout,
332
+ non_linearity=non_linearity,
333
+ fc_bias=fc_bias,
334
+ )
335
+
336
+ self.attention_norm, self.ffn_norm = [RMSNorm(dim, norm_eps) for _ in range(2)]
337
+
338
+ self.cross_attention = None
339
+ if not no_cross_attention:
340
+ self.cross_attention = Attention(
341
+ dim=dim,
342
+ head_dim=self.head_dim,
343
+ n_heads=self.n_heads,
344
+ n_kv_heads=self.n_heads,
345
+ norm_eps=norm_eps,
346
+ use_qk_norm=qk_norm,
347
+ fc_bias=fc_bias,
348
+ )
349
+
350
+ self.scale_shift_table = nn.Parameter(
351
+ torch.randn(6, self.dim) / self.dim**0.5,
352
+ )
353
+
354
+ def forward(
355
+ self,
356
+ x: torch.Tensor,
357
+ cross_x: Optional[torch.Tensor],
358
+ t: torch.Tensor,
359
+ padding_mask: Optional[torch.Tensor],
360
+ memory_padding_mask: Optional[torch.Tensor],
361
+ rope: Optional[RotaryEmbedding] = None,
362
+ ):
363
+ biases = self.scale_shift_table[None] + t.reshape(x.size(0), 6, -1)
364
+ (
365
+ shift_msa,
366
+ scale_msa,
367
+ gate_msa,
368
+ shift_mlp,
369
+ scale_mlp,
370
+ gate_mlp,
371
+ ) = biases.chunk(6, dim=1)
372
+
373
+ assert self.attention is not None and self.attention_norm is not None
374
+ h_attn = self.attention(
375
+ modulate(self.attention_norm(x), shift_msa, scale_msa),
376
+ key_padding_mask=padding_mask,
377
+ rope=rope,
378
+ )
379
+
380
+ h = x + gate(h_attn, gate_msa)
381
+
382
+ if self.cross_attention is not None:
383
+ h_cross = self.cross_attention(
384
+ x=h,
385
+ cross_x=cross_x,
386
+ key_padding_mask=memory_padding_mask,
387
+ )
388
+ h = h + h_cross # residual
389
+ h_ff = self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp))
390
+ out = h + gate(h_ff, gate_mlp)
391
+ return out
392
+
393
+
394
+ class DiT(torch.nn.Module):
395
+ def __init__(self, config: TransformerConfig):
396
+ super().__init__()
397
+ self.dropout = config.dropout
398
+ if config.in_channels is not None:
399
+ self.data_proj = torch.nn.Linear(config.in_channels, config.dim)
400
+
401
+ # embeddings
402
+ self.rope_embeddings = None
403
+ # rotary embeddings
404
+ if config.use_rope:
405
+ self.rope_embeddings = RotaryEmbedding(
406
+ theta=max(10000, 2 * config.max_positions),
407
+ head_dim=config.dim // config.n_heads,
408
+ max_seqlen=config.max_positions,
409
+ )
410
+ self.rope_embeddings.reset_parameters()
411
+
412
+ # transformer blocks
413
+ self.layers = nn.ModuleList()
414
+ for _ in range(config.n_layers):
415
+ self.layers.append(
416
+ DiTBlock(
417
+ dim=config.dim,
418
+ n_heads=config.n_heads,
419
+ dropout=config.dropout,
420
+ norm_eps=config.norm_eps,
421
+ qk_norm=config.qk_norm,
422
+ fc_bias=config.fc_bias,
423
+ ffn_exp=config.ffn_exp,
424
+ ffn_dim_multiplier=config.ffn_dim_multiplier,
425
+ multiple_of=config.multiple_of,
426
+ non_linearity=config.non_linearity,
427
+ )
428
+ )
429
+
430
+ self.norm = RMSNorm(config.dim, config.norm_eps)
431
+
432
+ # output layer
433
+ self.output = torch.nn.Linear(
434
+ config.dim, config.out_channels, bias=config.fc_bias
435
+ )
436
+
437
+ self.x_embedder = Patcher(
438
+ in_channels=config.dim,
439
+ out_channels=config.dim,
440
+ patch_size=1,
441
+ )
442
+
443
+ self.y_embedder = ContextEmbedder(
444
+ in_dim=config.context_dim,
445
+ out_dim=config.dim,
446
+ non_linearity=config.context_non_linearity,
447
+ dropout=config.context_embedder_dropout,
448
+ fc_bias=config.fc_bias,
449
+ norm_eps=config.norm_eps,
450
+ context_norm=config.context_norm,
451
+ )
452
+
453
+ self.t_embedder = TimestepEmbedder(
454
+ config.dim,
455
+ config.frequency_embedding_dim,
456
+ non_linearity=config.timestep_non_linearity,
457
+ dropout=config.dropout,
458
+ fc_bias=config.fc_bias,
459
+ max_period=10000,
460
+ )
461
+
462
+ self.t_block_non_linearity = get_nonlinearity(config.t_block_non_linearity)
463
+ self.t_block = torch.nn.Linear(
464
+ config.dim,
465
+ config.dim * 6,
466
+ bias=config.t_block_bias,
467
+ )
468
+
469
+ self.final_layer_scale_shift_table = nn.Parameter(
470
+ torch.randn(2, config.dim) / config.dim**0.5,
471
+ )
472
+
473
+ def forward(
474
+ self,
475
+ x: torch.Tensor,
476
+ time: torch.Tensor,
477
+ *,
478
+ padding_mask: Optional[torch.Tensor] = None,
479
+ memory: Optional[torch.Tensor] = None,
480
+ memory_padding_mask: Optional[torch.Tensor] = None,
481
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
482
+ x = rearrange(x, "b l c-> b c l")
483
+ h = self.x_embedder(x)
484
+ h = rearrange(h, "b c l -> b l c")
485
+ original_N = h.shape[1]
486
+ N = h.shape[1]
487
+
488
+ h = F.dropout(h, p=self.dropout, training=self.training)
489
+
490
+ t = self.t_embedder(time) # B -> B D
491
+
492
+ t0 = self.t_block_non_linearity(t)
493
+ t0 = self.t_block(t0) # B D -> B 6D
494
+
495
+ y = self.y_embedder(memory)
496
+
497
+ for layer in self.layers:
498
+ h = layer(
499
+ x=h,
500
+ cross_x=y,
501
+ t=t0,
502
+ padding_mask=padding_mask,
503
+ memory_padding_mask=memory_padding_mask,
504
+ rope=self.rope_embeddings,
505
+ )
506
+
507
+ shift, scale = (self.final_layer_scale_shift_table[None] + t[:, None]).chunk(
508
+ 2, dim=1
509
+ )
510
+
511
+ # output layer
512
+ if self.norm is not None:
513
+ h = self.norm(h)
514
+
515
+ h = modulate(h, shift, scale)
516
+
517
+ h = F.dropout(h, p=self.dropout, training=self.training)
518
+
519
+ output = self.output(h)
520
+
521
+ N = output.shape[1]
522
+ if original_N != N:
523
+ output = output[:, -original_N:]
524
+ return output
sam_audio/model/vision_encoder.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from abc import ABCMeta, abstractmethod
4
+
5
+ import torch
6
+ import torchvision
7
+ from core.vision_encoder import pe
8
+ from torch.nn.utils.rnn import pad_sequence
9
+
10
+ from sam_audio.model.config import (
11
+ PerceptionEncoderConfig,
12
+ VisionEncoderConfig,
13
+ )
14
+
15
+
16
+ class RescaleTransform(object):
17
+ """Rescale the image in a sample to a given size.
18
+
19
+ Args:
20
+ output_size (tuple or int): Desired output size. If tuple, output is
21
+ matched to output_size. If int, smaller of image edges is matched
22
+ to output_size keeping aspect ratio the same.
23
+ """
24
+
25
+ def __init__(self, output_size, interpolation):
26
+ assert isinstance(output_size, (int, tuple))
27
+ self.output_size = output_size
28
+ if isinstance(output_size, int):
29
+ self.output_size = (output_size, output_size)
30
+ self.interpolation = interpolation
31
+
32
+ def __call__(self, sample):
33
+ # sample: [T, C, H, W]
34
+ sample = torch.nn.functional.interpolate(
35
+ sample.float(), size=self.output_size, mode=self.interpolation.value
36
+ )
37
+ return sample
38
+
39
+
40
+ class VisionEncoder(torch.nn.Module, metaclass=ABCMeta):
41
+ def __init__(self, cfg: VisionEncoderConfig):
42
+ super().__init__()
43
+ self.batch_size = cfg.batch_size
44
+ self.dim = cfg.dim
45
+ self.transform = self.get_transform()
46
+
47
+ @torch.no_grad()
48
+ def forward(self, videos: list[torch.Tensor]) -> torch.Tensor:
49
+ """
50
+ Encodes a list of input videos. Each element of the list is a video represented
51
+ as a tensor [T, C, H, W]
52
+ Args:
53
+ videos (list[torch.Tensor]): List of input image tensors to be processed.
54
+
55
+ Returns:
56
+ torch.Tensor: Encoded feature representations of the input tensors.
57
+ The output is padded along the time dimension for variable length videos
58
+ """
59
+ result = []
60
+ for video in videos:
61
+ video = self.transform(video)
62
+ if self.batch_size > 0 and video.size(0) > self.batch_size:
63
+ res = []
64
+ for i in range(0, video.size(0), self.batch_size):
65
+ res.append(self.encode(video[i : i + self.batch_size]))
66
+ result.append(torch.cat(res, dim=0))
67
+ else:
68
+ result.append(self.encode(video))
69
+ return pad_sequence(result, batch_first=True, padding_value=0.0)
70
+
71
+ @abstractmethod
72
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
73
+ pass
74
+
75
+ @abstractmethod
76
+ def get_transform(self):
77
+ pass
78
+
79
+
80
+ class PerceptionEncoder(VisionEncoder):
81
+ def __init__(self, cfg: PerceptionEncoderConfig):
82
+ self.normalize_feature = cfg.normalize_feature
83
+ self.interpolation_mode = cfg.interpolation_mode
84
+ self.image_size = cfg.image_size
85
+ super().__init__(cfg)
86
+ self.model = pe.CLIP.from_config(cfg.name)
87
+
88
+ def encode(self, x):
89
+ image_features = self.model.encode_image(x, normalize=self.normalize_feature)
90
+ return image_features
91
+
92
+ def get_transform(self):
93
+ T = torchvision.transforms
94
+ try:
95
+ interp = getattr(T.InterpolationMode, self.interpolation_mode.upper())
96
+ except AttributeError as err:
97
+ raise ValueError(
98
+ f"Unsupported interpolation_mode: {self.interpolation_mode}"
99
+ ) from err
100
+ crop = [
101
+ T.Resize(
102
+ (self.image_size, self.image_size),
103
+ interpolation=interp,
104
+ )
105
+ ]
106
+
107
+ return T.Compose(
108
+ crop
109
+ + [
110
+ T.Lambda(lambda x: x.float() / 255.0),
111
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
112
+ ]
113
+ )
sam_audio/processor.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import json
4
+ import logging
5
+ import math
6
+ import os
7
+ from typing import Callable, List, Optional, Tuple
8
+
9
+ import torch
10
+ import torchaudio
11
+ from huggingface_hub import hf_hub_download
12
+ from torch.nn.utils.rnn import pad_sequence
13
+ from torchcodec.decoders import AudioDecoder, VideoDecoder
14
+ from transformers import AutoTokenizer, BatchFeature
15
+
16
+ from sam_audio.model.config import SAMAudioConfig, SAMAudioJudgeConfig
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ Anchor = Tuple[str, float, float]
21
+
22
+
23
+ def batch_audio(
24
+ audios: list[str | torch.Tensor], audio_sampling_rate: int = 48_000
25
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
26
+ wavs = []
27
+ for audio in audios:
28
+ if isinstance(audio, str):
29
+ wav, sr = torchaudio.load(audio)
30
+ if sr != audio_sampling_rate:
31
+ wav = torchaudio.functional.resample(wav, sr, audio_sampling_rate)
32
+ else:
33
+ wav = audio
34
+ wavs.append(wav.mean(0))
35
+ sizes = torch.tensor([wav.size(-1) for wav in wavs])
36
+ return pad_sequence(wavs, batch_first=True).unsqueeze(1), sizes
37
+
38
+
39
+ class Batch:
40
+ def __init__(
41
+ self,
42
+ audios: torch.Tensor,
43
+ sizes: torch.Tensor,
44
+ wav_sizes: torch.Tensor,
45
+ descriptions: list[str],
46
+ hop_length: int,
47
+ audio_sampling_rate: int,
48
+ anchors: Optional[list[list[Anchor]]] = None,
49
+ audio_pad_mask: Optional[torch.Tensor] = None,
50
+ masked_video: Optional[torch.Tensor] = None,
51
+ ):
52
+ self.audios = audios
53
+ self.sizes = sizes
54
+ self.wav_sizes = wav_sizes
55
+ self.descriptions = descriptions
56
+ self.audio_pad_mask = audio_pad_mask
57
+ self.masked_video = masked_video
58
+ self.hop_length = hop_length
59
+ self.audio_sampling_rate = audio_sampling_rate
60
+ self.process_anchors(anchors)
61
+ assert self.audios.size(0) == len(self.descriptions)
62
+
63
+ def _wav_to_feature_idx(self, wav_idx: int):
64
+ return math.ceil(wav_idx / self.hop_length)
65
+
66
+ def to(self, device: torch.device):
67
+ self.audios = self.audios.to(device)
68
+ self.anchor_ids = self.anchor_ids.to(device)
69
+ self.anchor_alignment = self.anchor_alignment.to(device)
70
+ self.sizes = self.sizes.to(device)
71
+ self.wav_sizes = self.wav_sizes.to(device)
72
+ if self.audio_pad_mask is not None:
73
+ self.audio_pad_mask = self.audio_pad_mask.to(device)
74
+ if self.masked_video is not None:
75
+ self.masked_video = [v.to(device) for v in self.masked_video]
76
+ return self
77
+
78
+ def process_anchors(self, anchors: Optional[list[list[Anchor]]]):
79
+ batch_size = len(self.audios)
80
+ anchor_dict = {"<null>": 0, "+": 1, "-": 2, "<pad>": 3}
81
+ if anchors is None:
82
+ anchor_ids = torch.full(
83
+ (batch_size, 2), anchor_dict["<null>"], dtype=torch.long
84
+ )
85
+ anchor_ids[:, 1] = anchor_dict["<pad>"]
86
+ anchor_alignment = torch.full(
87
+ (
88
+ batch_size,
89
+ self.audio_pad_mask.size(-1),
90
+ ),
91
+ 0,
92
+ dtype=torch.long,
93
+ )
94
+ anchor_alignment[~self.audio_pad_mask] = 1 # point to pad token
95
+ else:
96
+ anchor_alignment = torch.full(
97
+ (
98
+ batch_size,
99
+ self.audio_pad_mask.size(-1),
100
+ ),
101
+ 0,
102
+ dtype=torch.long,
103
+ )
104
+ anchor_alignment[~self.audio_pad_mask] = 1 # point to pad token
105
+ ids = []
106
+
107
+ for i, anchor_list in enumerate(anchors):
108
+ current = [anchor_dict["<null>"], anchor_dict["<pad>"]]
109
+ for token, start_time, end_time in anchor_list:
110
+ start_idx = self._wav_to_feature_idx(
111
+ start_time * self.audio_sampling_rate
112
+ )
113
+ end_idx = self._wav_to_feature_idx(
114
+ end_time * self.audio_sampling_rate
115
+ )
116
+ anchor_alignment[i, start_idx:end_idx] = len(current)
117
+ current.append(anchor_dict[token])
118
+ ids.append(torch.tensor(current))
119
+ anchor_ids = pad_sequence(
120
+ ids, batch_first=True, padding_value=anchor_dict["<pad>"]
121
+ )
122
+ self.anchor_ids = anchor_ids
123
+ self.anchor_alignment = anchor_alignment
124
+ self.anchors = anchors
125
+
126
+
127
+ def mask_from_sizes(sizes: torch.Tensor) -> torch.Tensor:
128
+ return torch.arange(sizes.max()).expand(len(sizes), -1) < sizes.unsqueeze(1)
129
+
130
+
131
+ def load_video(
132
+ sizes: torch.Tensor,
133
+ videos: List[str],
134
+ feature_idx_to_wav_idx: Callable[[torch.Tensor], torch.Tensor],
135
+ audio_sampling_rate: int,
136
+ ) -> list[torch.Tensor]:
137
+ all_frames = []
138
+ for size, video in zip(sizes, videos, strict=False):
139
+ audio_timestamps = (
140
+ feature_idx_to_wav_idx(torch.arange(size)) / audio_sampling_rate
141
+ )
142
+ if isinstance(video, str):
143
+ decoder = VideoDecoder(video, dimension_order="NCHW")
144
+ data = decoder.get_frames_in_range(0, len(decoder))
145
+ diffs = (audio_timestamps[None] - data.pts_seconds[:, None]).abs()
146
+ frame_idxs = diffs.argmin(dim=0)
147
+ frames = data.data[frame_idxs]
148
+ else:
149
+ assert video.size(1) == 3, (
150
+ f"Expected video tensor to be in NCHW format, but found {video.size(1)} channels"
151
+ )
152
+ idx = torch.linspace(0, video.size(0) - 1, int(size)).round().long()
153
+ frames = video[idx]
154
+ all_frames.append(frames)
155
+ return all_frames
156
+
157
+
158
+ class Processor:
159
+ config_cls: Callable
160
+
161
+ def __init__(self, audio_hop_length: int, audio_sampling_rate: int):
162
+ self.audio_hop_length = audio_hop_length
163
+ self.audio_sampling_rate = audio_sampling_rate
164
+
165
+ @classmethod
166
+ def _get_config(cls, model_name_or_path: str):
167
+ if os.path.exists(model_name_or_path):
168
+ config_path = os.path.join(model_name_or_path, "config.json")
169
+ else:
170
+ config_path = hf_hub_download(
171
+ repo_id=model_name_or_path,
172
+ filename="config.json",
173
+ revision=cls.revision,
174
+ )
175
+ with open(config_path) as fin:
176
+ config = cls.config_cls(**json.load(fin))
177
+ return config
178
+
179
+ @classmethod
180
+ def from_pretrained(cls, model_name_or_path: str) -> "Processor":
181
+ config = cls._get_config(model_name_or_path)
182
+ return cls(
183
+ audio_hop_length=config.audio_codec.hop_length,
184
+ audio_sampling_rate=config.audio_codec.sample_rate,
185
+ )
186
+
187
+ def feature_to_wav_idx(self, feature_idx):
188
+ return feature_idx * self.audio_hop_length
189
+
190
+ def wav_to_feature_idx(self, wav_idx):
191
+ if torch.is_tensor(wav_idx):
192
+ ceil = torch.ceil
193
+ else:
194
+ ceil = math.ceil
195
+ return ceil(wav_idx / self.audio_hop_length)
196
+
197
+ def mask_videos(
198
+ self,
199
+ videos: List[str | torch.Tensor],
200
+ masks: List[str | torch.Tensor],
201
+ ) -> list[torch.Tensor]:
202
+ video = [VideoDecoder(v)[:] if isinstance(v, str) else v for v in videos]
203
+ video_mask = [VideoDecoder(v)[:] if isinstance(v, str) else v for v in masks]
204
+ return [v * m.eq(0) for v, m in zip(video, video_mask, strict=False)]
205
+
206
+
207
+ class SAMAudioProcessor(Processor):
208
+ config_cls = SAMAudioConfig
209
+ revision = None
210
+
211
+ def __call__(
212
+ self,
213
+ descriptions: list[str],
214
+ audios: list[str | torch.Tensor],
215
+ anchors: Optional[list[list[Anchor]]] = None,
216
+ masked_videos: Optional[list[str | torch.Tensor]] = None,
217
+ ):
218
+ """
219
+ Processes input data for the model.
220
+
221
+ Args:
222
+ descriptions (list[str]): List of text descriptions corresponding to each audio sample.
223
+ audios (list[str]): List of audio file paths or tensors.
224
+ If a tensor:
225
+ - should have shape (channels, time) where channels=1 for mono and 2 for stereo.
226
+ - should be resampled to 48_000 hz
227
+ anchors (Optional[list[list[Anchor]]], optional): List of anchors for each sample,
228
+ where each anchor is a tuple (token, start_time, end_time).
229
+ masked_videos (Optional[list[str | torch.Tensor]], optional): List of masked video file paths or tensors.
230
+ If a tensor, should have shape (N, C, H, W)
231
+
232
+ Returns:
233
+ Batch: A Batch object containing processed audio, sizes, descriptions, anchor ids, anchor alignment, audio pad mask, and optionally masked video.
234
+ """
235
+
236
+ assert len(descriptions) == len(audios)
237
+ assert anchors is None or len(descriptions) == len(anchors)
238
+ assert masked_videos is None or len(descriptions) == len(masked_videos)
239
+
240
+ audios, wav_sizes = batch_audio(audios, self.audio_sampling_rate)
241
+
242
+ sizes = self.wav_to_feature_idx(wav_sizes)
243
+ audio_pad_mask = mask_from_sizes(sizes)
244
+ masked_video = None
245
+ if masked_videos is not None:
246
+ masked_video = load_video(
247
+ sizes, masked_videos, self.feature_to_wav_idx, self.audio_sampling_rate
248
+ )
249
+
250
+ return Batch(
251
+ audios=audios,
252
+ sizes=sizes,
253
+ descriptions=descriptions,
254
+ audio_pad_mask=audio_pad_mask,
255
+ anchors=anchors,
256
+ masked_video=masked_video,
257
+ hop_length=self.audio_hop_length,
258
+ audio_sampling_rate=self.audio_sampling_rate,
259
+ wav_sizes=wav_sizes,
260
+ )
261
+
262
+
263
+ class SAMAudioJudgeProcessor(Processor):
264
+ config_cls = SAMAudioJudgeConfig
265
+ revision = "sam_audio"
266
+
267
+ def __init__(
268
+ self,
269
+ audio_hop_length: int,
270
+ audio_sampling_rate: int,
271
+ tokenizer: AutoTokenizer,
272
+ ):
273
+ super().__init__(audio_hop_length, audio_sampling_rate)
274
+ self.tokenizer = tokenizer
275
+
276
+ @classmethod
277
+ def from_pretrained(cls, model_name_or_path: str) -> "SAMAudioJudgeProcessor":
278
+ config = cls._get_config(model_name_or_path)
279
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
280
+ return cls(
281
+ audio_hop_length=config.audio_codec.hop_length,
282
+ audio_sampling_rate=config.audio_codec.sample_rate,
283
+ tokenizer=tokenizer,
284
+ )
285
+
286
+ def _reflect_pad(self, wav):
287
+ if wav.ndim == 1:
288
+ wav = wav.unsqueeze(0)
289
+ if wav.size(-1) % self.audio_hop_length == 0:
290
+ return wav
291
+ p1d = (0, self.audio_hop_length - (wav.size(-1) % self.audio_hop_length))
292
+ return torch.nn.functional.pad(wav, p1d, mode="reflect")
293
+
294
+ def _load_audio(self, path: str):
295
+ ad = AudioDecoder(path, sample_rate=self.audio_sampling_rate, num_channels=1)
296
+ return ad.get_all_samples().data
297
+
298
+ def _process_audio(
299
+ self,
300
+ raw_audio,
301
+ sampling_rate: Optional[int] = None,
302
+ ):
303
+ from_file = False
304
+ if isinstance(raw_audio, str):
305
+ raw_audio = [raw_audio]
306
+
307
+ if isinstance(raw_audio, (list, tuple)) and isinstance(raw_audio[0], str):
308
+ loaded = []
309
+ for audio_file in raw_audio:
310
+ loaded.append(self._load_audio(audio_file))
311
+ raw_audio = loaded
312
+ from_file = True
313
+
314
+ if sampling_rate is not None:
315
+ if sampling_rate != self.audio_sampling_rate:
316
+ raise ValueError(
317
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
318
+ f" {self.audio_sampling_rate}. Please make sure that the provided audio input was sampled with"
319
+ f" {self.audio_sampling_rate} and not {sampling_rate}."
320
+ )
321
+ elif not from_file:
322
+ logger.warning(
323
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
324
+ "Failing to do so can result in silent errors that might be hard to debug."
325
+ )
326
+
327
+ if isinstance(raw_audio, list):
328
+ raw_audio = [self._reflect_pad(x).T for x in raw_audio]
329
+ else:
330
+ raw_audio = self._reflect_pad(raw_audio).T
331
+
332
+ # verify inputs are valid
333
+ for example in raw_audio:
334
+ if example.ndim > 2:
335
+ raise ValueError(
336
+ f"Expected input shape (channels, num_samples), but got shape ({example.shape})"
337
+ )
338
+
339
+ lengths = torch.tensor([x.size(0) for x in raw_audio])
340
+ input_values = pad_sequence(raw_audio, batch_first=True).transpose(1, 2)
341
+ padding_mask = torch.arange(lengths.max())[None] < lengths[:, None]
342
+
343
+ return BatchFeature(
344
+ {"input_values": input_values, "padding_mask": padding_mask}
345
+ )
346
+
347
+ def __call__(
348
+ self,
349
+ text: Optional[str] = None,
350
+ input_audio: Optional[
351
+ str | list[str] | torch.Tensor | list[torch.Tensor]
352
+ ] = None,
353
+ separated_audio: Optional[
354
+ str | list[str] | torch.Tensor | list[torch.Tensor]
355
+ ] = None,
356
+ sampling_rate: Optional[int] = None,
357
+ **kwargs,
358
+ ):
359
+ batch = BatchFeature()
360
+ if text is not None:
361
+ batch.update(
362
+ self.tokenizer(
363
+ text,
364
+ return_tensors="pt",
365
+ padding="longest",
366
+ max_length=512,
367
+ truncation=True,
368
+ )
369
+ )
370
+
371
+ if input_audio is not None:
372
+ batch.update(self._process_audio(input_audio, sampling_rate))
373
+
374
+ if separated_audio is not None:
375
+ batch["separated_values"] = self._process_audio(
376
+ separated_audio, sampling_rate
377
+ )["input_values"]
378
+
379
+ return batch
380
+
381
+
382
+ __all__ = ["SAMAudioProcessor", "SAMAudioJudgeProcessor", "Batch"]
sam_audio/ranking/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from sam_audio.model.config import (
4
+ ClapRankerConfig,
5
+ EnsembleRankerConfig,
6
+ ImageBindRankerConfig,
7
+ JudgeRankerConfig,
8
+ )
9
+ from sam_audio.ranking.clap import ClapRanker
10
+ from sam_audio.ranking.imagebind import ImageBindRanker
11
+ from sam_audio.ranking.judge import JudgeRanker
12
+ from sam_audio.ranking.ranker import EnsembleRanker
13
+
14
+
15
+ def create_ranker(config):
16
+ if isinstance(config, ImageBindRankerConfig):
17
+ return ImageBindRanker(config)
18
+ elif isinstance(config, ClapRankerConfig):
19
+ return ClapRanker(config)
20
+ elif isinstance(config, JudgeRankerConfig):
21
+ return JudgeRanker(config)
22
+ elif isinstance(config, EnsembleRankerConfig):
23
+ ranker_cfgs, weights = zip(*config.rankers.values(), strict=False)
24
+ return EnsembleRanker(
25
+ rankers=[create_ranker(cfg) for cfg in ranker_cfgs],
26
+ weights=weights,
27
+ )
28
+ else:
29
+ assert config is None
30
+ return None
sam_audio/ranking/clap.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import torch
4
+ import torchaudio
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ from sam_audio.model.config import ClapRankerConfig
8
+ from sam_audio.ranking.ranker import Ranker
9
+
10
+
11
+ def get_model(device="cpu"):
12
+ import laion_clap
13
+
14
+ model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-tiny").to(device)
15
+ checkpoint_file = hf_hub_download(
16
+ repo_id="lukewys/laion_clap", filename="630k-best.pt"
17
+ )
18
+ state_dict = torch.load(checkpoint_file, map_location=device, weights_only=False)[
19
+ "state_dict"
20
+ ]
21
+ if next(iter(state_dict.items()))[0].startswith("module"):
22
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
23
+
24
+ if "text_branch.embeddings.position_ids" in state_dict:
25
+ del state_dict["text_branch.embeddings.position_ids"]
26
+
27
+ model.model.load_state_dict(state_dict)
28
+ return model.eval()
29
+
30
+
31
+ class ClapRanker(Ranker):
32
+ def __init__(self, config: ClapRankerConfig):
33
+ from laion_clap.training import data
34
+
35
+ self.laion_data_module = data
36
+ super().__init__()
37
+ self.config = config
38
+ self.model = get_model()
39
+
40
+ def _prepare_audio(self, audio, sample_rate):
41
+ audio_features = []
42
+ for candidates in audio:
43
+ if sample_rate != 48_000:
44
+ candidates = torchaudio.functional.resample(
45
+ candidates, sample_rate, 48000
46
+ )
47
+
48
+ quantized = self.laion_data_module.int16_to_float32_torch(
49
+ self.laion_data_module.float32_to_int16_torch(candidates)
50
+ ).float()
51
+ for sample in quantized:
52
+ temp_dict = {}
53
+ temp_dict = self.laion_data_module.get_audio_features(
54
+ temp_dict,
55
+ sample,
56
+ 480000,
57
+ data_truncating=(
58
+ "fusion" if self.model.enable_fusion else "rand_trunc"
59
+ ),
60
+ data_filling="repeatpad",
61
+ audio_cfg=self.model.model_cfg["audio_cfg"],
62
+ require_grad=False,
63
+ )
64
+ audio_features.append(temp_dict)
65
+ return audio_features
66
+
67
+ @torch.inference_mode()
68
+ def forward(
69
+ self,
70
+ extracted_audio: list[torch.Tensor],
71
+ descriptions: list[str],
72
+ sample_rate: int = 48_000,
73
+ **kwargs,
74
+ ):
75
+ audio_embed = self.model.model.get_audio_embedding(
76
+ self._prepare_audio(extracted_audio, sample_rate)
77
+ )
78
+ text_embed = self.model.get_text_embedding(descriptions, use_tensor=True)
79
+ bsz = len(extracted_audio)
80
+ candidates = len(audio_embed) // bsz
81
+ audio_embed = audio_embed.reshape(bsz, candidates, -1)
82
+ text_embed = text_embed.reshape(bsz, -1, 1)
83
+ scores = audio_embed @ text_embed
84
+ return scores.squeeze(-1)
sam_audio/ranking/imagebind.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import math
4
+ from typing import List, Union
5
+
6
+ import torch
7
+ import torchaudio
8
+
9
+ from sam_audio.model.config import ImageBindRankerConfig
10
+ from sam_audio.ranking.ranker import Ranker
11
+
12
+ try:
13
+ from imagebind.data import (
14
+ ConstantClipsPerVideoSampler,
15
+ NormalizeVideo,
16
+ SpatialCrop,
17
+ get_clip_timepoints,
18
+ load_and_transform_video_data,
19
+ pv_transforms,
20
+ transforms,
21
+ waveform2melspec,
22
+ )
23
+ from imagebind.models.imagebind_model import ModalityType, imagebind_huge
24
+
25
+ __imagebind_exists__ = True
26
+ except ImportError:
27
+ __imagebind_exists__ = False
28
+
29
+
30
+ def load_and_transform_audio_data(
31
+ audios: List[Union[str, torch.Tensor]],
32
+ input_sample_rate=None,
33
+ num_mel_bins=128,
34
+ target_length=204,
35
+ sample_rate=16000,
36
+ clip_duration=2,
37
+ clips_per_video=3,
38
+ mean=-4.268,
39
+ std=9.138,
40
+ device=None,
41
+ ):
42
+ if audios is None:
43
+ return None
44
+
45
+ audio_outputs = []
46
+ clip_sampler = ConstantClipsPerVideoSampler(
47
+ clip_duration=clip_duration, clips_per_video=clips_per_video
48
+ )
49
+
50
+ for audio in audios:
51
+ if isinstance(audio, str):
52
+ waveform, input_sample_rate = torchaudio.load(audio)
53
+ else:
54
+ assert torch.is_tensor(audio)
55
+ assert sample_rate is not None
56
+ # Preprocessing needs to be done in full precision
57
+ waveform = audio.float()
58
+ if waveform.ndim == 1:
59
+ waveform = waveform[None]
60
+ if sample_rate != input_sample_rate:
61
+ waveform = torchaudio.functional.resample(
62
+ waveform, orig_freq=input_sample_rate, new_freq=sample_rate
63
+ )
64
+ all_clips_timepoints = get_clip_timepoints(
65
+ clip_sampler, waveform.size(1) / sample_rate
66
+ )
67
+ all_clips = []
68
+ for clip_timepoints in all_clips_timepoints:
69
+ waveform_clip = waveform[
70
+ :,
71
+ int(clip_timepoints[0] * sample_rate) : int(
72
+ clip_timepoints[1] * sample_rate
73
+ ),
74
+ ]
75
+ waveform_melspec = waveform2melspec(
76
+ waveform_clip, sample_rate, num_mel_bins, target_length
77
+ )
78
+ all_clips.append(waveform_melspec)
79
+
80
+ normalize = transforms.Normalize(mean=mean, std=std)
81
+ all_clips = [normalize(ac).to(device) for ac in all_clips]
82
+
83
+ all_clips = torch.stack(all_clips, dim=0)
84
+ audio_outputs.append(all_clips)
85
+
86
+ return torch.stack(audio_outputs, dim=0)
87
+
88
+
89
+ class VideoTransform:
90
+ def __init__(self, clip_duration=2, clips_per_video=5):
91
+ self.clip_duration = clip_duration
92
+ self.clips_per_video = clips_per_video
93
+ self.clip_sampler = ConstantClipsPerVideoSampler(
94
+ clip_duration=clip_duration, clips_per_video=clips_per_video
95
+ )
96
+ self.video_transform = transforms.Compose(
97
+ [
98
+ pv_transforms.ShortSideScale(224),
99
+ NormalizeVideo(
100
+ mean=(0.48145466, 0.4578275, 0.40821073),
101
+ std=(0.26862954, 0.26130258, 0.27577711),
102
+ ),
103
+ ]
104
+ )
105
+ self.spatial_crop = SpatialCrop(224, num_crops=3)
106
+
107
+ def load_video_fast(self, videos, durations, **kwargs):
108
+ result = []
109
+ for video, duration in zip(videos, durations, strict=False):
110
+ nframes = video.size(0)
111
+ fps = video.size(0) / duration
112
+ timepoints = get_clip_timepoints(
113
+ self.clip_sampler,
114
+ duration,
115
+ )
116
+ # Instead of loading 5 2s clips, and then sub-sampling frames, we figure
117
+ # Out the indices of the final clips we want and only decode those.
118
+ all_idxs = []
119
+ for start_time, end_time in timepoints:
120
+ idxs = torch.arange(
121
+ min(int(math.ceil(fps * start_time)), nframes - 1),
122
+ min(int(math.ceil(fps * end_time)), nframes),
123
+ )
124
+ ts = (
125
+ torch.linspace(0, idxs.size(0) - 1, self.clip_duration)
126
+ .clamp(max=idxs.size(0) - 1)
127
+ .long()
128
+ )
129
+ all_idxs.append(idxs[ts])
130
+ all_idxs = torch.cat(all_idxs)
131
+ fast_frames = video[all_idxs].transpose(0, 1)
132
+ result.append(fast_frames.chunk(self.clips_per_video, dim=1))
133
+ return result
134
+
135
+ def transform_video(self, batch, device=None):
136
+ device = device or torch.device("cpu")
137
+ video_outputs = []
138
+ for all_video in batch:
139
+ all_video = [
140
+ self.video_transform(clip.to(device) / 255.0) for clip in all_video
141
+ ]
142
+ all_video = self.spatial_crop(all_video)
143
+ all_video = torch.stack(all_video, dim=0)
144
+ video_outputs.append(all_video)
145
+ return torch.stack(video_outputs, dim=0)
146
+
147
+ def __call__(self, videos, durations, device=None):
148
+ return self.transform_video(
149
+ self.load_video_fast(videos, durations), device=device
150
+ )
151
+
152
+
153
+ class ImageBindRanker(Ranker):
154
+ def __init__(self, cfg: ImageBindRankerConfig):
155
+ super().__init__()
156
+ assert __imagebind_exists__, (
157
+ "Install ImageBind in order to use this ranker: https://github.com/facebookresearch/ImageBind/tree/main"
158
+ )
159
+
160
+ self.model = imagebind_huge(pretrained=cfg.checkpoint is None)
161
+ if cfg.checkpoint is not None:
162
+ self.model.load_state_dict(torch.load(cfg.checkpoint, map_location="cpu"))
163
+ self.model = self.model.eval()
164
+ self.video_transform = VideoTransform()
165
+
166
+ @torch.inference_mode()
167
+ def forward(
168
+ self,
169
+ extracted_audio: list[torch.Tensor],
170
+ videos: list[torch.Tensor | str],
171
+ sample_rate: int = 48_000,
172
+ **kwargs,
173
+ ):
174
+ audio_data = torch.cat(
175
+ [
176
+ load_and_transform_audio_data(x, input_sample_rate=sample_rate)
177
+ for x in extracted_audio
178
+ ],
179
+ dim=0,
180
+ )
181
+ if isinstance(videos[0], str):
182
+ video_data = load_and_transform_video_data(videos)
183
+ else:
184
+ durations = [x.size(-1) / sample_rate for x in extracted_audio]
185
+ video_data = self.video_transform(videos, durations, audio_data.device)
186
+
187
+ inputs = {ModalityType.AUDIO: audio_data, ModalityType.VISION: video_data}
188
+ embs = self.model(inputs)
189
+ audio_embs, video_embs = embs[ModalityType.AUDIO], embs[ModalityType.VISION]
190
+ audio_embs, video_embs = (
191
+ audio_embs / ((audio_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
192
+ video_embs / ((video_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
193
+ )
194
+ bsz = len(extracted_audio)
195
+ candidates = len(audio_embs) // bsz
196
+ scores = audio_embs.view(bsz, candidates, -1) @ video_embs.view(bsz, -1, 1)
197
+ return scores
sam_audio/ranking/judge.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import torch
4
+
5
+ from ..model.config import JudgeRankerConfig
6
+ from ..model.judge import SAMAudioJudgeModel
7
+ from ..processor import SAMAudioJudgeProcessor
8
+ from .ranker import Ranker
9
+
10
+
11
+ class JudgeRanker(Ranker):
12
+ def __init__(self, config: JudgeRankerConfig):
13
+ super().__init__()
14
+ self.config = config
15
+ self.model = SAMAudioJudgeModel.from_pretrained(config.checkpoint_or_model_id)
16
+ self.processor = SAMAudioJudgeProcessor.from_pretrained(
17
+ config.checkpoint_or_model_id
18
+ )
19
+
20
+ @torch.inference_mode()
21
+ def forward(
22
+ self,
23
+ input_audio: list[torch.Tensor],
24
+ extracted_audio: list[torch.Tensor],
25
+ descriptions: list[str],
26
+ sample_rate: int = 48_000,
27
+ **kwargs,
28
+ ):
29
+ bsz, ncandidates = len(input_audio), len(input_audio[0])
30
+ input_seqs = [x[None] for candidates in input_audio for x in candidates]
31
+ extracted_seqs = [x[None] for candidates in extracted_audio for x in candidates]
32
+ repeated_descriptions = [x for x in descriptions for _ in range(ncandidates)]
33
+ processed = self.processor(
34
+ text=repeated_descriptions,
35
+ input_audio=input_seqs,
36
+ separated_audio=extracted_seqs,
37
+ return_tensors="pt",
38
+ padding=True,
39
+ sampling_rate=sample_rate,
40
+ )
41
+ res = self.model(**processed.to(input_audio[0].device))
42
+ return res.overall.view(bsz, ncandidates)
sam_audio/ranking/ranker.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from abc import ABCMeta, abstractmethod
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+
9
+ class Ranker(torch.nn.Module, metaclass=ABCMeta):
10
+ @abstractmethod
11
+ def forward(self, audio: list[torch.Tensor], **kwargs) -> torch.Tensor:
12
+ """
13
+ Args:
14
+ audio: (list[torch.Tensor]) where each element in the list corresponds to
15
+ the candidates for the i'th generation (num_candidates, num_frames)
16
+ Returns:
17
+ (torch.Tensor) of shape (batch_size, num_candidates) correspoding to the ranking scores
18
+ """
19
+ pass
20
+
21
+
22
+ class EnsembleRanker(Ranker):
23
+ def __init__(self, rankers: List[Ranker], weights: List[float]):
24
+ super().__init__()
25
+ assert len(rankers) == len(weights)
26
+ self.rankers = torch.nn.ModuleList(rankers)
27
+ self.weights = weights
28
+
29
+ def forward(self, **kwargs) -> torch.Tensor:
30
+ result = None
31
+ for weight, ranker in zip(self.weights, self.rankers, strict=False):
32
+ if result is None:
33
+ result = weight * ranker(**kwargs)
34
+ else:
35
+ result += weight * ranker(**kwargs)
36
+ return result
sam_audio/ranking/sound_activity.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from io import BytesIO
4
+ from typing import Tuple, Union
5
+
6
+ import torch
7
+ from torchcodec.encoders import AudioEncoder
8
+
9
+ from ..model.config import SoundActivityRankerConfig
10
+ from .ranker import Ranker
11
+
12
+ try:
13
+ import pydub
14
+ except ImportError:
15
+ pydub = None
16
+
17
+
18
+ def get_peak_rms(audio, win_ms=250, hop_ms=100):
19
+ """
20
+ win_length and hop_length are in ms
21
+ """
22
+ last_slice_start = len(audio) - win_ms
23
+ slice_starts = range(0, last_slice_start + 1, hop_ms)
24
+ peak_rms = -1
25
+ for i in slice_starts:
26
+ audio_slice = audio[i : i + win_ms]
27
+ peak_rms = max(peak_rms, audio_slice.rms / audio.max_possible_amplitude)
28
+ # Ensure peak_rms is positive
29
+ peak_rms = max(peak_rms, 0)
30
+ return peak_rms
31
+
32
+
33
+ def torch_tensor_to_pydub(wav: torch.Tensor, sample_rate: int):
34
+ bytesio = BytesIO()
35
+ encoder = AudioEncoder(wav, sample_rate=sample_rate)
36
+ encoder.to_file_like(bytesio, format="wav")
37
+ bytesio.seek(0)
38
+ audio = pydub.AudioSegment.from_file(bytesio, format="wav")
39
+ return audio
40
+
41
+
42
+ def detect_nonsilent(
43
+ path: Union[str, Tuple[torch.Tensor, int]], # either a file path or pair wav & sr
44
+ min_sil_ms=250,
45
+ sil_threshold=-40,
46
+ threshold_mode="rel_to_max",
47
+ ):
48
+ TH_MODES = {"abs", "rel_to_max"}
49
+ SAMPLE_RATE = 24_000
50
+ assert threshold_mode in TH_MODES, f"{threshold_mode=} not in {TH_MODES}"
51
+ if isinstance(path, str):
52
+ audio = pydub.AudioSegment.from_file(path)
53
+ else: # tuple of (tensor, sr)
54
+ audio = torch_tensor_to_pydub(path[0], path[1])
55
+ audio = audio.set_frame_rate(SAMPLE_RATE)
56
+ if threshold_mode == "rel_to_max":
57
+ peak_rms = get_peak_rms(audio)
58
+ sil_threshold = sil_threshold + pydub.utils.ratio_to_db(
59
+ peak_rms
60
+ ) # convert to absolute db threshold
61
+ elif threshold_mode == "abs":
62
+ pass
63
+ else:
64
+ raise NotImplementedError(f"Unknown threshold_mode '{threshold_mode}'")
65
+ spans = pydub.silence.detect_nonsilent(
66
+ audio, min_silence_len=min_sil_ms, silence_thresh=sil_threshold, seek_step=10
67
+ )
68
+ spans = [(round(start / 1000, 3), round(end / 1000, 3)) for start, end in spans]
69
+ return spans
70
+
71
+
72
+ def compute_iou_recall_precision(hyp_spans, ref_spans):
73
+ def span_length(span):
74
+ return span[1] - span[0]
75
+
76
+ def intersection_length(span1, span2):
77
+ return max(0, min(span1[1], span2[1]) - max(span1[0], span2[0]))
78
+
79
+ total_hyp_length = sum(span_length(span) for span in hyp_spans)
80
+ total_ref_length = sum(span_length(span) for span in ref_spans)
81
+ total_intersection = 0
82
+ for hyp_span in hyp_spans:
83
+ for ref_span in ref_spans:
84
+ total_intersection += intersection_length(hyp_span, ref_span)
85
+
86
+ union_spans = hyp_spans + ref_spans # Combine both lists to compute union
87
+ union_length = sum(span_length(span) for span in union_spans) - total_intersection
88
+
89
+ iou = total_intersection / union_length if union_length > 0 else 0
90
+ recall = total_intersection / total_ref_length if total_ref_length > 0 else 0
91
+ precision = total_intersection / total_hyp_length if total_hyp_length > 0 else 0
92
+
93
+ return {"iou": iou, "recall": recall, "precision": precision}
94
+
95
+
96
+ class SoundActivityRanker(Ranker):
97
+ def __init__(self, config: SoundActivityRankerConfig):
98
+ if pydub is None:
99
+ raise ImportError(
100
+ 'Install reranking dependencies: `pip install "sam-audio[reranking]"`'
101
+ )
102
+ super().__init__()
103
+ self.config = config
104
+
105
+ @torch.inference_mode()
106
+ def forward(
107
+ self,
108
+ extracted_audio: list[torch.Tensor],
109
+ spans: list[list[list[float]]],
110
+ sample_rate: int = 48_000,
111
+ **kwargs,
112
+ ):
113
+ device = extracted_audio[0].device
114
+ scores = []
115
+ for wav, current_spans in zip(extracted_audio, spans, strict=True):
116
+ wav = wav.to(torch.float32).cpu()
117
+ # get non-silent spans
118
+ hyp_spans = detect_nonsilent(
119
+ (wav, sample_rate),
120
+ sil_threshold=self.config.sil_threshold,
121
+ threshold_mode=self.config.threshold_mode,
122
+ )
123
+ timestamps = [[span[1], span[2]] for span in current_spans]
124
+ result = compute_iou_recall_precision(hyp_spans, timestamps)
125
+ scores.append(result[self.config.metric])
126
+
127
+ # convert to tensor
128
+ scores = torch.tensor(scores, device=device)
129
+ return scores