|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
class BeatTrackingDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
hf_dataset, |
|
|
target_type="beats", |
|
|
sample_rate=16000, |
|
|
hop_length=160, |
|
|
context_frames=50, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
hf_dataset: HuggingFace dataset object |
|
|
target_type (str): "beats" or "downbeats". Determines which labels are treated as positive. |
|
|
context_frames (int): Number of frames before and after the center frame. |
|
|
Total frames = 2 * context_frames + 1. |
|
|
Default 50 means 101 frames (~1s). |
|
|
""" |
|
|
self.sr = sample_rate |
|
|
self.hop_length = hop_length |
|
|
self.target_type = target_type |
|
|
|
|
|
self.context_frames = context_frames |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.context_samples = (self.context_frames * 2 + 1) * hop_length + 1488 |
|
|
|
|
|
|
|
|
self.audio_cache = [] |
|
|
self.indices = [] |
|
|
self._prepare_indices(hf_dataset) |
|
|
|
|
|
def _prepare_indices(self, hf_dataset): |
|
|
""" |
|
|
Prepares balanced indices and caches audio. |
|
|
Uses the same "Fuzzier" training examples strategy as the baseline. |
|
|
""" |
|
|
print(f"Preparing dataset indices for target: {self.target_type}...") |
|
|
|
|
|
for i, item in tqdm( |
|
|
enumerate(hf_dataset), total=len(hf_dataset), desc="Building indices" |
|
|
): |
|
|
|
|
|
audio = item["audio"]["array"] |
|
|
if hasattr(audio, "numpy"): |
|
|
audio = audio.numpy() |
|
|
self.audio_cache.append(audio) |
|
|
|
|
|
|
|
|
audio_len = len(audio) |
|
|
n_frames = int(audio_len / self.hop_length) |
|
|
|
|
|
|
|
|
if self.target_type == "downbeats": |
|
|
gt_times = item["downbeats"] |
|
|
else: |
|
|
gt_times = item["beats"] |
|
|
|
|
|
|
|
|
if hasattr(gt_times, "tolist"): |
|
|
gt_times = gt_times.tolist() |
|
|
|
|
|
gt_frames = set([int(t * self.sr / self.hop_length) for t in gt_times]) |
|
|
|
|
|
|
|
|
pos_frames = set() |
|
|
for bf in gt_frames: |
|
|
if 0 <= bf < n_frames: |
|
|
self.indices.append((i, bf, 1.0)) |
|
|
pos_frames.add(bf) |
|
|
|
|
|
|
|
|
if 0 <= bf - 1 < n_frames: |
|
|
self.indices.append((i, bf - 1, 0.25)) |
|
|
pos_frames.add(bf - 1) |
|
|
if 0 <= bf + 1 < n_frames: |
|
|
self.indices.append((i, bf + 1, 0.25)) |
|
|
pos_frames.add(bf + 1) |
|
|
|
|
|
|
|
|
|
|
|
num_pos = len(pos_frames) |
|
|
num_neg = num_pos * 2 |
|
|
|
|
|
count = 0 |
|
|
attempts = 0 |
|
|
while count < num_neg and attempts < num_neg * 5: |
|
|
f = np.random.randint(0, n_frames) |
|
|
if f not in pos_frames: |
|
|
self.indices.append((i, f, 0.0)) |
|
|
count += 1 |
|
|
attempts += 1 |
|
|
|
|
|
print( |
|
|
f"Dataset ready. {len(self.indices)} samples, {len(self.audio_cache)} tracks cached." |
|
|
) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.indices) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
track_idx, frame_idx, label = self.indices[idx] |
|
|
|
|
|
|
|
|
audio = self.audio_cache[track_idx] |
|
|
audio_len = len(audio) |
|
|
|
|
|
|
|
|
center_sample = frame_idx * self.hop_length |
|
|
half_context = self.context_samples // 2 |
|
|
|
|
|
|
|
|
start = center_sample - half_context |
|
|
end = center_sample + half_context |
|
|
|
|
|
|
|
|
pad_left = max(0, -start) |
|
|
pad_right = max(0, end - audio_len) |
|
|
|
|
|
valid_start = max(0, start) |
|
|
valid_end = min(audio_len, end) |
|
|
|
|
|
|
|
|
chunk = audio[valid_start:valid_end] |
|
|
|
|
|
if pad_left > 0 or pad_right > 0: |
|
|
chunk = np.pad(chunk, (pad_left, pad_right), mode="constant") |
|
|
|
|
|
waveform = torch.tensor(chunk, dtype=torch.float32) |
|
|
return waveform, torch.tensor([label], dtype=torch.float32) |
|
|
|