fix: Avoid copy list conversion
This commit is contained in:
parent
f97c7c9130
commit
209bf2403d
1 changed files with 3 additions and 1 deletions
|
|
@ -2,6 +2,7 @@ from abc import abstractmethod, ABC
|
||||||
from os.path import join, curdir
|
from os.path import join, curdir
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset as TorchDataset
|
from torch.utils.data import Dataset as TorchDataset
|
||||||
|
|
@ -57,7 +58,8 @@ class Dataset(TorchDataset, ABC):
|
||||||
self.chunk_offsets = self.get_offsets()
|
self.chunk_offsets = self.get_offsets()
|
||||||
self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace')
|
self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace')
|
||||||
|
|
||||||
self.tensor = torch.tensor(list(self.bytes), dtype=torch.long)
|
bytes_array = np.frombuffer(self.bytes, dtype=np.uint8) # Zero-copy
|
||||||
|
self.tensor = torch.from_numpy(bytes_array).to(torch.long)
|
||||||
|
|
||||||
def get_offsets(self):
|
def get_offsets(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Reference in a new issue