|
|
from datasets import load_dataset, Audio |
|
|
|
|
|
N_PROC = None |
|
|
|
|
|
ds = load_dataset("JacobLinCool/taiko-1000-parsed") |
|
|
ds = ds.remove_columns(["tja", "hard", "normal", "easy", "ura"]) |
|
|
|
|
|
|
|
|
def filter_out_broken(example): |
|
|
try: |
|
|
example["audio"]["array"] |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
|
|
|
ds = ds.filter(filter_out_broken, num_proc=N_PROC, batch_size=32, writer_batch_size=32) |
|
|
ds = ds.cast_column("audio", Audio(sampling_rate=16000)) |
|
|
|
|
|
|
|
|
def build_beat_and_downbeat_labels(example): |
|
|
""" |
|
|
Extract beat and downbeat times from the chart segments. |
|
|
|
|
|
- Downbeats: First beat of each measure (segment timestamp) |
|
|
- Beats: All beats within each measure based on time signature |
|
|
|
|
|
Returns lists of times in seconds. |
|
|
""" |
|
|
title = example["metadata"]["TITLE"] |
|
|
segments = example["oni"]["segments"] |
|
|
|
|
|
beats = [] |
|
|
downbeats = [] |
|
|
|
|
|
for i, segment in enumerate(segments): |
|
|
seg_timestamp = segment["timestamp"] |
|
|
measure_num = segment["measure_num"] |
|
|
measure_den = segment["measure_den"] |
|
|
notes = segment["notes"] |
|
|
|
|
|
|
|
|
downbeats.append(seg_timestamp) |
|
|
|
|
|
|
|
|
bpm = None |
|
|
if notes: |
|
|
bpm = notes[0]["bpm"] |
|
|
else: |
|
|
|
|
|
for j in range(i + 1, len(segments)): |
|
|
if segments[j]["notes"]: |
|
|
bpm = segments[j]["notes"][0]["bpm"] |
|
|
break |
|
|
|
|
|
if bpm is None or bpm <= 0: |
|
|
bpm = 120.0 |
|
|
|
|
|
|
|
|
|
|
|
beat_duration = (60.0 / bpm) * (4.0 / measure_den) |
|
|
|
|
|
|
|
|
for beat_idx in range(measure_num): |
|
|
beat_time = seg_timestamp + beat_idx * beat_duration |
|
|
beats.append(beat_time) |
|
|
|
|
|
|
|
|
beats = sorted(set(beats)) |
|
|
downbeats = sorted(set(downbeats)) |
|
|
|
|
|
return { |
|
|
"title": title, |
|
|
"beats": beats, |
|
|
"downbeats": downbeats, |
|
|
} |
|
|
|
|
|
|
|
|
ds = ds.map( |
|
|
build_beat_and_downbeat_labels, |
|
|
num_proc=N_PROC, |
|
|
batch_size=32, |
|
|
writer_batch_size=32, |
|
|
remove_columns=["oni", "metadata"], |
|
|
) |
|
|
|
|
|
ds = ds.with_format("torch") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print(ds) |
|
|
print(ds["train"].features) |
|
|
|