chore: More verbose dataloading
This commit is contained in:
parent
2de9e87470
commit
20cbd61a82
1 changed files with 19 additions and 8 deletions
|
|
@ -1,4 +1,5 @@
|
||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
|
from itertools import accumulate
|
||||||
from os.path import join, curdir
|
from os.path import join, curdir
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
|
@ -53,23 +54,33 @@ class Dataset(TorchDataset, ABC):
|
||||||
self.chunk_offsets = self.get_offsets()
|
self.chunk_offsets = self.get_offsets()
|
||||||
if self.size == -1:
|
if self.size == -1:
|
||||||
# Just use the whole dataset
|
# Just use the whole dataset
|
||||||
self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace')
|
self.bytes = ''.join(tqdm(self.data, desc="Encoding data", leave=False)).encode('utf-8', errors='replace')
|
||||||
else:
|
else:
|
||||||
# Use only partition, calculate offsets
|
# Use only partition, calculate offsets
|
||||||
self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace')
|
self.bytes = (''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data", leave=False))
|
||||||
|
.encode('utf-8', errors='replace'))
|
||||||
|
|
||||||
bytes_array = np.frombuffer(self.bytes, dtype=np.uint8) # Zero-copy
|
bytes_array = np.frombuffer(self.bytes, dtype=np.uint8) # Zero-copy
|
||||||
self.tensor = torch.from_numpy(bytes_array).to(torch.long)
|
self.tensor = torch.from_numpy(bytes_array).to(torch.long, non_blocking=True)
|
||||||
|
|
||||||
def get_offsets(self):
|
def get_offsets(self):
|
||||||
"""
|
"""
|
||||||
Calculate for each chunk how many bytes came before it
|
Calculate for each chunk how many bytes came before it
|
||||||
"""
|
"""
|
||||||
|
data = self.data
|
||||||
|
size = self.size
|
||||||
|
|
||||||
|
if size == -1:
|
||||||
|
return [0, *accumulate(tqdm(map(len, data), desc="Calculating offsets", leave=False, total=len(data)))]
|
||||||
|
|
||||||
offsets = [0]
|
offsets = [0]
|
||||||
while len(offsets) <= len(self.data) and (self.size == -1 or offsets[-1] < self.size):
|
total = 0
|
||||||
idx = len(offsets) - 1
|
append = offsets.append
|
||||||
offsets.append(offsets[idx] + len(self.data[idx]))
|
for chunk in tqdm(data):
|
||||||
print(offsets)
|
if total >= size:
|
||||||
|
break
|
||||||
|
total += len(chunk)
|
||||||
|
append(total)
|
||||||
return offsets
|
return offsets
|
||||||
|
|
||||||
def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):
|
def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):
|
||||||
|
|
|
||||||
Reference in a new issue