fix: Avoid copy list conversion

This commit is contained in:
Tibo De Peuter 2025-12-09 22:50:59 +01:00
parent f97c7c9130
commit 209bf2403d
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2

View file

@ -2,6 +2,7 @@ from abc import abstractmethod, ABC
from os.path import join, curdir from os.path import join, curdir
from typing import Callable from typing import Callable
import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset as TorchDataset from torch.utils.data import Dataset as TorchDataset
@ -57,7 +58,8 @@ class Dataset(TorchDataset, ABC):
self.chunk_offsets = self.get_offsets() self.chunk_offsets = self.get_offsets()
self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace') self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace')
self.tensor = torch.tensor(list(self.bytes), dtype=torch.long) bytes_array = np.frombuffer(self.bytes, dtype=np.uint8) # Zero-copy
self.tensor = torch.from_numpy(bytes_array).to(torch.long)
def get_offsets(self): def get_offsets(self):
""" """