| |
| |
| |
|
|
| import json, math, random, os, sys |
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
| from pytorch_lightning.utilities import rank_zero_info |
| from .binidx import MMapIndexedDataset |
| from .utils import MaybeIsPrime |
|
|
|
|
| class MyDataset(Dataset): |
| def __init__(self, args): |
| self.args = args |
|
|
| if args.data_type == "binidx": |
| self.vocab_size = args.vocab_size |
| rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)") |
|
|
| if args.my_pile_version == 1: |
| self.data = MMapIndexedDataset(args.data_file) |
| self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size |
| rank_zero_info(f"Data has {self.data_size} tokens.") |
| else: |
| data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n') |
| data_list = [i.strip().split(' ') for i in data_list] |
| self.data = [] |
| self.data_size = int(data_list[-1][-1]) |
| rank_zero_info(f"Data has {self.data_size} chunks.") |
| for d in data_list: |
| data = MMapIndexedDataset(d[0]) |
| data_size = len(data._bin_buffer) // data._index._dtype_size |
| assert (data_size - args.ctx_len) == int(d[1]) |
| self.data += [[int(d[-1]), int(d[1]), data]] |
| |
|
|
| if args.my_qa_mask > 0: |
| |
| self.data_pile = MMapIndexedDataset('/fsx/pile_deduped/pile_0.87_deduped_text_document') |
| self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size |
| else: |
| self.data_pile = None |
| self.data_pile_size = 0 |
|
|
| if args.my_pile_stage > 0: |
| |
| self.samples_per_epoch = args.epoch_steps * args.real_bsz |
| assert self.samples_per_epoch == 40320 |
| rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") |
| dataset_slot = self.data_size // args.ctx_len |
| if args.my_pile_stage != 4: |
| assert MaybeIsPrime(args.magic_prime) |
| assert args.magic_prime % 3 == 2 |
| assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1 |
| elif args.data_type == "numpy": |
| self.data = np.load(args.data_file).astype("int") |
| self.vocab_size = args.vocab_size |
| rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") |
| self.data_size = len(self.data) |
| rank_zero_info(f"Data has {self.data_size} tokens.") |
| elif args.data_type == "uint16": |
| self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len) |
| self.vocab_size = args.vocab_size |
| rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") |
| self.data_size = self.data.shape[0] |
| rank_zero_info(f"Data has {self.data_size} samples.") |
| elif args.data_type == "wds_img": |
| self.vocab_size = -1 |
| self.data_size = -1 |
| self.data = None |
| self.error_count = 0 |
| else: |
| if args.data_type == "dummy": |
| rank_zero_info("Building dummy data...") |
| self.data = "" |
| for i in range(100000): |
| aa = (i) % 10000 |
| bb = (i * i) % 10000 |
| cc = aa + bb |
| self.data += f".{aa}+{bb}={cc}." |
| else: |
| self.data = open(args.data_file, "r", encoding=args.data_type).read() |
| rank_zero_info("Building token list...") |
| unique = sorted(list(set(self.data))) |
| self.vocab_size = len(unique) |
| |
| |
| |
| |
| xx = 0 |
| xxObj = {} |
| for u in unique: |
| xxObj[xx] = u |
| xx += 1 |
| with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file: |
| vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) |
| self.data_size = len(self.data) |
| rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.") |
| self.stoi = {ch: i for i, ch in enumerate(unique)} |
| self.itos = {i: ch for i, ch in enumerate(unique)} |
|
|
| def __len__(self): |
| return self.args.epoch_steps * self.args.micro_bsz |
|
|
| def __getitem__(self, idx): |
| args = self.args |
| rank = self.global_rank |
| epoch = self.real_epoch |
| world_size = self.world_size |
| |
|
|
| if args.data_type == "wds_img": |
| def init_wds(self, bias=0): |
| def identity(x): |
| return x |
| import webdataset as wds |
| import torchvision.transforms as transforms |
| |
| |
| |
| img_transform = transforms.Compose([ |
| transforms.CenterCrop(512), |
| transforms.Resize((args.my_img_size)) |
| ]) |
| self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) |
| for pp in self.data_raw.pipeline: |
| if 'Resampled' in str(pp): |
| pp.deterministic = True |
| def worker_seed(): |
| return rank*100000+epoch+bias*1e9 |
| pp.worker_seed = worker_seed |
| self.data = iter(self.data_raw) |
| |
| if self.data == None: |
| init_wds(self) |
| trial = 0 |
| while trial < 10: |
| try: |
| dd = next(self.data) |
| break |
| except: |
| print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]') |
| self.error_count += 1 |
| init_wds(self, self.error_count) |
| trial += 1 |
| pass |
| |
| |
| |
| return dd[0], dd[2] |
| else: |
| if args.data_type == "uint16": |
| i = np.random.randint(0, self.data_size-1) |
| dix = self.data[i] |
| x = torch.tensor(dix[:-1], dtype=torch.long) |
| y = torch.tensor(dix[1:], dtype=torch.long) |
| else: |
| ctx_len = args.ctx_len |
| req_len = ctx_len + 1 |
| magic_prime = args.magic_prime |
| data = self.data |
|
|
| if args.my_pile_stage > 0: |
| ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank |
|
|
| if args.my_qa_mask > 0: |
| ii_orig = ii |
| if ii % 2 == 0: |
| ii = -1 |
| data = self.data_pile |
| else: |
| ii = ii // 2 |
| if data == self.data_pile: |
| i = np.random.randint(0, self.data_pile_size - req_len) |
| else: |
| if args.my_pile_stage == 4 or ii < args.my_random_steps: |
| |
| if args.my_pile_version == 1: |
| i = np.random.randint(0, self.data_size - req_len) |
| else: |
| i = np.random.randint(0, self.data_size) |
| else: |
| ii = ii - args.my_random_steps |
| factor = (math.sqrt(5) - 1) / 2 |
| factor = int(magic_prime * factor) |
| i = ((factor * ii * ii * ii) % magic_prime) * ctx_len |
| i = i + args.my_pile_shift |
| |
| else: |
| |
| i = np.random.randint(0, self.data_size - req_len) |
|
|
| if args.data_type == "binidx": |
| if args.my_pile_version == 1: |
| dix = data.get(idx=0, offset=i, length=req_len).astype(int) |
| else: |
| |
| for j in range(len(data)): |
| if i < data[j][0]: |
| ii = i |
| i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1] |
| dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int) |
| |
| break |
| elif args.data_type == "numpy": |
| dix = data[i : i + req_len] |
| else: |
| dix = [self.stoi[s] for s in data[i : i + req_len]] |
|
|
| if args.my_qa_mask == 1: |
| if data == self.data_pile: |
| z = [1] * ctx_len |
| else: |
| z = [0] * ctx_len |
| z_sum = 0 |
| isGood = False |
| for i in range(3, ctx_len): |
| if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187: |
| isGood = True |
| if dix[i] == 0: |
| isGood = False |
| if isGood: |
| z[i] = 1 |
| z_sum += 1 |
| if z_sum == 0: |
| z = [1] * ctx_len |
| i = np.random.randint(0, self.data_pile_size - req_len) |
| dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int) |
| z = torch.tensor(z, dtype=torch.bfloat16) |
|
|
| x = torch.tensor(dix[:-1], dtype=torch.long) |
| y = torch.tensor(dix[1:], dtype=torch.long) |
|
|
| |
| |
| |
| |
| |
|
|
| if args.my_qa_mask == 1: |
| return x, y, z |
|
|
| return x, y |
|
|