from datasets import load_dataset, Audio N_PROC = None ds = load_dataset("JacobLinCool/taiko-1000-parsed") ds = ds.remove_columns(["tja", "hard", "normal", "easy", "ura"]) def filter_out_broken(example): try: example["audio"]["array"] return True except: return False ds = ds.filter(filter_out_broken, num_proc=N_PROC, batch_size=32, writer_batch_size=32) ds = ds.cast_column("audio", Audio(sampling_rate=16000)) def build_beat_and_downbeat_labels(example): """ Extract beat and downbeat times from the chart segments. - Downbeats: First beat of each measure (segment timestamp) - Beats: All beats within each measure based on time signature Returns lists of times in seconds. """ title = example["metadata"]["TITLE"] segments = example["oni"]["segments"] beats = [] downbeats = [] for i, segment in enumerate(segments): seg_timestamp = segment["timestamp"] measure_num = segment["measure_num"] # numerator (e.g., 4 in 4/4) measure_den = segment["measure_den"] # denominator (e.g., 4 in 4/4) notes = segment["notes"] # Downbeat is the start of each measure downbeats.append(seg_timestamp) # Get BPM from the first note in segment, or fallback to next segment's first note bpm = None if notes: bpm = notes[0]["bpm"] else: # Look ahead for BPM if current segment has no notes for j in range(i + 1, len(segments)): if segments[j]["notes"]: bpm = segments[j]["notes"][0]["bpm"] break if bpm is None or bpm <= 0: bpm = 120.0 # fallback default BPM # Calculate beat duration: one beat = 60/BPM seconds (for quarter note) # Adjust for time signature denominator (4 = quarter, 8 = eighth, etc.) beat_duration = (60.0 / bpm) * (4.0 / measure_den) # Calculate beat positions within this measure for beat_idx in range(measure_num): beat_time = seg_timestamp + beat_idx * beat_duration beats.append(beat_time) # Sort and deduplicate (in case of overlapping segments) beats = sorted(set(beats)) downbeats = sorted(set(downbeats)) return { "title": title, "beats": beats, "downbeats": downbeats, } ds = ds.map( build_beat_and_downbeat_labels, num_proc=N_PROC, batch_size=32, writer_batch_size=32, remove_columns=["oni", "metadata"], ) ds = ds.with_format("torch") if __name__ == "__main__": print(ds) print(ds["train"].features)