JacobLinCool's picture
Upload folder using huggingface_hub
66e0fbe verified
import torch
from torch.utils.data import Dataset
import numpy as np
from tqdm import tqdm
class BeatTrackingDataset(Dataset):
def __init__(
self,
hf_dataset,
target_type="beats",
sample_rate=16000,
hop_length=160,
context_frames=50,
):
"""
Args:
hf_dataset: HuggingFace dataset object
target_type (str): "beats" or "downbeats". Determines which labels are treated as positive.
context_frames (int): Number of frames before and after the center frame.
Total frames = 2 * context_frames + 1.
Default 50 means 101 frames (~1s).
"""
self.sr = sample_rate
self.hop_length = hop_length
self.target_type = target_type
self.context_frames = context_frames
# Context window size in samples
# We need enough samples for the center frame +/- context frames
# PLUS the window size of the largest FFT to compute the edges correctly.
# Largest window in MultiViewSpectrogram is 1488.
self.context_samples = (self.context_frames * 2 + 1) * hop_length + 1488
# Cache audio arrays in memory for fast access
self.audio_cache = []
self.indices = []
self._prepare_indices(hf_dataset)
def _prepare_indices(self, hf_dataset):
"""
Prepares balanced indices and caches audio.
Uses the same "Fuzzier" training examples strategy as the baseline.
"""
print(f"Preparing dataset indices for target: {self.target_type}...")
for i, item in tqdm(
enumerate(hf_dataset), total=len(hf_dataset), desc="Building indices"
):
# Cache audio array (convert to numpy if tensor)
audio = item["audio"]["array"]
if hasattr(audio, "numpy"):
audio = audio.numpy()
self.audio_cache.append(audio)
# Calculate total frames available in audio
audio_len = len(audio)
n_frames = int(audio_len / self.hop_length)
# Select ground truth based on target_type
if self.target_type == "downbeats":
gt_times = item["downbeats"]
else:
gt_times = item["beats"]
# Convert to list if tensor
if hasattr(gt_times, "tolist"):
gt_times = gt_times.tolist()
gt_frames = set([int(t * self.sr / self.hop_length) for t in gt_times])
# --- Positive Examples (with Fuzziness) ---
pos_frames = set()
for bf in gt_frames:
if 0 <= bf < n_frames:
self.indices.append((i, bf, 1.0)) # Center frame
pos_frames.add(bf)
# Neighbors weighted at 0.25
if 0 <= bf - 1 < n_frames:
self.indices.append((i, bf - 1, 0.25))
pos_frames.add(bf - 1)
if 0 <= bf + 1 < n_frames:
self.indices.append((i, bf + 1, 0.25))
pos_frames.add(bf + 1)
# --- Negative Examples ---
# Balance 2:1
num_pos = len(pos_frames)
num_neg = num_pos * 2
count = 0
attempts = 0
while count < num_neg and attempts < num_neg * 5:
f = np.random.randint(0, n_frames)
if f not in pos_frames:
self.indices.append((i, f, 0.0))
count += 1
attempts += 1
print(
f"Dataset ready. {len(self.indices)} samples, {len(self.audio_cache)} tracks cached."
)
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
track_idx, frame_idx, label = self.indices[idx]
# Fast lookup from cache
audio = self.audio_cache[track_idx]
audio_len = len(audio)
# Calculate sample range for context window
center_sample = frame_idx * self.hop_length
half_context = self.context_samples // 2
# We want the window centered around center_sample
start = center_sample - half_context
end = center_sample + half_context
# Handle padding if needed
pad_left = max(0, -start)
pad_right = max(0, end - audio_len)
valid_start = max(0, start)
valid_end = min(audio_len, end)
# Extract audio chunk
chunk = audio[valid_start:valid_end]
if pad_left > 0 or pad_right > 0:
chunk = np.pad(chunk, (pad_left, pad_right), mode="constant")
waveform = torch.tensor(chunk, dtype=torch.float32)
return waveform, torch.tensor([label], dtype=torch.float32)