chore: More verbose dataloading

This commit is contained in:
Tibo De Peuter 2025-12-10 00:02:10 +01:00
parent 2de9e87470
commit 20cbd61a82
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2

View file

@ -1,4 +1,5 @@
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from itertools import accumulate
from os.path import join, curdir from os.path import join, curdir
from typing import Callable from typing import Callable
@ -53,23 +54,33 @@ class Dataset(TorchDataset, ABC):
self.chunk_offsets = self.get_offsets() self.chunk_offsets = self.get_offsets()
if self.size == -1: if self.size == -1:
# Just use the whole dataset # Just use the whole dataset
self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace') self.bytes = ''.join(tqdm(self.data, desc="Encoding data", leave=False)).encode('utf-8', errors='replace')
else: else:
# Use only partition, calculate offsets # Use only partition, calculate offsets
self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace') self.bytes = (''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data", leave=False))
.encode('utf-8', errors='replace'))
bytes_array = np.frombuffer(self.bytes, dtype=np.uint8) # Zero-copy bytes_array = np.frombuffer(self.bytes, dtype=np.uint8) # Zero-copy
self.tensor = torch.from_numpy(bytes_array).to(torch.long) self.tensor = torch.from_numpy(bytes_array).to(torch.long, non_blocking=True)
def get_offsets(self): def get_offsets(self):
""" """
Calculate for each chunk how many bytes came before it Calculate for each chunk how many bytes came before it
""" """
data = self.data
size = self.size
if size == -1:
return [0, *accumulate(tqdm(map(len, data), desc="Calculating offsets", leave=False, total=len(data)))]
offsets = [0] offsets = [0]
while len(offsets) <= len(self.data) and (self.size == -1 or offsets[-1] < self.size): total = 0
idx = len(offsets) - 1 append = offsets.append
offsets.append(offsets[idx] + len(self.data[idx])) for chunk in tqdm(data):
print(offsets) if total >= size:
break
total += len(chunk)
append(total)
return offsets return offsets
def get_chunked_item(self, idx: int, offsets: list[int], context_length: int): def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):