JacobLinCool commited on
Commit
66e0fbe
·
verified ·
1 Parent(s): 975d5cb

Upload folder using huggingface_hub

Browse files
Files changed (42) hide show
  1. .gitignore +2 -1
  2. .ruff_cache/0.12.5/14293067367466839361 +0 -0
  3. README.md +68 -37
  4. exp/baseline1/__init__.py +0 -0
  5. exp/baseline1/data.py +128 -0
  6. exp/baseline1/eval.py +322 -0
  7. exp/baseline1/model.py +62 -0
  8. exp/baseline1/train.py +183 -0
  9. exp/baseline1/utils.py +53 -0
  10. exp/baseline2/__init__.py +0 -0
  11. exp/baseline2/data.py +137 -0
  12. exp/baseline2/eval.py +324 -0
  13. exp/baseline2/model.py +139 -0
  14. exp/baseline2/train.py +215 -0
  15. outputs/baseline1/beats/README.md +10 -0
  16. outputs/baseline1/beats/config.json +3 -0
  17. outputs/baseline1/beats/final/README.md +10 -0
  18. outputs/baseline1/beats/final/config.json +3 -0
  19. outputs/baseline1/beats/final/model.safetensors +3 -0
  20. outputs/baseline1/beats/logs/events.out.tfevents.1766351314.msiit232.1284330.0 +3 -0
  21. outputs/baseline1/beats/model.safetensors +3 -0
  22. outputs/baseline1/downbeats/README.md +10 -0
  23. outputs/baseline1/downbeats/config.json +3 -0
  24. outputs/baseline1/downbeats/final/README.md +10 -0
  25. outputs/baseline1/downbeats/final/config.json +3 -0
  26. outputs/baseline1/downbeats/final/model.safetensors +3 -0
  27. outputs/baseline1/downbeats/logs/events.out.tfevents.1766353075.msiit232.1284330.1 +3 -0
  28. outputs/baseline1/downbeats/model.safetensors +3 -0
  29. outputs/baseline2/beats/README.md +10 -0
  30. outputs/baseline2/beats/config.json +15 -0
  31. outputs/baseline2/beats/final/README.md +10 -0
  32. outputs/baseline2/beats/final/config.json +15 -0
  33. outputs/baseline2/beats/final/model.safetensors +3 -0
  34. outputs/baseline2/beats/logs/events.out.tfevents.1766356346.msiit232.1356098.0 +3 -0
  35. outputs/baseline2/beats/model.safetensors +3 -0
  36. outputs/baseline2/downbeats/README.md +10 -0
  37. outputs/baseline2/downbeats/config.json +15 -0
  38. outputs/baseline2/downbeats/final/README.md +10 -0
  39. outputs/baseline2/downbeats/final/config.json +15 -0
  40. outputs/baseline2/downbeats/final/model.safetensors +3 -0
  41. outputs/baseline2/downbeats/logs/events.out.tfevents.1766359276.msiit232.1356098.1 +3 -0
  42. outputs/baseline2/downbeats/model.safetensors +3 -0
.gitignore CHANGED
@@ -10,4 +10,5 @@ wheels/
10
  .venv
11
 
