| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | All utilities related to data handling. |
| | """ |
| |
|
| | from collections.abc import Callable |
| | from functools import partial |
| |
|
| | import datasets |
| | import numpy as np |
| | from datasets import Dataset, load_dataset |
| |
|
| |
|
| | |
| | |
| | CHAR_LIMIT = 1300 |
| | |
| | |
| | VALID_SIZE = 50 |
| |
|
| |
|
| | def get_filtered_dataset(*, ds: datasets.Dataset, print_fn: Callable[..., None]) -> Dataset: |
| | """Return the filtered dataset, with long queries removed. |
| | |
| | We determined that 99% of queries have 529 or fewer characters. Characters roughly correspond to tokens, so this is |
| | a good proxy. We cannot use tokens directly, as that depends on the tokenizer, which can be different for each |
| | model, but we want the same filter for each model. |
| | |
| | """ |
| | char_lengths = [len(f"{q} {r}") for q, r in zip(ds["query"], ds["response"])] |
| | idx_filtered = [i for i, length in enumerate(char_lengths) if length <= CHAR_LIMIT] |
| | print_fn(f"Filtered dataset: {100 * len(idx_filtered) / len(ds):.1f}% of the original dataset") |
| | return ds.select(idx_filtered) |
| |
|
| |
|
| | def get_train_valid_test_datasets( |
| | *, tokenizer, query_template: str, print_fn: Callable[..., None] |
| | ) -> tuple[Dataset, Dataset, Dataset]: |
| | """ |
| | Return the indices of the train, valid, and test splits of the dataset. |
| | |
| | We cannot use ds.train_test_split(..., stratify_by_column="type") as it gives: |
| | |
| | > ValueError: Stratifying by column is only supported for ClassLabel column, and column type is Value. |
| | |
| | even after calling ds_filtered.class_encode_column("type"). Thus, using sklearn's StratifiedKFold instead. |
| | """ |
| | metamath = load_dataset("meta-math/MetaMathQA")["train"] |
| | metamath = get_filtered_dataset(ds=metamath, print_fn=print_fn) |
| |
|
| | |
| | gsm8k = load_dataset("openai/gsm8k", "main") |
| | gsm8k = gsm8k.rename_columns({"question": "query", "answer": "response"}) |
| | gsm8k_train = gsm8k["train"] |
| | gsm8k_test = gsm8k["test"] |
| |
|
| | np.random.seed(0) |
| | indices = np.arange(len(gsm8k_train)) |
| | np.random.shuffle(indices) |
| | idx_valid = indices[:VALID_SIZE] |
| |
|
| | ds_train = metamath |
| | ds_valid = gsm8k_train.select(idx_valid) |
| | ds_test = gsm8k_test |
| |
|
| | print_fn(f"Train size: {len(ds_train)}") |
| | print_fn(f"Valid size: {len(ds_valid)}") |
| | print_fn(f"Test size: {len(ds_test)}") |
| |
|
| | tokenize_with_answer_ = partial(tokenize_with_answer, tokenizer=tokenizer, template=query_template) |
| | tokenize_wo_answer_ = partial(tokenize_wo_answer, tokenizer=tokenizer, template=query_template) |
| | ds_train = ds_train.map(tokenize_with_answer_, batched=True).remove_columns(["type", "query", "original_question"]) |
| | ds_valid = ds_valid.map(tokenize_wo_answer_, batched=True).remove_columns(["query"]) |
| | ds_test = ds_test.map(tokenize_wo_answer_, batched=True).remove_columns(["query"]) |
| |
|
| | return ds_train, ds_valid, ds_test |
| |
|
| |
|
| | def tokenize_with_answer(samples, tokenizer, template): |
| | queries = [template.format(query=sample) + answer for sample, answer in zip(samples["query"], samples["response"])] |
| | tokenized = tokenizer(queries) |
| | tokenized["input_ids"] = [input_ids[: tokenizer.model_max_length] for input_ids in tokenized["input_ids"]] |
| | tokenized["attention_mask"] = [ |
| | input_ids[: tokenizer.model_max_length] for input_ids in tokenized["attention_mask"] |
| | ] |
| | return tokenized |
| |
|
| |
|
| | def tokenize_wo_answer(samples, tokenizer, template): |
| | queries = [template.format(query=sample) for sample in samples["query"]] |
| | tokenized = tokenizer(queries) |
| | tokenized["input_ids"] = [input_ids[: tokenizer.model_max_length] for input_ids in tokenized["input_ids"]] |
| | tokenized["attention_mask"] = [ |
| | input_ids[: tokenizer.model_max_length] for input_ids in tokenized["attention_mask"] |
| | ] |
| | return tokenized |
| |
|
| |
|
| | def get_wiki_small(num_samples: int = 100) -> list[str]: |
| | |
| | ds = load_dataset("HuggingFaceFW/finewiki", split="train", streaming=True) |
| | dataset_head = ds.take(num_samples) |
| | rows = [row["text"] for row in dataset_head] |
| | return rows |
| |
|