fix: Avoid copy list conversion

This commit is contained in:
Tibo De Peuter 2025-12-09 22:50:59 +01:00
parent f97c7c9130
commit 209bf2403d
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2

View file

@ -2,6 +2,7 @@ from abc import abstractmethod, ABC
from os.path import join, curdir
from typing import Callable
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset
@ -57,7 +58,8 @@ class Dataset(TorchDataset, ABC):
self.chunk_offsets = self.get_offsets()
self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace')
self.tensor = torch.tensor(list(self.bytes), dtype=torch.long)
bytes_array = np.frombuffer(self.bytes, dtype=np.uint8) # Zero-copy
self.tensor = torch.from_numpy(bytes_array).to(torch.long)
def get_offsets(self):
"""