12
  outputs/*
13
- !outputs/baseline/
 
 
10
  .venv
11
 
12
  outputs/*
13
+ !outputs/baseline1/
14
+ !outputs/baseline2/
.ruff_cache/0.12.5/14293067367466839361 CHANGED
Binary files a/.ruff_cache/0.12.5/14293067367466839361 and b/.ruff_cache/0.12.5/14293067367466839361 differ
 
README.md CHANGED
@@ -24,7 +24,7 @@ The dataset is derived from Taiko no Tatsujin rhythm game charts, providing high
24
 
25
  | Split | Tracks | Duration | Description |
26
  |-------|--------|----------|-------------|
27
- | `train` | ~900 | 1-3 min each | Training data with beat/downbeat annotations |
28
  | `test` | ~100 | 1-3 min each | Held-out test set for final evaluation |
29
 
30
  ### Data Features
@@ -118,6 +118,17 @@ Downbeat Detection:
118
  Combined Weighted F1: X.XXXX (average of beat and downbeat)
119
  ```
120
 
 
 
 
 
 
 
 
 
 
 
 
121
  ---
122
 
123
  ## Quick Start
@@ -128,35 +139,40 @@ Combined Weighted F1: X.XXXX (average of beat and downbeat)
128
  uv sync
129
  ```
130
 
131
- ### Train Baseline Model
132
 
133
  ```bash
134
- # Train both beat and downbeat models
135
- uv run -m exp.baseline.train
136
 
137
- # Train specific model only
138
- uv run -m exp.baseline.train --target beats
139
- uv run -m exp.baseline.train --target downbeats
 
 
 
140
  ```
141
 
142
  ### Run Evaluation
143
 
144
  ```bash
145
- # Basic evaluation
146
- uv run -m exp.baseline.eval
147
 
148
  # Full evaluation with visualization and audio
149
- uv run -m exp.baseline.eval --visualize --synthesize --summary-plot
150
 
151
  # Evaluate on more samples with custom output directory
152
- uv run -m exp.baseline.eval --num-samples 50 --output-dir outputs/my_eval
153
  ```
154
 
155
  ### Evaluation Options
156
 
157
  | Option | Description |
158
  |--------|-------------|
159
- | `--model-dir DIR` | Model directory (default: `outputs/baseline`) |
 
 
160
  | `--num-samples N` | Number of samples to evaluate (default: 20) |
161
  | `--output-dir DIR` | Output directory (default: `outputs/eval`) |
162
  | `--visualize` | Generate visualization plots for each track |
@@ -175,7 +191,7 @@ uv run -m exp.baseline.eval --num-samples 50 --output-dir outputs/my_eval
175
  Generate plots comparing predicted vs ground truth beats:
176
 
177
  ```bash
178
- uv run -m exp.baseline.eval --visualize --viz-tracks 10
179
  ```
180
 
181
  Output: `outputs/eval/plots/track_XXX.png`
@@ -185,7 +201,7 @@ Output: `outputs/eval/plots/track_XXX.png`
185
  Generate audio files with click sounds overlaid on the original music:
186
 
187
  ```bash
188
- uv run -m exp.baseline.eval --synthesize
189
  ```
190
 
191
  Output files in `outputs/eval/audio/`:
@@ -198,37 +214,48 @@ Output files in `outputs/eval/audio/`:
198
  Generate bar charts summarizing F1 scores and continuity metrics:
199
 
200
  ```bash
201
- uv run -m exp.baseline.eval --summary-plot
202
  ```
203
 
204
  Output: `outputs/eval/evaluation_summary.png`
205
 
206
  ---
207
 
208
- ## Baseline Model
209
 
210
- The provided baseline implements the **Onset Detection CNN (ODCNN)** architecture:
211
 
212
- ### Architecture
213
 
 
 
 
214
  - **Input**: Multi-view mel spectrogram (3 window sizes: 23ms, 46ms, 93ms)
215
  - **CNN Backbone**: 3 convolutional blocks with max pooling
216
  - **Output**: Frame-level beat/downbeat probability
 
217
 
218
- ### Training Details
219
 
220
- - **Optimizer**: SGD with momentum (0.9)
221
- - **Learning Rate**: 0.05 with cosine annealing
222
- - **Loss**: Binary Cross-Entropy
223
- - **Epochs**: 50
224
- - **Batch Size**: 512
225
 
226
- ### Inference Pipeline
 
 
 
 
 
227
 
228
- 1. Compute multi-view mel spectrogram on GPU
229
- 2. Sliding window inference (±7 frames context = ±70ms)
230
- 3. Hamming window smoothing
231
- 4. Peak picking with threshold (0.5) and minimum distance (5 frames)
 
 
 
 
232
 
233
  ---
234
 
@@ -237,21 +264,25 @@ The provided baseline implements the **Onset Detection CNN (ODCNN)** architectur
237
  ```
238
  exp-onset/
239
  ├── exp/
240
- │ ├── baseline/ # Baseline model implementation
241
  │ │ ├── model.py # ODCNN architecture
242
- │ │ ├── train.py # Training script
243
- │ │ ├── eval.py # Evaluation with viz/audio
244
- │ │ ├── data.py # Dataset wrapper
245
- │ │ └── utils.py # Spectrogram processing
 
 
 
 
 
246
  │ └── data/
247
  │ ├── load.py # Dataset loading & preprocessing
248
  │ ├── eval.py # Evaluation metrics (F1, CML, AML)
249
  │ ├── audio.py # Click track synthesis
250
  │ └── viz.py # Visualization utilities
251
  ├── outputs/
252
- │ ├── baseline/ # Trained models
253
- ├── beats/ # Beat detection model
254
- │ │ └── downbeats/ # Downbeat detection model
255
  │ └── eval/ # Evaluation outputs
256
  │ ├── plots/ # Visualization images
257
  │ ├── audio/ # Click track audio files
 
24
 
25
  | Split | Tracks | Duration | Description |
26
  |-------|--------|----------|-------------|
27
+ | `train` | ~1000 | 1-3 min each | Training data with beat/downbeat annotations |
28
  | `test` | ~100 | 1-3 min each | Held-out test set for final evaluation |
29
 
30
  ### Data Features
 
118
  Combined Weighted F1: X.XXXX (average of beat and downbeat)
119
  ```
120
 
121
+ ### Benchmark Results
122
+
123
+ Results evaluated on 100 tracks from the test set:
124
+
125
+ | Model | Combined F1 | Beat F1 | Downbeat F1 | CMLt (Beat) | CMLt (Downbeat) |
126
+ |-------|-------------|---------|-------------|-------------|-----------------|
127
+ | **Baseline 1 (ODCNN)** | 0.0765 | 0.0861 | 0.0669 | 0.0731 | 0.0321 |
128
+ | **Baseline 2 (ResNet-SE)** | **0.2775** | **0.3292** | **0.2258** | **0.3287** | **0.1146** |
129
+
130
+ *Note: Baseline 2 (ResNet-SE) demonstrates significantly better performance due to larger context window and deeper architecture.*
131
+
132
  ---
133
 
134
  ## Quick Start
 
139
  uv sync
140
  ```
141
 
142
+ ### Train Models
143
 
144
  ```bash
145
+ # Train Baseline 1 (ODCNN)
146
+ uv run -m exp.baseline1.train
147
 
148
+ # Train Baseline 2 (ResNet-SE)
149
+ uv run -m exp.baseline2.train
150
+
151
+ # Train specific target only (e.g. for Baseline 2)
152
+ uv run -m exp.baseline2.train --target beats
153
+ uv run -m exp.baseline2.train --target downbeats
154
  ```
155
 
156
  ### Run Evaluation
157
 
158
  ```bash
159
+ # Evaluation (replace baseline1 with baseline2 to evaluate the new model)
160
+ uv run -m exp.baseline1.eval
161
 
162
  # Full evaluation with visualization and audio
163
+ uv run -m exp.baseline1.eval --visualize --synthesize --summary-plot
164
 
165
  # Evaluate on more samples with custom output directory
166
+ uv run -m exp.baseline1.eval --num-samples 50 --output-dir outputs/eval_baseline1
167
  ```
168
 
169
  ### Evaluation Options
170
 
171
  | Option | Description |
172
  |--------|-------------|
173
+ | Option | Description |
174
+ |--------|-------------|
175
+ | `--model-dir DIR` | Model directory (default: `outputs/baseline1`) |
176
  | `--num-samples N` | Number of samples to evaluate (default: 20) |
177
  | `--output-dir DIR` | Output directory (default: `outputs/eval`) |
178
  | `--visualize` | Generate visualization plots for each track |
 
191
  Generate plots comparing predicted vs ground truth beats:
192
 
193
  ```bash
194
+ uv run -m exp.baseline1.eval --visualize --viz-tracks 10
195
  ```
196
 
197
  Output: `outputs/eval/plots/track_XXX.png`
 
201
  Generate audio files with click sounds overlaid on the original music:
202
 
203
  ```bash
204
+ uv run -m exp.baseline1.eval --synthesize
205
  ```
206
 
207
  Output files in `outputs/eval/audio/`:
 
214
  Generate bar charts summarizing F1 scores and continuity metrics:
215
 
216
  ```bash
217
+ uv run -m exp.baseline1.eval --summary-plot
218
  ```
219
 
220
  Output: `outputs/eval/evaluation_summary.png`
221
 
222
  ---
223
 
224
+ ## Models
225
 
226
+ ### Baseline 1: ODCNN
227
 
228
+ A 10-year-old baseline model: <https://ieeexplore.ieee.org/document/6854953>.
229
 
230
+ The original baseline implements the **Onset Detection CNN (ODCNN)** architecture:
231
+
232
+ #### Architecture
233
  - **Input**: Multi-view mel spectrogram (3 window sizes: 23ms, 46ms, 93ms)
234
  - **CNN Backbone**: 3 convolutional blocks with max pooling
235
  - **Output**: Frame-level beat/downbeat probability
236
+ - **Inference**: ±7 frames context (±70ms)
237
 
238
+ ### Baseline 2: ResNet-SE
239
 
240
+ Inspired by ResNet-SE: <https://arxiv.org/abs/1709.01507>.
241
+
242
+ A modernized architecture designed to capture longer temporal context:
 
 
243
 
244
+ #### Architecture
245
+ - **Input**: Mel spectrogram with larger context
246
+ - **Backbone**: ResNet with Squeeze-and-Excitation (SE) blocks
247
+ - **Context**: **±50 frames (~1s)** window
248
+ - **Features**: Deeper network (4 stages) with effective channel attention
249
+ - **Parameters**: ~400k (Small & Efficient)
250
 
251
+ ### Training Details
252
+
253
+ Both models use similar training loops:
254
+ - **Optimizer**: SGD (Baseline 1) / AdamW (Baseline 2)
255
+ - **Learning Rate**: Cosine annealing
256
+ - **Loss**: Binary Cross-Entropy
257
+ - **Epochs**: 50 (Baseline 1) / 3 (Baseline 2)
258
+ - **Batch Size**: 512 (Baseline 1) / 128 (Baseline 2)
259
 
260
  ---
261
 
 
264
  ```
265
  exp-onset/
266
  ├── exp/
267
+ │ ├── baseline1/ # Baseline 1 (ODCNN)
268
  │ │ ├── model.py # ODCNN architecture
269
+ │ │ ├── train.py
270
+ │ │ ├── eval.py
271
+ │ │ ├── data.py
272
+ │ │ └── utils.py
273
+ │ ├── baseline2/ # Baseline 2 (ResNet-SE)
274
+ │ │ ├── model.py # ResNet-SE
275
+ │ │ ├── train.py
276
+ │ │ ├── eval.py
277
+ │ │ └── data.py
278
  │ └── data/
279
  │ ├── load.py # Dataset loading & preprocessing
280
  │ ├── eval.py # Evaluation metrics (F1, CML, AML)
281
  │ ├── audio.py # Click track synthesis
282
  │ └── viz.py # Visualization utilities
283
  ├── outputs/
284
+ │ ├── baseline1/ # Trained models (Baseline 1)
285
+ │ ├── baseline2/ # Trained models (Baseline 2)
 
286
  │ └── eval/ # Evaluation outputs
287
  │ ├── plots/ # Visualization images
288
  │ ├── audio/ # Click track audio files
exp/baseline1/__init__.py ADDED
File without changes
exp/baseline1/data.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from .utils import extract_context
6
+
7
+
8
+ class BeatTrackingDataset(Dataset):
9
+ def __init__(
10
+ self, hf_dataset, target_type="beats", sample_rate=16000, hop_length=160
11
+ ):
12
+ """
13
+ Args:
14
+ hf_dataset: HuggingFace dataset object
15
+ target_type (str): "beats" or "downbeats". Determines which labels are treated as positive.
16
+ """
17
+ self.sr = sample_rate
18
+ self.hop_length = hop_length
19
+ self.target_type = target_type
20
+
21
+ # Context window size in samples (7 frames = 70ms at 100fps)
22
+ self.context_frames = 7
23
+ self.context_samples = (self.context_frames * 2 + 1) * hop_length + max(
24
+ [368, 736, 1488]
25
+ ) # extra for FFT window
26
+
27
+ # Cache audio arrays in memory for fast access
28
+ self.audio_cache = []
29
+ self.indices = []
30
+ self._prepare_indices(hf_dataset)
31
+
32
+ def _prepare_indices(self, hf_dataset):
33
+ """
34
+ Prepares balanced indices and caches audio.
35
+ Paper Section 4.5: Uses "Fuzzier" training examples (neighbors weighted less).
36
+ """
37
+ print(f"Preparing dataset indices for target: {self.target_type}...")
38
+
39
+ for i, item in tqdm(
40
+ enumerate(hf_dataset), total=len(hf_dataset), desc="Building indices"
41
+ ):
42
+ # Cache audio array (convert to numpy if tensor)
43
+ audio = item["audio"]["array"]
44
+ if hasattr(audio, "numpy"):
45
+ audio = audio.numpy()
46
+ self.audio_cache.append(audio)
47
+
48
+ # Calculate total frames available in audio
49
+ audio_len = len(audio)
50
+ n_frames = int(audio_len / self.hop_length)
51
+
52
+ # Select ground truth based on target_type
53
+ if self.target_type == "downbeats":
54
+ # Only downbeats are positives
55
+ gt_times = item["downbeats"]
56
+ else:
57
+ # All beats are positives (downbeats are also beats)
58
+ gt_times = item["beats"]
59
+
60
+ # Convert to list if tensor
61
+ if hasattr(gt_times, "tolist"):
62
+ gt_times = gt_times.tolist()
63
+
64
+ gt_frames = set([int(t * self.sr / self.hop_length) for t in gt_times])
65
+
66
+ # --- Positive Examples (with Fuzziness) ---
67
+ # "define a single frame before and after each annotated onset to be additional positive examples"
68
+ pos_frames = set()
69
+ for bf in gt_frames:
70
+ if 0 <= bf < n_frames:
71
+ self.indices.append((i, bf, 1.0)) # Center frame (Sharp onset)
72
+ pos_frames.add(bf)
73
+
74
+ # Neighbors weighted at 0.25
75
+ if 0 <= bf - 1 < n_frames:
76
+ self.indices.append((i, bf - 1, 0.25))
77
+ pos_frames.add(bf - 1)
78
+ if 0 <= bf + 1 < n_frames:
79
+ self.indices.append((i, bf + 1, 0.25))
80
+ pos_frames.add(bf + 1)
81
+
82
+ # --- Negative Examples ---
83
+ # Paper uses "all others as negative", but we balance 2:1 for stable SGD.
84
+ num_pos = len(pos_frames)
85
+ num_neg = num_pos * 2
86
+
87
+ count = 0
88
+ attempts = 0
89
+ while count < num_neg and attempts < num_neg * 5:
90
+ f = np.random.randint(0, n_frames)
91
+ if f not in pos_frames:
92
+ self.indices.append((i, f, 0.0))
93
+ count += 1
94
+ attempts += 1
95
+
96
+ print(
97
+ f"Dataset ready. {len(self.indices)} samples, {len(self.audio_cache)} tracks cached."
98
+ )
99
+
100
+ def __len__(self):
101
+ return len(self.indices)
102
+
103
+ def __getitem__(self, idx):
104
+ track_idx, frame_idx, label = self.indices[idx]
105
+
106
+ # Fast lookup from cache
107
+ audio = self.audio_cache[track_idx]
108
+ audio_len = len(audio)
109
+
110
+ # Calculate sample range for context window
111
+ center_sample = frame_idx * self.hop_length
112
+ half_context = self.context_samples // 2
113
+ start = center_sample - half_context
114
+ end = center_sample + half_context
115
+
116
+ # Handle padding if needed
117
+ pad_left = max(0, -start)
118
+ pad_right = max(0, end - audio_len)
119
+ start = max(0, start)
120
+ end = min(audio_len, end)
121
+
122
+ # Extract audio chunk
123
+ chunk = audio[start:end]
124
+ if pad_left > 0 or pad_right > 0:
125
+ chunk = np.pad(chunk, (pad_left, pad_right), mode="constant")
126
+
127
+ waveform = torch.tensor(chunk, dtype=torch.float32)
128
+ return waveform, torch.tensor([label], dtype=torch.float32)
exp/baseline1/eval.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from scipy.signal import find_peaks
5
+ import argparse
6
+ import os
7
+
8
+ from .model import ODCNN
9
+ from .utils import MultiViewSpectrogram
10
+ from ..data.load import ds
11
+ from ..data.eval import evaluate_all, format_results
12
+
13
+
14
+ def get_activation_function(model, waveform, device):
15
+ """
16
+ Computes probability curve over time.
17
+ """
18
+ processor = MultiViewSpectrogram().to(device)
19
+ waveform = waveform.unsqueeze(0).to(device)
20
+
21
+ with torch.no_grad():
22
+ spec = processor(waveform)
23
+
24
+ # Normalize
25
+ mean = spec.mean(dim=(2, 3), keepdim=True)
26
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
27
+ spec = (spec - mean) / std
28
+
29
+ # Batchify with sliding window
30
+ spec = torch.nn.functional.pad(spec, (7, 7)) # Pad time
31
+ windows = spec.unfold(3, 15, 1) # (1, 3, 80, Time, 15)
32
+ windows = windows.permute(3, 0, 1, 2, 4).squeeze(1) # (Time, 3, 80, 15)
33
+
34
+ # Inference
35
+ activations = []
36
+ batch_size = 512
37
+ for i in range(0, len(windows), batch_size):
38
+ batch = windows[i : i + batch_size]
39
+ out = model(batch)
40
+ activations.append(out.cpu().numpy())
41
+
42
+ return np.concatenate(activations).flatten()
43
+
44
+
45
+ def pick_peaks(activations, hop_length=160, sr=16000):
46
+ """
47
+ Smooth with Hamming window and report local maxima.
48
+ """
49
+ # Smoothing
50
+ window = np.hamming(5)
51
+ window /= window.sum()
52
+ smoothed = np.convolve(activations, window, mode="same")
53
+
54
+ # Peak Picking
55
+ peaks, _ = find_peaks(smoothed, height=0.5, distance=5)
56
+
57
+ timestamps = peaks * hop_length / sr
58
+ return timestamps.tolist()
59
+
60
+
61
+ def visualize_track(
62
+ audio: np.ndarray,
63
+ sr: int,
64
+ pred_beats: list[float],
65
+ pred_downbeats: list[float],
66
+ gt_beats: list[float],
67
+ gt_downbeats: list[float],
68
+ output_dir: str,
69
+ track_idx: int,
70
+ time_range: tuple[float, float] | None = None,
71
+ ):
72
+ """
73
+ Create and save visualizations for a single track.
74
+ """
75
+ from ..data.viz import plot_waveform_with_beats, save_figure
76
+
77
+ os.makedirs(output_dir, exist_ok=True)
78
+
79
+ # Full waveform plot
80
+ fig = plot_waveform_with_beats(
81
+ audio,
82
+ sr,
83
+ pred_beats,
84
+ gt_beats,
85
+ pred_downbeats,
86
+ gt_downbeats,
87
+ title=f"Track {track_idx}: Beat Comparison",
88
+ time_range=time_range,
89
+ )
90
+ save_figure(fig, os.path.join(output_dir, f"track_{track_idx:03d}.png"))
91
+
92
+
93
+ def synthesize_audio(
94
+ audio: np.ndarray,
95
+ sr: int,
96
+ pred_beats: list[float],
97
+ pred_downbeats: list[float],
98
+ gt_beats: list[float],
99
+ gt_downbeats: list[float],
100
+ output_dir: str,
101
+ track_idx: int,
102
+ click_volume: float = 0.5,
103
+ ):
104
+ """
105
+ Create and save audio files with click tracks for a single track.
106
+ """
107
+ from ..data.audio import create_comparison_audio, save_audio
108
+
109
+ os.makedirs(output_dir, exist_ok=True)
110
+
111
+ # Create comparison audio
112
+ audio_pred, audio_gt, audio_both = create_comparison_audio(
113
+ audio,
114
+ pred_beats,
115
+ pred_downbeats,
116
+ gt_beats,
117
+ gt_downbeats,
118
+ sr=sr,
119
+ click_volume=click_volume,
120
+ )
121
+
122
+ # Save audio files
123
+ save_audio(
124
+ audio_pred, os.path.join(output_dir, f"track_{track_idx:03d}_pred.wav"), sr
125
+ )
126
+ save_audio(audio_gt, os.path.join(output_dir, f"track_{track_idx:03d}_gt.wav"), sr)
127
+ save_audio(
128
+ audio_both, os.path.join(output_dir, f"track_{track_idx:03d}_both.wav"), sr
129
+ )
130
+
131
+
132
+ def main():
133
+ parser = argparse.ArgumentParser(
134
+ description="Evaluate beat tracking models with visualization and audio synthesis"
135
+ )
136
+ parser.add_argument(
137
+ "--model-dir",
138
+ type=str,
139
+ default="outputs/baseline1",
140
+ help="Base directory containing trained models (with 'beats' and 'downbeats' subdirs)",
141
+ )
142
+ parser.add_argument(
143
+ "--num-samples",
144
+ type=int,
145
+ default=116,
146
+ help="Number of samples to evaluate",
147
+ )
148
+ parser.add_argument(
149
+ "--output-dir",
150
+ type=str,
151
+ default="outputs/eval_baseline1",
152
+ help="Directory to save visualizations and audio",
153
+ )
154
+ parser.add_argument(
155
+ "--visualize",
156
+ action="store_true",
157
+ help="Generate visualization plots for each track",
158
+ )
159
+ parser.add_argument(
160
+ "--synthesize",
161
+ action="store_true",
162
+ help="Generate audio files with click tracks",
163
+ )
164
+ parser.add_argument(
165
+ "--viz-tracks",
166
+ type=int,
167
+ default=5,
168
+ help="Number of tracks to visualize/synthesize (default: 5)",
169
+ )
170
+ parser.add_argument(
171
+ "--time-range",
172
+ type=float,
173
+ nargs=2,
174
+ default=None,
175
+ metavar=("START", "END"),
176
+ help="Time range for visualization in seconds (default: full track)",
177
+ )
178
+ parser.add_argument(
179
+ "--click-volume",
180
+ type=float,
181
+ default=0.5,
182
+ help="Volume of click sounds relative to audio (0.0 to 1.0)",
183
+ )
184
+ parser.add_argument(
185
+ "--summary-plot",
186
+ action="store_true",
187
+ help="Generate summary evaluation plot",
188
+ )
189
+ args = parser.parse_args()
190
+
191
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
192
+
193
+ # Load BOTH models using from_pretrained
194
+ beat_model = None
195
+ downbeat_model = None
196
+
197
+ has_beats = False
198
+ has_downbeats = False
199
+
200
+ beats_dir = os.path.join(args.model_dir, "beats")
201
+ downbeats_dir = os.path.join(args.model_dir, "downbeats")
202
+
203
+ if os.path.exists(os.path.join(beats_dir, "model.safetensors")):
204
+ beat_model = ODCNN.from_pretrained(beats_dir).to(DEVICE)
205
+ beat_model.eval()
206
+ has_beats = True
207
+ print(f"Loaded Beat Model from {beats_dir}")
208
+ else:
209
+ print(f"Warning: No beat model found in {beats_dir}")
210
+
211
+ if os.path.exists(os.path.join(downbeats_dir, "model.safetensors")):
212
+ downbeat_model = ODCNN.from_pretrained(downbeats_dir).to(DEVICE)
213
+ downbeat_model.eval()
214
+ has_downbeats = True
215
+ print(f"Loaded Downbeat Model from {downbeats_dir}")
216
+ else:
217
+ print(f"Warning: No downbeat model found in {downbeats_dir}")
218
+
219
+ if not has_beats and not has_downbeats:
220
+ print("No models found. Please run training first.")
221
+ return
222
+
223
+ predictions = []
224
+ ground_truths = []
225
+ audio_data = [] # Store audio for visualization/synthesis
226
+
227
+ # Eval on specified number of tracks
228
+ test_set = ds["train"].select(range(args.num_samples))
229
+
230
+ print("Running evaluation...")
231
+ for i, item in enumerate(tqdm(test_set)):
232
+ waveform = torch.tensor(item["audio"]["array"], dtype=torch.float32)
233
+ waveform_device = waveform.to(DEVICE)
234
+
235
+ pred_entry = {"beats": [], "downbeats": []}
236
+
237
+ # 1. Predict Beats
238
+ if has_beats:
239
+ act_b = get_activation_function(beat_model, waveform_device, DEVICE)
240
+ pred_entry["beats"] = pick_peaks(act_b)
241
+
242
+ # 2. Predict Downbeats
243
+ if has_downbeats:
244
+ act_d = get_activation_function(downbeat_model, waveform_device, DEVICE)
245
+ pred_entry["downbeats"] = pick_peaks(act_d)
246
+
247
+ predictions.append(pred_entry)
248
+ ground_truths.append({"beats": item["beats"], "downbeats": item["downbeats"]})
249
+
250
+ # Store audio for later visualization/synthesis
251
+ if args.visualize or args.synthesize:
252
+ if i < args.viz_tracks:
253
+ audio_data.append(
254
+ {
255
+ "audio": waveform.numpy(),
256
+ "sr": item["audio"]["sampling_rate"],
257
+ "pred": pred_entry,
258
+ "gt": ground_truths[-1],
259
+ }
260
+ )
261
+
262
+ # Run evaluation
263
+ results = evaluate_all(predictions, ground_truths)
264
+ print(format_results(results))
265
+
266
+ # Create output directory
267
+ if args.visualize or args.synthesize or args.summary_plot:
268
+ os.makedirs(args.output_dir, exist_ok=True)
269
+
270
+ # Generate visualizations
271
+ if args.visualize:
272
+ print(f"\nGenerating visualizations for {len(audio_data)} tracks...")
273
+ viz_dir = os.path.join(args.output_dir, "plots")
274
+ for i, data in enumerate(tqdm(audio_data, desc="Visualizing")):
275
+ time_range = tuple(args.time_range) if args.time_range else None
276
+ visualize_track(
277
+ data["audio"],
278
+ data["sr"],
279
+ data["pred"]["beats"],
280
+ data["pred"]["downbeats"],
281
+ data["gt"]["beats"],
282
+ data["gt"]["downbeats"],
283
+ viz_dir,
284
+ i,
285
+ time_range=time_range,
286
+ )
287
+ print(f"Saved visualizations to {viz_dir}")
288
+
289
+ # Generate audio with clicks
290
+ if args.synthesize:
291
+ print(f"\nSynthesizing audio for {len(audio_data)} tracks...")
292
+ audio_dir = os.path.join(args.output_dir, "audio")
293
+ for i, data in enumerate(tqdm(audio_data, desc="Synthesizing")):
294
+ synthesize_audio(
295
+ data["audio"],
296
+ data["sr"],
297
+ data["pred"]["beats"],
298
+ data["pred"]["downbeats"],
299
+ data["gt"]["beats"],
300
+ data["gt"]["downbeats"],
301
+ audio_dir,
302
+ i,
303
+ click_volume=args.click_volume,
304
+ )
305
+ print(f"Saved audio files to {audio_dir}")
306
+ print(" *_pred.wav - Original audio with predicted beat clicks")
307
+ print(" *_gt.wav - Original audio with ground truth beat clicks")
308
+ print(" *_both.wav - Original audio with both predicted and GT clicks")
309
+
310
+ # Generate summary plot
311
+ if args.summary_plot:
312
+ from ..data.viz import plot_evaluation_summary, save_figure
313
+
314
+ print("\nGenerating summary plot...")
315
+ fig = plot_evaluation_summary(results, title="Beat Tracking Evaluation Summary")
316
+ summary_path = os.path.join(args.output_dir, "evaluation_summary.png")
317
+ save_figure(fig, summary_path)
318
+ print(f"Saved summary plot to {summary_path}")
319
+
320
+
321
+ if __name__ == "__main__":
322
+ main()
exp/baseline1/model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+
6
+ class ODCNN(nn.Module, PyTorchModelHubMixin):
7
+ def __init__(self, dropout_rate=0.5):
8
+ super().__init__()
9
+
10
+ # Input 3 channels, 80 bands
11
+ # Conv 1: 7x3 filters -> 10 maps
12
+ self.conv1 = nn.Conv2d(3, 10, kernel_size=(3, 7))
13
+ self.relu1 = nn.ReLU() # ReLU improvement
14
+ self.pool1 = nn.MaxPool2d(kernel_size=(3, 1), stride=(3, 1))
15
+
16
+ # Conv 2: 3x3 filters -> 20 maps
17
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=(3, 3))
18
+ self.relu2 = nn.ReLU()
19
+ self.pool2 = nn.MaxPool2d(kernel_size=(3, 1), stride=(3, 1))
20
+
21
+ # Flatten size calculation based on architecture
22
+ # (20 feature maps * 8 freq bands * 7 time frames)
23
+ self.flatten_size = 20 * 8 * 7
24
+
25
+ # Dropout on FC inputs
26
+ self.dropout = nn.Dropout(p=dropout_rate)
27
+
28
+ # 256 Hidden Units
29
+ self.fc1 = nn.Linear(self.flatten_size, 256)
30
+ self.relu_fc = nn.ReLU()
31
+
32
+ # Output Unit
33
+ self.fc2 = nn.Linear(256, 1)
34
+ self.sigmoid = nn.Sigmoid()
35
+
36
+ def forward(self, x):
37
+ x = self.conv1(x)
38
+ x = self.relu1(x)
39
+ x = self.pool1(x)
40
+
41
+ x = self.conv2(x)
42
+ x = self.relu2(x)
43
+ x = self.pool2(x)
44
+
45
+ x = x.view(x.size(0), -1)
46
+
47
+ x = self.dropout(x)
48
+ x = self.fc1(x)
49
+ x = self.relu_fc(x)
50
+
51
+ x = self.dropout(x)
52
+ x = self.fc2(x)
53
+ x = self.sigmoid(x)
54
+
55
+ return x
56
+
57
+
58
+ if __name__ == "__main__":
59
+ from torchinfo import summary
60
+
61
+ model = ODCNN()
62
+ summary(model, (1, 3, 80, 15))
exp/baseline1/train.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ from tqdm import tqdm
7
+ import argparse
8
+ import os
9
+
10
+ from .model import ODCNN
11
+ from .data import BeatTrackingDataset
12
+ from .utils import MultiViewSpectrogram
13
+ from ..data.load import ds
14
+
15
+
16
+ def train(target_type: str, output_dir: str):
17
+ # Note: Paper uses SGD with Momentum, Dropout, and ReLU
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ BATCH_SIZE = 512
20
+ EPOCHS = 50
21
+ LR = 0.05
22
+ MOMENTUM = 0.9
23
+ NUM_WORKERS = 4
24
+
25
+ print(f"--- Training Model for target: {target_type} ---")
26
+ print(f"Output directory: {output_dir}")
27
+
28
+ # Create output directory
29
+ os.makedirs(output_dir, exist_ok=True)
30
+
31
+ # TensorBoard writer
32
+ writer = SummaryWriter(log_dir=os.path.join(output_dir, "logs"))
33
+
34
+ # Data - use existing train/test splits
35
+ train_dataset = BeatTrackingDataset(ds["train"], target_type=target_type)
36
+ val_dataset = BeatTrackingDataset(ds["test"], target_type=target_type)
37
+
38
+ train_loader = DataLoader(
39
+ train_dataset,
40
+ batch_size=BATCH_SIZE,
41
+ shuffle=True,
42
+ num_workers=NUM_WORKERS,
43
+ pin_memory=True,
44
+ prefetch_factor=4,
45
+ persistent_workers=True,
46
+ )
47
+ val_loader = DataLoader(
48
+ val_dataset,
49
+ batch_size=BATCH_SIZE,
50
+ shuffle=False,
51
+ num_workers=NUM_WORKERS,
52
+ pin_memory=True,
53
+ prefetch_factor=4,
54
+ persistent_workers=True,
55
+ )
56
+
57
+ print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
58
+
59
+ # Model
60
+ model = ODCNN(dropout_rate=0.5).to(DEVICE)
61
+
62
+ # GPU Spectrogram Preprocessor
63
+ preprocessor = MultiViewSpectrogram(sample_rate=16000, hop_length=160).to(DEVICE)
64
+
65
+ # Optimizer
66
+ optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
67
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
68
+ criterion = nn.BCELoss() # Binary Cross Entropy
69
+
70
+ best_val_loss = float("inf")
71
+ global_step = 0
72
+
73
+ for epoch in range(EPOCHS):
74
+ # Training
75
+ model.train()
76
+ total_train_loss = 0
77
+ for waveform, y in tqdm(
78
+ train_loader,
79
+ desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Train",
80
+ leave=False,
81
+ ):
82
+ waveform, y = waveform.to(DEVICE), y.to(DEVICE)
83
+
84
+ # Compute spectrogram on GPU
85
+ with torch.no_grad():
86
+ spec = preprocessor(waveform) # (B, 3, 80, T)
87
+ # Normalize
88
+ mean = spec.mean(dim=(2, 3), keepdim=True)
89
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
90
+ spec = (spec - mean) / std
91
+ # Extract center context (T should be ~15 frames)
92
+ x = spec[:, :, :, 7:22] # center 15 frames
93
+
94
+ optimizer.zero_grad()
95
+ output = model(x)
96
+ loss = criterion(output, y)
97
+ loss.backward()
98
+ optimizer.step()
99
+
100
+ total_train_loss += loss.item()
101
+ global_step += 1
102
+
103
+ # Log batch loss
104
+ writer.add_scalar("train/batch_loss", loss.item(), global_step)
105
+
106
+ avg_train_loss = total_train_loss / len(train_loader)
107
+
108
+ # Validation
109
+ model.eval()
110
+ total_val_loss = 0
111
+ with torch.no_grad():
112
+ for waveform, y in tqdm(
113
+ val_loader,
114
+ desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Val",
115
+ leave=False,
116
+ ):
117
+ waveform, y = waveform.to(DEVICE), y.to(DEVICE)
118
+
119
+ # Compute spectrogram on GPU
120
+ spec = preprocessor(waveform) # (B, 3, 80, T)
121
+ # Normalize
122
+ mean = spec.mean(dim=(2, 3), keepdim=True)
123
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
124
+ spec = (spec - mean) / std
125
+ # Extract center context
126
+ x = spec[:, :, :, 7:22]
127
+
128
+ output = model(x)
129
+ loss = criterion(output, y)
130
+ total_val_loss += loss.item()
131
+
132
+ avg_val_loss = total_val_loss / len(val_loader)
133
+
134
+ # Log epoch metrics
135
+ writer.add_scalar("train/epoch_loss", avg_train_loss, epoch)
136
+ writer.add_scalar("val/loss", avg_val_loss, epoch)
137
+ writer.add_scalar("train/learning_rate", scheduler.get_last_lr()[0], epoch)
138
+
139
+ # Step the scheduler
140
+ scheduler.step()
141
+
142
+ print(
143
+ f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} - "
144
+ f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}"
145
+ )
146
+
147
+ # Save best model
148
+ if avg_val_loss < best_val_loss:
149
+ best_val_loss = avg_val_loss
150
+ model.save_pretrained(output_dir)
151
+ print(f" -> Saved best model (val_loss: {best_val_loss:.4f})")
152
+
153
+ writer.close()
154
+
155
+ # Save final model
156
+ final_dir = os.path.join(output_dir, "final")
157
+ model.save_pretrained(final_dir)
158
+ print(f"Saved final model to {final_dir}")
159
+
160
+
161
+ if __name__ == "__main__":
162
+ parser = argparse.ArgumentParser()
163
+ parser.add_argument(
164
+ "--target",
165
+ type=str,
166
+ choices=["beats", "downbeats"],
167
+ default=None,
168
+ help="Train a model for 'beats' or 'downbeats'. If not specified, trains both.",
169
+ )
170
+ parser.add_argument(
171
+ "--output-dir",
172
+ type=str,
173
+ default="outputs/baseline1",
174
+ help="Directory to save model and logs",
175
+ )
176
+ args = parser.parse_args()
177
+
178
+ # Determine which targets to train
179
+ targets = [args.target] if args.target else ["beats", "downbeats"]
180
+
181
+ for target in targets:
182
+ output_dir = os.path.join(args.output_dir, target)
183
+ train(target, output_dir)
exp/baseline1/utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchaudio.transforms as T
4
+ import numpy as np
5
+
6
+
7
+ class MultiViewSpectrogram(nn.Module):
8
+ def __init__(self, sample_rate=16000, n_mels=80, hop_length=160):
9
+ super().__init__()
10
+ # Window sizes: 23ms, 46ms, 93ms
11
+ self.win_lengths = [368, 736, 1488]
12
+ self.transforms = nn.ModuleList()
13
+
14
+ for win_len in self.win_lengths:
15
+ n_fft = 2 ** int(np.ceil(np.log2(win_len)))
16
+ mel = T.MelSpectrogram(
17
+ sample_rate=sample_rate,
18
+ n_fft=n_fft,
19
+ win_length=win_len,
20
+ hop_length=hop_length,
21
+ f_min=27.5,
22
+ f_max=16000.0,
23
+ n_mels=n_mels,
24
+ power=1.0,
25
+ center=True,
26
+ )
27
+ self.transforms.append(mel)
28
+
29
+ def forward(self, waveform):
30
+ specs = []
31
+ for transform in self.transforms:
32
+ # Scale magnitudes logarithmically
33
+ s = transform(waveform)
34
+ s = torch.log(s + 1e-9)
35
+ specs.append(s)
36
+ return torch.stack(specs, dim=1)
37
+
38
+
39
+ def extract_context(spec, center_frame, context=7):
40
+ # Context of +/- 70ms (7 frames)
41
+ channels, n_mels, total_time = spec.shape
42
+ start = center_frame - context
43
+ end = center_frame + context + 1
44
+
45
+ pad_left = max(0, -start)
46
+ pad_right = max(0, end - total_time)
47
+
48
+ if pad_left > 0 or pad_right > 0:
49
+ spec = torch.nn.functional.pad(spec, (pad_left, pad_right))
50
+ start += pad_left
51
+ end += pad_left
52
+
53
+ return spec[:, :, start:end]
exp/baseline2/__init__.py ADDED
File without changes
exp/baseline2/data.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+
6
+
7
+ class BeatTrackingDataset(Dataset):
8
+ def __init__(
9
+ self,
10
+ hf_dataset,
11
+ target_type="beats",
12
+ sample_rate=16000,
13
+ hop_length=160,
14
+ context_frames=50,
15
+ ):
16
+ """
17
+ Args:
18
+ hf_dataset: HuggingFace dataset object
19
+ target_type (str): "beats" or "downbeats". Determines which labels are treated as positive.
20
+ context_frames (int): Number of frames before and after the center frame.
21
+ Total frames = 2 * context_frames + 1.
22
+ Default 50 means 101 frames (~1s).
23
+ """
24
+ self.sr = sample_rate
25
+ self.hop_length = hop_length
26
+ self.target_type = target_type
27
+
28
+ self.context_frames = context_frames
29
+ # Context window size in samples
30
+ # We need enough samples for the center frame +/- context frames
31
+ # PLUS the window size of the largest FFT to compute the edges correctly.
32
+ # Largest window in MultiViewSpectrogram is 1488.
33
+ self.context_samples = (self.context_frames * 2 + 1) * hop_length + 1488
34
+
35
+ # Cache audio arrays in memory for fast access
36
+ self.audio_cache = []
37
+ self.indices = []
38
+ self._prepare_indices(hf_dataset)
39
+
40
+ def _prepare_indices(self, hf_dataset):
41
+ """
42
+ Prepares balanced indices and caches audio.
43
+ Uses the same "Fuzzier" training examples strategy as the baseline.
44
+ """
45
+ print(f"Preparing dataset indices for target: {self.target_type}...")
46
+
47
+ for i, item in tqdm(
48
+ enumerate(hf_dataset), total=len(hf_dataset), desc="Building indices"
49
+ ):
50
+ # Cache audio array (convert to numpy if tensor)
51
+ audio = item["audio"]["array"]
52
+ if hasattr(audio, "numpy"):
53
+ audio = audio.numpy()
54
+ self.audio_cache.append(audio)
55
+
56
+ # Calculate total frames available in audio
57
+ audio_len = len(audio)
58
+ n_frames = int(audio_len / self.hop_length)
59
+
60
+ # Select ground truth based on target_type
61
+ if self.target_type == "downbeats":
62
+ gt_times = item["downbeats"]
63
+ else:
64
+ gt_times = item["beats"]
65
+
66
+ # Convert to list if tensor
67
+ if hasattr(gt_times, "tolist"):
68
+ gt_times = gt_times.tolist()
69
+
70
+ gt_frames = set([int(t * self.sr / self.hop_length) for t in gt_times])
71
+
72
+ # --- Positive Examples (with Fuzziness) ---
73
+ pos_frames = set()
74
+ for bf in gt_frames:
75
+ if 0 <= bf < n_frames:
76
+ self.indices.append((i, bf, 1.0)) # Center frame
77
+ pos_frames.add(bf)
78
+
79
+ # Neighbors weighted at 0.25
80
+ if 0 <= bf - 1 < n_frames:
81
+ self.indices.append((i, bf - 1, 0.25))
82
+ pos_frames.add(bf - 1)
83
+ if 0 <= bf + 1 < n_frames:
84
+ self.indices.append((i, bf + 1, 0.25))
85
+ pos_frames.add(bf + 1)
86
+
87
+ # --- Negative Examples ---
88
+ # Balance 2:1
89
+ num_pos = len(pos_frames)
90
+ num_neg = num_pos * 2
91
+
92
+ count = 0
93
+ attempts = 0
94
+ while count < num_neg and attempts < num_neg * 5:
95
+ f = np.random.randint(0, n_frames)
96
+ if f not in pos_frames:
97
+ self.indices.append((i, f, 0.0))
98
+ count += 1
99
+ attempts += 1
100
+
101
+ print(
102
+ f"Dataset ready. {len(self.indices)} samples, {len(self.audio_cache)} tracks cached."
103
+ )
104
+
105
+ def __len__(self):
106
+ return len(self.indices)
107
+
108
+ def __getitem__(self, idx):
109
+ track_idx, frame_idx, label = self.indices[idx]
110
+
111
+ # Fast lookup from cache
112
+ audio = self.audio_cache[track_idx]
113
+ audio_len = len(audio)
114
+
115
+ # Calculate sample range for context window
116
+ center_sample = frame_idx * self.hop_length
117
+ half_context = self.context_samples // 2
118
+
119
+ # We want the window centered around center_sample
120
+ start = center_sample - half_context
121
+ end = center_sample + half_context
122
+
123
+ # Handle padding if needed
124
+ pad_left = max(0, -start)
125
+ pad_right = max(0, end - audio_len)
126
+
127
+ valid_start = max(0, start)
128
+ valid_end = min(audio_len, end)
129
+
130
+ # Extract audio chunk
131
+ chunk = audio[valid_start:valid_end]
132
+
133
+ if pad_left > 0 or pad_right > 0:
134
+ chunk = np.pad(chunk, (pad_left, pad_right), mode="constant")
135
+
136
+ waveform = torch.tensor(chunk, dtype=torch.float32)
137
+ return waveform, torch.tensor([label], dtype=torch.float32)
exp/baseline2/eval.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from scipy.signal import find_peaks
5
+ import argparse
6
+ import os
7
+
8
+ from .model import ResNet
9
+ from ..baseline1.utils import MultiViewSpectrogram
10
+ from ..data.load import ds
11
+ from ..data.eval import evaluate_all, format_results
12
+
13
+
14
+ def get_activation_function(model, waveform, device):
15
+ """
16
+ Computes probability curve over time.
17
+ """
18
+ processor = MultiViewSpectrogram().to(device)
19
+ waveform = waveform.unsqueeze(0).to(device)
20
+
21
+ with torch.no_grad():
22
+ spec = processor(waveform)
23
+
24
+ # Normalize
25
+ mean = spec.mean(dim=(2, 3), keepdim=True)
26
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
27
+ spec = (spec - mean) / std
28
+
29
+ # Batchify with sliding window
30
+ # Context frames = 50, so total window = 101.
31
+ # Pad time by 50 on each side.
32
+ spec = torch.nn.functional.pad(spec, (50, 50)) # Pad time
33
+ windows = spec.unfold(3, 101, 1) # (1, 3, 80, Time, 101)
34
+ windows = windows.permute(3, 0, 1, 2, 4).squeeze(1) # (Time, 3, 80, 101)
35
+
36
+ # Inference
37
+ activations = []
38
+ batch_size = 128 # Reduced batch size
39
+ for i in range(0, len(windows), batch_size):
40
+ batch = windows[i : i + batch_size]
41
+ out = model(batch)
42
+ activations.append(out.cpu().numpy())
43
+
44
+ return np.concatenate(activations).flatten()
45
+
46
+
47
+ def pick_peaks(activations, hop_length=160, sr=16000):
48
+ """
49
+ Smooth with Hamming window and report local maxima.
50
+ """
51
+ # Smoothing
52
+ window = np.hamming(5)
53
+ window /= window.sum()
54
+ smoothed = np.convolve(activations, window, mode="same")
55
+
56
+ # Peak Picking
57
+ peaks, _ = find_peaks(smoothed, height=0.5, distance=5)
58
+
59
+ timestamps = peaks * hop_length / sr
60
+ return timestamps.tolist()
61
+
62
+
63
+ def visualize_track(
64
+ audio: np.ndarray,
65
+ sr: int,
66
+ pred_beats: list[float],
67
+ pred_downbeats: list[float],
68
+ gt_beats: list[float],
69
+ gt_downbeats: list[float],
70
+ output_dir: str,
71
+ track_idx: int,
72
+ time_range: tuple[float, float] | None = None,
73
+ ):
74
+ """
75
+ Create and save visualizations for a single track.
76
+ """
77
+ from ..data.viz import plot_waveform_with_beats, save_figure
78
+
79
+ os.makedirs(output_dir, exist_ok=True)
80
+
81
+ # Full waveform plot
82
+ fig = plot_waveform_with_beats(
83
+ audio,
84
+ sr,
85
+ pred_beats,
86
+ gt_beats,
87
+ pred_downbeats,
88
+ gt_downbeats,
89
+ title=f"Track {track_idx}: Beat Comparison",
90
+ time_range=time_range,
91
+ )
92
+ save_figure(fig, os.path.join(output_dir, f"track_{track_idx:03d}.png"))
93
+
94
+
95
+ def synthesize_audio(
96
+ audio: np.ndarray,
97
+ sr: int,
98
+ pred_beats: list[float],
99
+ pred_downbeats: list[float],
100
+ gt_beats: list[float],
101
+ gt_downbeats: list[float],
102
+ output_dir: str,
103
+ track_idx: int,
104
+ click_volume: float = 0.5,
105
+ ):
106
+ """
107
+ Create and save audio files with click tracks for a single track.
108
+ """
109
+ from ..data.audio import create_comparison_audio, save_audio
110
+
111
+ os.makedirs(output_dir, exist_ok=True)
112
+
113
+ # Create comparison audio
114
+ audio_pred, audio_gt, audio_both = create_comparison_audio(
115
+ audio,
116
+ pred_beats,
117
+ pred_downbeats,
118
+ gt_beats,
119
+ gt_downbeats,
120
+ sr=sr,
121
+ click_volume=click_volume,
122
+ )
123
+
124
+ # Save audio files
125
+ save_audio(
126
+ audio_pred, os.path.join(output_dir, f"track_{track_idx:03d}_pred.wav"), sr
127
+ )
128
+ save_audio(audio_gt, os.path.join(output_dir, f"track_{track_idx:03d}_gt.wav"), sr)
129
+ save_audio(
130
+ audio_both, os.path.join(output_dir, f"track_{track_idx:03d}_both.wav"), sr
131
+ )
132
+
133
+
134
+ def main():
135
+ parser = argparse.ArgumentParser(
136
+ description="Evaluate beat tracking models with visualization and audio synthesis"
137
+ )
138
+ parser.add_argument(
139
+ "--model-dir",
140
+ type=str,
141
+ default="outputs/baseline2",
142
+ help="Base directory containing trained models (with 'beats' and 'downbeats' subdirs)",
143
+ )
144
+ parser.add_argument(
145
+ "--num-samples",
146
+ type=int,
147
+ default=116,
148
+ help="Number of samples to evaluate",
149
+ )
150
+ parser.add_argument(
151
+ "--output-dir",
152
+ type=str,
153
+ default="outputs/eval_baseline2",
154
+ help="Directory to save visualizations and audio",
155
+ )
156
+ parser.add_argument(
157
+ "--visualize",
158
+ action="store_true",
159
+ help="Generate visualization plots for each track",
160
+ )
161
+ parser.add_argument(
162
+ "--synthesize",
163
+ action="store_true",
164
+ help="Generate audio files with click tracks",
165
+ )
166
+ parser.add_argument(
167
+ "--viz-tracks",
168
+ type=int,
169
+ default=5,
170
+ help="Number of tracks to visualize/synthesize (default: 5)",
171
+ )
172
+ parser.add_argument(
173
+ "--time-range",
174
+ type=float,
175
+ nargs=2,
176
+ default=None,
177
+ metavar=("START", "END"),
178
+ help="Time range for visualization in seconds (default: full track)",
179
+ )
180
+ parser.add_argument(
181
+ "--click-volume",
182
+ type=float,
183
+ default=0.5,
184
+ help="Volume of click sounds relative to audio (0.0 to 1.0)",
185
+ )
186
+ parser.add_argument(
187
+ "--summary-plot",
188
+ action="store_true",
189
+ help="Generate summary evaluation plot",
190
+ )
191
+ args = parser.parse_args()
192
+
193
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
194
+
195
+ # Load BOTH models using from_pretrained
196
+ beat_model = None
197
+ downbeat_model = None
198
+
199
+ has_beats = False
200
+ has_downbeats = False
201
+
202
+ beats_dir = os.path.join(args.model_dir, "beats")
203
+ downbeats_dir = os.path.join(args.model_dir, "downbeats")
204
+
205
+ if os.path.exists(os.path.join(beats_dir, "model.safetensors")):
206
+ beat_model = ResNet.from_pretrained(beats_dir).to(DEVICE)
207
+ beat_model.eval()
208
+ has_beats = True
209
+ print(f"Loaded Beat Model from {beats_dir}")
210
+ else:
211
+ print(f"Warning: No beat model found in {beats_dir}")
212
+
213
+ if os.path.exists(os.path.join(downbeats_dir, "model.safetensors")):
214
+ downbeat_model = ResNet.from_pretrained(downbeats_dir).to(DEVICE)
215
+ downbeat_model.eval()
216
+ has_downbeats = True
217
+ print(f"Loaded Downbeat Model from {downbeats_dir}")
218
+ else:
219
+ print(f"Warning: No downbeat model found in {downbeats_dir}")
220
+
221
+ if not has_beats and not has_downbeats:
222
+ print("No models found. Please run training first.")
223
+ return
224
+
225
+ predictions = []
226
+ ground_truths = []
227
+ audio_data = [] # Store audio for visualization/synthesis
228
+
229
+ # Eval on specified number of tracks
230
+ test_set = ds["train"].select(range(args.num_samples))
231
+
232
+ print("Running evaluation...")
233
+ for i, item in enumerate(tqdm(test_set)):
234
+ waveform = torch.tensor(item["audio"]["array"], dtype=torch.float32)
235
+ waveform_device = waveform.to(DEVICE)
236
+
237
+ pred_entry = {"beats": [], "downbeats": []}
238
+
239
+ # 1. Predict Beats
240
+ if has_beats:
241
+ act_b = get_activation_function(beat_model, waveform_device, DEVICE)
242
+ pred_entry["beats"] = pick_peaks(act_b)
243
+
244
+ # 2. Predict Downbeats
245
+ if has_downbeats:
246
+ act_d = get_activation_function(downbeat_model, waveform_device, DEVICE)
247
+ pred_entry["downbeats"] = pick_peaks(act_d)
248
+
249
+ predictions.append(pred_entry)
250
+ ground_truths.append({"beats": item["beats"], "downbeats": item["downbeats"]})
251
+
252
+ # Store audio for later visualization/synthesis
253
+ if args.visualize or args.synthesize:
254
+ if i < args.viz_tracks:
255
+ audio_data.append(
256
+ {
257
+ "audio": waveform.numpy(),
258
+ "sr": item["audio"]["sampling_rate"],
259
+ "pred": pred_entry,
260
+ "gt": ground_truths[-1],
261
+ }
262
+ )
263
+
264
+ # Run evaluation
265
+ results = evaluate_all(predictions, ground_truths)
266
+ print(format_results(results))
267
+
268
+ # Create output directory
269
+ if args.visualize or args.synthesize or args.summary_plot:
270
+ os.makedirs(args.output_dir, exist_ok=True)
271
+
272
+ # Generate visualizations
273
+ if args.visualize:
274
+ print(f"\nGenerating visualizations for {len(audio_data)} tracks...")
275
+ viz_dir = os.path.join(args.output_dir, "plots")
276
+ for i, data in enumerate(tqdm(audio_data, desc="Visualizing")):
277
+ time_range = tuple(args.time_range) if args.time_range else None
278
+ visualize_track(
279
+ data["audio"],
280
+ data["sr"],
281
+ data["pred"]["beats"],
282
+ data["pred"]["downbeats"],
283
+ data["gt"]["beats"],
284
+ data["gt"]["downbeats"],
285
+ viz_dir,
286
+ i,
287
+ time_range=time_range,
288
+ )
289
+ print(f"Saved visualizations to {viz_dir}")
290
+
291
+ # Generate audio with clicks
292
+ if args.synthesize:
293
+ print(f"\nSynthesizing audio for {len(audio_data)} tracks...")
294
+ audio_dir = os.path.join(args.output_dir, "audio")
295
+ for i, data in enumerate(tqdm(audio_data, desc="Synthesizing")):
296
+ synthesize_audio(
297
+ data["audio"],
298
+ data["sr"],
299
+ data["pred"]["beats"],
300
+ data["pred"]["downbeats"],
301
+ data["gt"]["beats"],
302
+ data["gt"]["downbeats"],
303
+ audio_dir,
304
+ i,
305
+ click_volume=args.click_volume,
306
+ )
307
+ print(f"Saved audio files to {audio_dir}")
308
+ print(" *_pred.wav - Original audio with predicted beat clicks")
309
+ print(" *_gt.wav - Original audio with ground truth beat clicks")
310
+ print(" *_both.wav - Original audio with both predicted and GT clicks")
311
+
312
+ # Generate summary plot
313
+ if args.summary_plot:
314
+ from ..data.viz import plot_evaluation_summary, save_figure
315
+
316
+ print("\nGenerating summary plot...")
317
+ fig = plot_evaluation_summary(results, title="Beat Tracking Evaluation Summary")
318
+ summary_path = os.path.join(args.output_dir, "evaluation_summary.png")
319
+ save_figure(fig, summary_path)
320
+ print(f"Saved summary plot to {summary_path}")
321
+
322
+
323
+ if __name__ == "__main__":
324
+ main()
exp/baseline2/model.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+
6
+ class SEBlock(nn.Module):
7
+ def __init__(self, channels, reduction=16):
8
+ super().__init__()
9
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
10
+ self.fc = nn.Sequential(
11
+ nn.Linear(channels, channels // reduction, bias=False),
12
+ nn.ReLU(inplace=True),
13
+ nn.Linear(channels // reduction, channels, bias=False),
14
+ nn.Sigmoid(),
15
+ )
16
+
17
+ def forward(self, x):
18
+ b, c, _, _ = x.size()
19
+ y = self.avg_pool(x).view(b, c)
20
+ y = self.fc(y).view(b, c, 1, 1)
21
+ return x * y.expand_as(x)
22
+
23
+
24
+ class ResBlock(nn.Module):
25
+ def __init__(self, in_channels, out_channels, stride=1, downsample=None):
26
+ super().__init__()
27
+ self.conv1 = nn.Conv2d(
28
+ in_channels,
29
+ out_channels,
30
+ kernel_size=3,
31
+ stride=stride,
32
+ padding=1,
33
+ bias=False,
34
+ )
35
+ self.bn1 = nn.BatchNorm2d(out_channels)
36
+ self.relu = nn.ReLU(inplace=True)
37
+ self.conv2 = nn.Conv2d(
38
+ out_channels, out_channels, kernel_size=3, padding=1, bias=False
39
+ )
40
+ self.bn2 = nn.BatchNorm2d(out_channels)
41
+ self.se = SEBlock(out_channels)
42
+ self.downsample = downsample
43
+
44
+ def forward(self, x):
45
+ identity = x
46
+ if self.downsample is not None:
47
+ identity = self.downsample(x)
48
+
49
+ out = self.conv1(x)
50
+ out = self.bn1(out)
51
+ out = self.relu(out)
52
+
53
+ out = self.conv2(out)
54
+ out = self.bn2(out)
55
+ out = self.se(out)
56
+
57
+ out += identity
58
+ out = self.relu(out)
59
+ return out
60
+
61
+
62
+ class ResNet(nn.Module, PyTorchModelHubMixin):
63
+ def __init__(
64
+ self, layers=[2, 2, 2, 2], channels=[16, 24, 48, 96], dropout_rate=0.5
65
+ ):
66
+ super().__init__()
67
+ self.in_channels = 16
68
+
69
+ # Stem
70
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
71
+ self.bn1 = nn.BatchNorm2d(16)
72
+ self.relu = nn.ReLU(inplace=True)
73
+
74
+ # Stages
75
+ self.layer1 = self._make_layer(channels[0], layers[0], stride=1)
76
+ self.layer2 = self._make_layer(channels[1], layers[1], stride=2)
77
+ self.layer3 = self._make_layer(channels[2], layers[2], stride=2)
78
+ self.layer4 = self._make_layer(channels[3], layers[3], stride=2)
79
+
80
+ self.dropout = nn.Dropout(p=dropout_rate)
81
+
82
+ # Final classification head
83
+ # H, W will reduce. Assuming input is (3, 80, 101)
84
+ # L1: (16, 80, 101) (stride 1)
85
+ # L2: (32, 40, 51) (stride 2)
86
+ # L3: (64, 20, 26) (stride 2)
87
+ # L4: (128, 10, 13) (stride 2)
88
+
89
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
90
+ self.fc = nn.Linear(channels[3], 1)
91
+ self.sigmoid = nn.Sigmoid()
92
+
93
+ def _make_layer(self, out_channels, blocks, stride=1):
94
+ downsample = None
95
+ if stride != 1 or self.in_channels != out_channels:
96
+ downsample = nn.Sequential(
97
+ nn.Conv2d(
98
+ self.in_channels,
99
+ out_channels,
100
+ kernel_size=1,
101
+ stride=stride,
102
+ bias=False,
103
+ ),
104
+ nn.BatchNorm2d(out_channels),
105
+ )
106
+
107
+ layers = []
108
+ layers.append(ResBlock(self.in_channels, out_channels, stride, downsample))
109
+ self.in_channels = out_channels
110
+ for _ in range(1, blocks):
111
+ layers.append(ResBlock(self.in_channels, out_channels))
112
+
113
+ return nn.Sequential(*layers)
114
+
115
+ def forward(self, x):
116
+ # x: (B, 3, 80, 101)
117
+ x = self.conv1(x)
118
+ x = self.bn1(x)
119
+ x = self.relu(x)
120
+
121
+ x = self.layer1(x)
122
+ x = self.layer2(x)
123
+ x = self.layer3(x)
124
+ x = self.layer4(x)
125
+
126
+ x = self.avgpool(x) # (B, 128, 1, 1)
127
+ x = torch.flatten(x, 1) # (B, 128)
128
+ x = self.dropout(x)
129
+ x = self.fc(x)
130
+ x = self.sigmoid(x)
131
+
132
+ return x
133
+
134
+
135
+ if __name__ == "__main__":
136
+ from torchinfo import summary
137
+
138
+ model = ResNet()
139
+ summary(model, (1, 3, 80, 101))
exp/baseline2/train.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ from tqdm import tqdm
7
+ import argparse
8
+ import os
9
+
10
+ from .model import ResNet
11
+ from .data import BeatTrackingDataset
12
+ from ..baseline1.utils import MultiViewSpectrogram
13
+ from ..data.load import ds
14
+
15
+
16
+ def train(target_type: str, output_dir: str):
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ BATCH_SIZE = 128 # Reduced batch size due to larger context
19
+ EPOCHS = 3
20
+ LR = 0.001 # Adjusted LR for Adam (ResNet usually prefers Adam/AdamW)
21
+ NUM_WORKERS = 4
22
+ CONTEXT_FRAMES = 50 # +/- 50 frames -> 101 frames total
23
+ PATIENCE = 5 # Early stopping patience
24
+
25
+ print(f"--- Training Model for target: {target_type} ---")
26
+ print(f"Output directory: {output_dir}")
27
+
28
+ # Create output directory
29
+ os.makedirs(output_dir, exist_ok=True)
30
+
31
+ # TensorBoard writer
32
+ writer = SummaryWriter(log_dir=os.path.join(output_dir, "logs"))
33
+
34
+ # Data
35
+ train_dataset = BeatTrackingDataset(
36
+ ds["train"], target_type=target_type, context_frames=CONTEXT_FRAMES
37
+ )
38
+ val_dataset = BeatTrackingDataset(
39
+ ds["test"], target_type=target_type, context_frames=CONTEXT_FRAMES
40
+ )
41
+
42
+ train_loader = DataLoader(
43
+ train_dataset,
44
+ batch_size=BATCH_SIZE,
45
+ shuffle=True,
46
+ num_workers=NUM_WORKERS,
47
+ pin_memory=True,
48
+ prefetch_factor=4,
49
+ persistent_workers=True,
50
+ )
51
+ val_loader = DataLoader(
52
+ val_dataset,
53
+ batch_size=BATCH_SIZE,
54
+ shuffle=False,
55
+ num_workers=NUM_WORKERS,
56
+ pin_memory=True,
57
+ prefetch_factor=4,
58
+ persistent_workers=True,
59
+ )
60
+
61
+ print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
62
+
63
+ # Model
64
+ model = ResNet(dropout_rate=0.5).to(DEVICE)
65
+
66
+ # GPU Spectrogram Preprocessor
67
+ preprocessor = MultiViewSpectrogram(sample_rate=16000, hop_length=160).to(DEVICE)
68
+
69
+ # Optimizer - Using AdamW for ResNet
70
+ optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
71
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
72
+ criterion = nn.BCELoss() # Binary Cross Entropy
73
+
74
+ best_val_loss = float("inf")
75
+ patience_counter = 0
76
+ global_step = 0
77
+
78
+ for epoch in range(EPOCHS):
79
+ # Training
80
+ model.train()
81
+ total_train_loss = 0
82
+ for waveform, y in tqdm(
83
+ train_loader,
84
+ desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Train",
85
+ leave=False,
86
+ ):
87
+ waveform, y = waveform.to(DEVICE), y.to(DEVICE)
88
+
89
+ # Compute spectrogram on GPU
90
+ with torch.no_grad():
91
+ spec = preprocessor(waveform) # (B, 3, 80, T_raw)
92
+ # Normalize
93
+ mean = spec.mean(dim=(2, 3), keepdim=True)
94
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
95
+ spec = (spec - mean) / std
96
+
97
+ T_curr = spec.shape[-1]
98
+ target_T = CONTEXT_FRAMES * 2 + 1
99
+
100
+ if T_curr > target_T:
101
+ start = (T_curr - target_T) // 2
102
+ x = spec[:, :, :, start : start + target_T]
103
+ elif T_curr < target_T:
104
+ # This shouldn't happen if dataset is correct, but just in case pad
105
+ pad = target_T - T_curr
106
+ x = torch.nn.functional.pad(spec, (0, pad))
107
+ else:
108
+ x = spec
109
+
110
+ optimizer.zero_grad()
111
+ output = model(x)
112
+ loss = criterion(output, y)
113
+ loss.backward()
114
+ optimizer.step()
115
+
116
+ total_train_loss += loss.item()
117
+ global_step += 1
118
+
119
+ # Log batch loss
120
+ writer.add_scalar("train/batch_loss", loss.item(), global_step)
121
+
122
+ avg_train_loss = total_train_loss / len(train_loader)
123
+
124
+ # Validation
125
+ model.eval()
126
+ total_val_loss = 0
127
+ with torch.no_grad():
128
+ for waveform, y in tqdm(
129
+ val_loader,
130
+ desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Val",
131
+ leave=False,
132
+ ):
133
+ waveform, y = waveform.to(DEVICE), y.to(DEVICE)
134
+
135
+ # Compute spectrogram on GPU
136
+ spec = preprocessor(waveform) # (B, 3, 80, T)
137
+ # Normalize
138
+ mean = spec.mean(dim=(2, 3), keepdim=True)
139
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
140
+ spec = (spec - mean) / std
141
+
142
+ T_curr = spec.shape[-1]
143
+ target_T = CONTEXT_FRAMES * 2 + 1
144
+
145
+ if T_curr > target_T:
146
+ start = (T_curr - target_T) // 2
147
+ x = spec[:, :, :, start : start + target_T]
148
+ else:
149
+ pad = target_T - T_curr
150
+ x = torch.nn.functional.pad(spec, (0, pad))
151
+
152
+ output = model(x)
153
+ loss = criterion(output, y)
154
+ total_val_loss += loss.item()
155
+
156
+ avg_val_loss = total_val_loss / len(val_loader)
157
+
158
+ # Log epoch metrics
159
+ writer.add_scalar("train/epoch_loss", avg_train_loss, epoch)
160
+ writer.add_scalar("val/loss", avg_val_loss, epoch)
161
+ writer.add_scalar("train/learning_rate", scheduler.get_last_lr()[0], epoch)
162
+
163
+ # Step the scheduler
164
+ scheduler.step()
165
+
166
+ print(
167
+ f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} - "
168
+ f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}"
169
+ )
170
+
171
+ # Save best model
172
+ if avg_val_loss < best_val_loss:
173
+ best_val_loss = avg_val_loss
174
+ patience_counter = 0
175
+ model.save_pretrained(output_dir)
176
+ print(f" -> Saved best model (val_loss: {best_val_loss:.4f})")
177
+ else:
178
+ patience_counter += 1
179
+ print(f" -> No improvement (patience: {patience_counter}/{PATIENCE})")
180
+
181
+ if patience_counter >= PATIENCE:
182
+ print("Early stopping triggered.")
183
+ break
184
+
185
+ writer.close()
186
+
187
+ # Save final model
188
+ final_dir = os.path.join(output_dir, "final")
189
+ model.save_pretrained(final_dir)
190
+ print(f"Saved final model to {final_dir}")
191
+
192
+
193
+ if __name__ == "__main__":
194
+ parser = argparse.ArgumentParser()
195
+ parser.add_argument(
196
+ "--target",
197
+ type=str,
198
+ choices=["beats", "downbeats"],
199
+ default=None,
200
+ help="Train a model for 'beats' or 'downbeats'. If not specified, trains both.",
201
+ )
202
+ parser.add_argument(
203
+ "--output-dir",
204
+ type=str,
205
+ default="outputs/baseline2",
206
+ help="Directory to save model and logs",
207
+ )
208
+ args = parser.parse_args()
209
+
210
+ # Determine which targets to train
211
+ targets = [args.target] if args.target else ["beats", "downbeats"]
212
+
213
+ for target in targets:
214
+ output_dir = os.path.join(args.output_dir, target)
215
+ train(target, output_dir)
outputs/baseline1/beats/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline1/beats/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "dropout_rate": 0.5
3
+ }
outputs/baseline1/beats/final/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline1/beats/final/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "dropout_rate": 0.5
3
+ }
outputs/baseline1/beats/final/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0ee01ee41360f0b486e16d6022f896a19f9ead901be0180bdbd9cad2a3b8597
3
+ size 1159372
outputs/baseline1/beats/logs/events.out.tfevents.1766351314.msiit232.1284330.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b2d91a22ba01091bf072f5a5e8f12fc7d49801d6538914c973ccb2700978934
3
+ size 17749022
outputs/baseline1/beats/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e7a0d5178bc5dfeee6da26345e7956aeb6bf64a21be7e541db4bcc37b290249
3
+ size 1159372
outputs/baseline1/downbeats/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline1/downbeats/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "dropout_rate": 0.5
3
+ }
outputs/baseline1/downbeats/final/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline1/downbeats/final/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "dropout_rate": 0.5
3
+ }
outputs/baseline1/downbeats/final/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:870e3425ffd366be9a0e8fafcda62fa28b2c25917c8354570edc53a67e132d38
3
+ size 1159372
outputs/baseline1/downbeats/logs/events.out.tfevents.1766353075.msiit232.1284330.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8744916b2c1a8255cd5379e6956a4ad2acbf8bcc1fcfaed21ca11285a771550c
3
+ size 4272622
outputs/baseline1/downbeats/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8895be0bff1c3210f46b04c490596490fe03081728e17fffb33c80369b472134
3
+ size 1159372
outputs/baseline2/beats/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline2/beats/config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "channels": [
3
+ 16,
4
+ 24,
5
+ 48,
6
+ 96
7
+ ],
8
+ "dropout_rate": 0.5,
9
+ "layers": [
10
+ 2,
11
+ 2,
12
+ 2,
13
+ 2
14
+ ]
15
+ }
outputs/baseline2/beats/final/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline2/beats/final/config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "channels": [
3
+ 16,
4
+ 24,
5
+ 48,
6
+ 96
7
+ ],
8
+ "dropout_rate": 0.5,
9
+ "layers": [
10
+ 2,
11
+ 2,
12
+ 2,
13
+ 2
14
+ ]
15
+ }
outputs/baseline2/beats/final/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e60f98fe8152c2238d2bf7b2bc800a13aaf8a3cb665c3bfaa8e7dbc656362212
3
+ size 1629940
outputs/baseline2/beats/logs/events.out.tfevents.1766356346.msiit232.1356098.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0b974a684478fa38672c299c1a441df5c051abc37e30e6f06d26502378c7c1d
3
+ size 4245699
outputs/baseline2/beats/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e60f98fe8152c2238d2bf7b2bc800a13aaf8a3cb665c3bfaa8e7dbc656362212
3
+ size 1629940
outputs/baseline2/downbeats/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline2/downbeats/config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "channels": [
3
+ 16,
4
+ 24,
5
+ 48,
6
+ 96
7
+ ],
8
+ "dropout_rate": 0.5,
9
+ "layers": [
10
+ 2,
11
+ 2,
12
+ 2,
13
+ 2
14
+ ]
15
+ }
outputs/baseline2/downbeats/final/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline2/downbeats/final/config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "channels": [
3
+ 16,
4
+ 24,
5
+ 48,
6
+ 96
7
+ ],
8
+ "dropout_rate": 0.5,
9
+ "layers": [
10
+ 2,
11
+ 2,
12
+ 2,
13
+ 2
14
+ ]
15
+ }
outputs/baseline2/downbeats/final/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72c3a009b2e5d067d69d53755d96cee23a6305de3aa9b8336c3b817a0f0f8e77
3
+ size 1629940
outputs/baseline2/downbeats/logs/events.out.tfevents.1766359276.msiit232.1356098.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc03f9d0a5e525864dd519cd176d0ed71520315a89b9853403a72afac4e77921
3
+ size 1011363
outputs/baseline2/downbeats/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72c3a009b2e5d067d69d53755d96cee23a6305de3aa9b8336c3b817a0f0f8e77
3
+ size 1629940