From 20cbd61a82137ab52474f9d6d298f63622b983a0 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Wed, 10 Dec 2025 00:02:10 +0100 Subject: [PATCH] chore: More verbose dataloading --- src/dataset_loaders/Dataset.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/dataset_loaders/Dataset.py b/src/dataset_loaders/Dataset.py index e37f0cc..f3c5786 100644 --- a/src/dataset_loaders/Dataset.py +++ b/src/dataset_loaders/Dataset.py @@ -1,4 +1,5 @@ from abc import abstractmethod, ABC +from itertools import accumulate from os.path import join, curdir from typing import Callable @@ -53,23 +54,33 @@ class Dataset(TorchDataset, ABC): self.chunk_offsets = self.get_offsets() if self.size == -1: # Just use the whole dataset - self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace') + self.bytes = ''.join(tqdm(self.data, desc="Encoding data", leave=False)).encode('utf-8', errors='replace') else: # Use only partition, calculate offsets - self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace') + self.bytes = (''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data", leave=False)) + .encode('utf-8', errors='replace')) - bytes_array = np.frombuffer(self.bytes, dtype=np.uint8) # Zero-copy - self.tensor = torch.from_numpy(bytes_array).to(torch.long) + bytes_array = np.frombuffer(self.bytes, dtype=np.uint8) # Zero-copy + self.tensor = torch.from_numpy(bytes_array).to(torch.long, non_blocking=True) def get_offsets(self): """ Calculate for each chunk how many bytes came before it """ + data = self.data + size = self.size + + if size == -1: + return [0, *accumulate(tqdm(map(len, data), desc="Calculating offsets", leave=False, total=len(data)))] + offsets = [0] - while len(offsets) <= len(self.data) and (self.size == -1 or offsets[-1] < self.size): - idx = len(offsets) - 1 - offsets.append(offsets[idx] + len(self.data[idx])) - print(offsets) + total = 0 + append = offsets.append + for chunk in tqdm(data): + if total >= size: + break + total += len(chunk) + append(total) return offsets def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):