fix: conflicts
This commit is contained in:
commit
b178c097d8
90 changed files with 2034 additions and 11145 deletions
51
src/args.py
Normal file
51
src/args.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from argparse import ArgumentParser
|
||||
|
||||
from src.dataset_loaders import dataset_called
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = ArgumentParser(prog="NeuralCompression")
|
||||
parser.add_argument("--debug", "-d", action="store_true", required=False,
|
||||
help="Enable debug mode: smaller datasets, more information")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", required=False,
|
||||
help="Enable verbose mode")
|
||||
parser.add_argument("--results", type=str, required=True, help="path to save graphs to")
|
||||
parser.add_argument("--device", required=False, help="Override the device to use")
|
||||
|
||||
dataparser = ArgumentParser(add_help=False)
|
||||
dataparser.add_argument("--data-root", type=str, required=False)
|
||||
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
|
||||
dataparser.add_argument("--size", "-s", type=int, required=False,
|
||||
help="Size of the subset of the dataset to use")
|
||||
|
||||
modelparser = ArgumentParser(add_help=False)
|
||||
modelparser.add_argument("--model", "-m", type=str, required=False,
|
||||
help="Which model to use")
|
||||
modelparser.add_argument("--model-load-path", type=str, required=False,
|
||||
help="Filepath to the model to load")
|
||||
modelparser.add_argument("--model-save-path", type=str, required=False,
|
||||
help="Filepath to the model to save")
|
||||
modelparser.add_argument("--context", type=int, required=False,
|
||||
help="Context length to use")
|
||||
|
||||
fileparser = ArgumentParser(add_help=False)
|
||||
fileparser.add_argument("--input-file", "-i", required=False, type=str)
|
||||
fileparser.add_argument("--output-file", "-o", required=False, type=str)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="mode", required=True,
|
||||
help="Mode to run in")
|
||||
|
||||
train_parser = subparsers.add_parser("train",
|
||||
parents=[dataparser, modelparser],
|
||||
help="Do a full training")
|
||||
train_parser.add_argument("--method",
|
||||
choices=["fetch", "optuna", "full"], required=True,
|
||||
help="Method to use for training")
|
||||
|
||||
subparsers.add_parser("compress", parents=[modelparser, fileparser],
|
||||
help="Compress a file")
|
||||
|
||||
subparsers.add_parser("decompress", parents=[modelparser, fileparser],
|
||||
help="Decompress a file")
|
||||
|
||||
return parser.parse_args(), parser.print_help
|
||||
132
src/dataset_loaders/Dataset.py
Normal file
132
src/dataset_loaders/Dataset.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
from abc import abstractmethod, ABC
|
||||
from itertools import accumulate
|
||||
from os.path import join, curdir
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
from tqdm import tqdm
|
||||
|
||||
"""
|
||||
Author: Tibo De Peuter
|
||||
"""
|
||||
|
||||
|
||||
class Dataset(TorchDataset, ABC):
|
||||
"""Abstract base class for datasets."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self,
|
||||
name: str,
|
||||
root: str | None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = -1,
|
||||
context_length: int = 1024
|
||||
):
|
||||
"""
|
||||
:param root: Path to the dataset root directory
|
||||
:param split: The dataset split, e.g. 'train', 'validation', 'test'
|
||||
:param size: Override the maximum size of the dataset, useful for debugging
|
||||
"""
|
||||
if root is None:
|
||||
root = join(curdir, 'data')
|
||||
|
||||
self._root = join(root, name)
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.size = size
|
||||
self.context_length = context_length
|
||||
self.data = None
|
||||
|
||||
print(f"Context length: {self.context_length}")
|
||||
|
||||
self.chunk_offsets: list[int] = []
|
||||
self.bytes: bytes = bytes()
|
||||
self.tensor: Tensor = torch.tensor([])
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
return self._root
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def process_data(self):
|
||||
self.chunk_offsets = self.get_offsets()
|
||||
if self.size == -1:
|
||||
# Just use the whole dataset
|
||||
self.bytes = ''.join(tqdm(self.data, desc="Encoding data", leave=False)).encode('utf-8', errors='replace')
|
||||
else:
|
||||
# Use only partition, calculate offsets
|
||||
self.bytes = (''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data", leave=False))
|
||||
.encode('utf-8', errors='replace'))
|
||||
|
||||
bytes_array = np.frombuffer(self.bytes, dtype=np.uint8) # Zero-copy
|
||||
self.tensor = torch.from_numpy(bytes_array).to(torch.long, non_blocking=True)
|
||||
|
||||
def get_offsets(self):
|
||||
"""
|
||||
Calculate for each chunk how many bytes came before it
|
||||
"""
|
||||
data = self.data
|
||||
size = self.size
|
||||
|
||||
if size == -1:
|
||||
return [0, *accumulate(tqdm(map(len, data), desc="Calculating offsets", leave=False, total=len(data)))]
|
||||
|
||||
offsets = [0]
|
||||
total = 0
|
||||
append = offsets.append
|
||||
for chunk in tqdm(data):
|
||||
if total >= size:
|
||||
break
|
||||
total += len(chunk)
|
||||
append(total)
|
||||
return offsets
|
||||
|
||||
def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):
|
||||
item = ''
|
||||
|
||||
# Determine first chunk in which item is located
|
||||
chunk_idx = 0
|
||||
while idx >= offsets[chunk_idx]:
|
||||
chunk_idx += 1
|
||||
chunk_idx -= 1
|
||||
|
||||
# Extract item from chunks
|
||||
chunk = str(self.data[chunk_idx])
|
||||
chunk_start = offsets[chunk_idx]
|
||||
|
||||
chunk_item_start = idx - chunk_start
|
||||
item_len_remaining = context_length + 1
|
||||
|
||||
assert len(item) + item_len_remaining == context_length + 1
|
||||
|
||||
while chunk_item_start + item_len_remaining > len(chunk):
|
||||
adding_now_len = len(chunk) - chunk_item_start
|
||||
item += chunk[chunk_item_start:]
|
||||
|
||||
chunk_idx += 1
|
||||
chunk = str(self.data[chunk_idx])
|
||||
|
||||
chunk_item_start = 0
|
||||
item_len_remaining -= adding_now_len
|
||||
|
||||
assert len(item) + item_len_remaining == context_length + 1
|
||||
|
||||
item += chunk[chunk_item_start: chunk_item_start + item_len_remaining]
|
||||
|
||||
assert len(item) == context_length + 1, f"Expected item of length {context_length + 1}, was {len(item)}"
|
||||
|
||||
# Transform to tensor
|
||||
data = ''.join(item).encode('utf-8', errors='replace')
|
||||
t = torch.tensor(list(data), dtype=torch.long)
|
||||
x, y = t[:-1], t[-1]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
57
src/dataset_loaders/EnWik9.py
Normal file
57
src/dataset_loaders/EnWik9.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
from math import ceil
|
||||
from typing import Callable
|
||||
|
||||
from datasets import load_dataset, Features, Value
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class EnWik9DataSet(Dataset):
|
||||
"""
|
||||
Hugging Face: https://huggingface.co/datasets/haukur/enwik9
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable | None = None,
|
||||
size: int = -1,
|
||||
context_length: int = 1024
|
||||
):
|
||||
super().__init__('enwik9', root, split, transform, size, context_length)
|
||||
|
||||
print(f"Loading from HuggingFace")
|
||||
ft = Features({'text': Value('string')})
|
||||
# Don't pass split here, dataset only contains training
|
||||
text_chunks = load_dataset("haukur/enwik9", cache_dir=self.root, split='train', features=ft)
|
||||
self.data = text_chunks['text']
|
||||
self.size = size
|
||||
|
||||
self.process_data()
|
||||
|
||||
# Define splits manually, because they do not exist in the dataset
|
||||
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
||||
|
||||
if self.split == 'train':
|
||||
self.start_byte = 0
|
||||
self.end_byte = split_point
|
||||
elif self.split == 'validation':
|
||||
self.start_byte = split_point
|
||||
self.end_byte = self.chunk_offsets[-1]
|
||||
else:
|
||||
raise ValueError("split must be 'train' or 'validation'")
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
return self.end_byte - self.start_byte - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
|
||||
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
|
||||
y = self.tensor[self.start_byte + idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
45
src/dataset_loaders/HumanReferenceGenomeDataset.py
Normal file
45
src/dataset_loaders/HumanReferenceGenomeDataset.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
from typing import Callable
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class HumanReferenceGenomeDataset(Dataset):
|
||||
"""
|
||||
Hugging Face: https://huggingface.co/datasets/InstaDeepAI/human_reference_genome
|
||||
|
||||
:param split: 'train' | 'validation' | 'test'
|
||||
:param config: '6kbp' | '12kbp' (chunk length in the HF builder config)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = "train",
|
||||
transform: Callable = None,
|
||||
size: int = -1,
|
||||
context_length: int = 1024,
|
||||
config: str = "6kbp",
|
||||
):
|
||||
super().__init__("human_reference_genome", root, split, transform, size, context_length)
|
||||
|
||||
print(f"Loading from HuggingFace (config: {config}, split: {split})")
|
||||
data = load_dataset("InstaDeepAI/human_reference_genome", config, split=split,
|
||||
cache_dir=self.root, trust_remote_code=True)
|
||||
self.data = data["sequence"]
|
||||
|
||||
self.process_data()
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
return self.chunk_offsets[-1] - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
x = self.tensor[idx: idx + self.context_length]
|
||||
y = self.tensor[idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
62
src/dataset_loaders/LoremIpsumDataset.py
Normal file
62
src/dataset_loaders/LoremIpsumDataset.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from math import ceil
|
||||
from typing import Callable
|
||||
|
||||
from lorem.text import TextLorem
|
||||
from tqdm import tqdm
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class LoremIpsumDataset(Dataset):
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = 2**30,
|
||||
context_length: int = 1024
|
||||
):
|
||||
super().__init__('lorem_ipsum', root, split, transform, size, context_length)
|
||||
|
||||
_lorem = TextLorem()
|
||||
|
||||
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
|
||||
self.size = size
|
||||
|
||||
self.process_data()
|
||||
|
||||
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
||||
|
||||
if self.split == 'train':
|
||||
self.start_byte = 0
|
||||
self.end_byte = split_point
|
||||
elif self.split == 'validation':
|
||||
self.start_byte = split_point
|
||||
self.end_byte = self.chunk_offsets[-1]
|
||||
else:
|
||||
raise ValueError("split must be 'train' or 'validation'")
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
return self.end_byte - self.start_byte - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Get sequence of characters
|
||||
# x_str = self.text[idx: idx + self.context_length]
|
||||
# y_char = self.text[idx + self.context_length]
|
||||
#
|
||||
# # Convert to tensors
|
||||
# x = torch.tensor([ord(c) % 256 for c in x_str], dtype=torch.long)
|
||||
# y = torch.tensor(ord(y_char) % 256, dtype=torch.long)
|
||||
#
|
||||
# if self.transform is not None:
|
||||
# x = self.transform(x)
|
||||
#
|
||||
# return x, y
|
||||
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
|
||||
y = self.tensor[self.start_byte + idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
49
src/dataset_loaders/OpenGenomeDataset.py
Normal file
49
src/dataset_loaders/OpenGenomeDataset.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from typing import Callable
|
||||
|
||||
from datasets import load_dataset, Value, Features
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class OpenGenomeDataset(Dataset):
|
||||
"""
|
||||
Hugging Face: https://huggingface.co/datasets/LongSafari/open-genome
|
||||
|
||||
:param split Either 'train', 'test' or 'validation'
|
||||
:param stage Either 'sample', 'stage1' or 'stage2'.
|
||||
'sample' only provides a 'validation' split
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = -1,
|
||||
context_length: int = 1024,
|
||||
stage: str = 'stage2'
|
||||
):
|
||||
super().__init__('open_genome', root, split, transform, size, context_length)
|
||||
|
||||
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
|
||||
ft = Features({'text': Value('string')})
|
||||
data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft)
|
||||
self.data = data['text']
|
||||
self.size = size
|
||||
|
||||
self.process_data()
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
# return len(self.data) - self.context_length
|
||||
return self.chunk_offsets[-1] - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
|
||||
x = self.tensor[idx:idx + self.context_length]
|
||||
y = self.tensor[idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
12
src/dataset_loaders/__init__.py
Normal file
12
src/dataset_loaders/__init__.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from .Dataset import Dataset
|
||||
from .EnWik9 import EnWik9DataSet
|
||||
from .HumanReferenceGenomeDataset import HumanReferenceGenomeDataset
|
||||
from .LoremIpsumDataset import LoremIpsumDataset
|
||||
from .OpenGenomeDataset import OpenGenomeDataset
|
||||
|
||||
dataset_called: dict[str, type[Dataset]] = {
|
||||
'enwik9': EnWik9DataSet,
|
||||
'lorem_ipsum': LoremIpsumDataset,
|
||||
'opengenome': OpenGenomeDataset,
|
||||
'humanreference': HumanReferenceGenomeDataset
|
||||
}
|
||||
14
src/models/Model.py
Normal file
14
src/models/Model.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Model(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, loss_function = None):
|
||||
super().__init__()
|
||||
self._loss_function = loss_function
|
||||
|
||||
@property
|
||||
def loss_function(self):
|
||||
return self._loss_function
|
||||
8
src/models/__init__.py
Normal file
8
src/models/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from .Model import Model
|
||||
from .cnn import CNNPredictor
|
||||
from .transformer import ByteTransformer
|
||||
|
||||
model_called: dict[str, type[Model]] = {
|
||||
'cnn': CNNPredictor,
|
||||
'transformer': ByteTransformer
|
||||
}
|
||||
18
src/models/autoencoder.py
Normal file
18
src/models/autoencoder.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, output_size):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
1
src/models/cnn/__init__.py
Normal file
1
src/models/cnn/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .cnn import CNNPredictor
|
||||
54
src/models/cnn/cnn.py
Normal file
54
src/models/cnn/cnn.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
import torch.nn as nn
|
||||
|
||||
from src.models import Model
|
||||
|
||||
|
||||
class CNNPredictor(Model):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256,
|
||||
embed_dim=64,
|
||||
hidden_dim=128,
|
||||
):
|
||||
super().__init__(nn.CrossEntropyLoss())
|
||||
|
||||
# 1. Embedding: maps bytes (0–255) → vectors
|
||||
self.embed = nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# 2. Convolutional feature extractor
|
||||
self.conv_layers = nn.Sequential(
|
||||
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# 3. Global pooling to collapse sequence length
|
||||
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
|
||||
|
||||
# 4. Final classifier
|
||||
self.fc = nn.Linear(hidden_dim, vocab_size) # → (B, 256)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: LongTensor of shape (B, 128), values 0-255
|
||||
"""
|
||||
# embed: (B, 128, embed_dim)
|
||||
x = self.embed(x)
|
||||
|
||||
# conv1d expects (B, C_in, L) → swap dims
|
||||
x = x.transpose(1, 2) # (B, embed_dim, 128)
|
||||
|
||||
# apply CNN
|
||||
x = self.conv_layers(x) # (B, hidden_channels, 128)
|
||||
|
||||
# global average pooling over sequence
|
||||
x = self.pool(x).squeeze(-1) # (B, hidden_channels)
|
||||
|
||||
# final classifier
|
||||
logits = self.fc(x) # (B, 256)
|
||||
return logits
|
||||
|
||||
|
||||
1
src/models/transformer/__init__.py
Normal file
1
src/models/transformer/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .transformer import ByteTransformer
|
||||
70
src/models/transformer/transformer.py
Normal file
70
src/models/transformer/transformer.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, arange
|
||||
|
||||
from src.models import Model
|
||||
|
||||
|
||||
class LearnedPositionalEncoding(Model):
|
||||
def __init__(self, max_len, d_model):
|
||||
super().__init__()
|
||||
self.pos_emb = nn.Embedding(max_len, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [seq, batch, d_model]
|
||||
seq_len = x.size(0)
|
||||
positions = arange(seq_len, device=x.device).unsqueeze(1) # [seq, 1]
|
||||
return x + self.pos_emb(positions) # broadcast over batch
|
||||
|
||||
class ByteTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
layer_norm_eps=1e-05,
|
||||
max_len=128
|
||||
):
|
||||
super().__init__()
|
||||
self.src_embedding = nn.Embedding(256, d_model)
|
||||
self.tgt_embedding = nn.Embedding(256, d_model)
|
||||
|
||||
self.src_pos = LearnedPositionalEncoding(max_len, d_model)
|
||||
self.tgt_pos = LearnedPositionalEncoding(max_len, d_model)
|
||||
|
||||
self.transformer = nn.Transformer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
batch_first=False,
|
||||
norm_first=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
)
|
||||
|
||||
self.output_proj = nn.Linear(d_model, 256)
|
||||
|
||||
self.loss_function = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
tgt: Tensor,
|
||||
) -> Tensor:
|
||||
src_embeds = self.src_embedding(src)
|
||||
tgt_embeds = self.tgt_embedding(tgt)
|
||||
|
||||
src_pos = self.src_pos(src_embeds)
|
||||
tgt_pos = self.tgt_pos(tgt_embeds)
|
||||
|
||||
return self.output_proj(self.transformer(src_pos, tgt_pos))
|
||||
30
src/process.py
Normal file
30
src/process.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import torch
|
||||
|
||||
|
||||
def compress(
|
||||
device,
|
||||
model_path: str,
|
||||
output_file: str,
|
||||
input_file: str | None = None
|
||||
):
|
||||
# Get input to compress
|
||||
if input_file:
|
||||
with open(input_file, "rb") as file:
|
||||
byte_data = file.read()
|
||||
else:
|
||||
# Read from stdin
|
||||
text = input()
|
||||
byte_data = text.encode('utf-8', errors='replace')
|
||||
|
||||
tensor = torch.tensor(list(byte_data), dtype=torch.long)
|
||||
print(tensor)
|
||||
|
||||
# Get model
|
||||
model = torch.load(model_path, weights_only=False)
|
||||
|
||||
# TODO Feed to model for compression, store result
|
||||
return
|
||||
|
||||
|
||||
def decompress():
|
||||
return NotImplementedError("Decompression is not implemented yet")
|
||||
79
src/train.py
Normal file
79
src/train.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from src.dataset_loaders import dataset_called
|
||||
from src.models import model_called
|
||||
from src.trainers import OptunaTrainer, Trainer, FullTrainer
|
||||
|
||||
|
||||
def train(
|
||||
device,
|
||||
dataset: str,
|
||||
data_root: str,
|
||||
n_trials: int | None = None,
|
||||
size: int | None = None,
|
||||
context_length: int | None = None,
|
||||
method: str = 'optuna',
|
||||
model_name: str | None = None,
|
||||
model_path: str | None = None,
|
||||
model_out: str | None = None,
|
||||
results_dir: str = 'results'
|
||||
):
|
||||
batch_size = 64
|
||||
|
||||
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
|
||||
|
||||
if model_name:
|
||||
print("Creating model")
|
||||
model = model_called[model_name]
|
||||
else:
|
||||
print("Loading model from disk")
|
||||
model = torch.load(model_path, weights_only=False)
|
||||
|
||||
dataset_common_args = {
|
||||
'root': data_root,
|
||||
'transform': lambda x: x.to(device),
|
||||
}
|
||||
|
||||
if size:
|
||||
dataset_common_args['size'] = size
|
||||
|
||||
if context_length:
|
||||
dataset_common_args['context_length'] = context_length
|
||||
|
||||
print("Loading in the dataset...")
|
||||
if dataset in dataset_called:
|
||||
training_set = dataset_called[dataset](split='train', **dataset_common_args)
|
||||
validate_set = dataset_called[dataset](split='validation', **dataset_common_args)
|
||||
else:
|
||||
# TODO Allow to import arbitrary files
|
||||
raise NotImplementedError(f"Importing external datasets is not implemented yet")
|
||||
|
||||
if method == 'fetch':
|
||||
# TODO More to earlier in chain, because now everything is converted into tensors as well?
|
||||
exit(0)
|
||||
|
||||
print(f"Training set size = {len(training_set)}, Validation set size {len(validate_set)}")
|
||||
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
|
||||
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
|
||||
|
||||
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer(results_dir=results_dir)
|
||||
|
||||
print("Training")
|
||||
best_model = trainer.execute(
|
||||
model=model,
|
||||
train_loader=training_loader,
|
||||
validation_loader=validation_loader,
|
||||
n_epochs=n_trials,
|
||||
device=device
|
||||
)
|
||||
|
||||
print("Saving model...")
|
||||
f = model_out or f"saved_models/{model.__class__.__name__}.pt"
|
||||
# Make sure path exists
|
||||
Path(f).parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(best_model, f)
|
||||
print(f"Saved model to '{f}'")
|
||||
|
||||
29
src/trainers/FullTrainer.py
Normal file
29
src/trainers/FullTrainer.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .train import train
|
||||
from .trainer import Trainer
|
||||
from ..models import Model
|
||||
from ..utils import print_losses
|
||||
|
||||
class FullTrainer(Trainer):
|
||||
def __init__(self, results_dir: str = 'results'):
|
||||
self.results_dir = results_dir
|
||||
|
||||
def execute(
|
||||
self,
|
||||
model: Model,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
n_epochs: int | None,
|
||||
device: str
|
||||
) -> nn.Module:
|
||||
if model is None:
|
||||
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||
|
||||
model.to(device)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs,
|
||||
device=device)
|
||||
print_losses(train_loss, val_loss, filename=f"{self.results_dir}/{model.__class__.__name__}-losses.png")
|
||||
|
||||
return model
|
||||
72
src/trainers/OptunaTrainer.py
Normal file
72
src/trainers/OptunaTrainer.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import optuna
|
||||
import optuna.trial as tr
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .train import train
|
||||
from .trainer import Trainer
|
||||
from ..models import Model, CNNPredictor, ByteTransformer
|
||||
|
||||
|
||||
def create_model(trial: tr.Trial, model: nn.Module):
|
||||
match model.__class__:
|
||||
case CNNPredictor.__class__:
|
||||
return model(
|
||||
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
|
||||
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
||||
vocab_size=256,
|
||||
)
|
||||
case ByteTransformer.__class__:
|
||||
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
|
||||
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
|
||||
return model(
|
||||
d_model=128, # hard coded for now as data loaders provide fixed (B, 128) tensors
|
||||
nhead=nhead,
|
||||
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
|
||||
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
|
||||
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
|
||||
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
|
||||
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
|
||||
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def objective_function(
|
||||
trial: tr.Trial,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
model: Model,
|
||||
device: str
|
||||
):
|
||||
model = create_model(trial, model).to(device)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device)
|
||||
return min(validation_loss)
|
||||
|
||||
|
||||
class OptunaTrainer(Trainer):
|
||||
def __init__(self, n_trials: int | None = None):
|
||||
super().__init__()
|
||||
self.n_trials = n_trials if n_trials else 20
|
||||
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
|
||||
|
||||
def execute(
|
||||
self,
|
||||
model: Model,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> nn.Module:
|
||||
study = optuna.create_study(direction="minimize")
|
||||
study.optimize(
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, model, device),
|
||||
n_trials=self.n_trials
|
||||
)
|
||||
|
||||
best_params = study.best_trial.params
|
||||
best_model = model(
|
||||
**best_params
|
||||
)
|
||||
|
||||
return best_model
|
||||
3
src/trainers/__init__.py
Normal file
3
src/trainers/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .OptunaTrainer import OptunaTrainer
|
||||
from .FullTrainer import FullTrainer
|
||||
from .trainer import Trainer
|
||||
81
src/trainers/train.py
Normal file
81
src/trainers/train.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..models import ByteTransformer, Model
|
||||
|
||||
|
||||
def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor:
|
||||
if isinstance(model, ByteTransformer):
|
||||
tgt_in = torch.cat([
|
||||
torch.zeros(x.shape[0], 1, device=device, dtype=torch.long),
|
||||
x[:, :-1]
|
||||
], dim=1)
|
||||
logits = model(x, tgt_in)
|
||||
|
||||
# only consider the last time step of the model where the full context
|
||||
# is available
|
||||
return logits[:, -1, :]
|
||||
return model(x)
|
||||
|
||||
|
||||
def train(
|
||||
model: Model,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable,
|
||||
epochs: int | None = None,
|
||||
learning_rate: float = 1e-3,
|
||||
weight_decay: float = 1e-8,
|
||||
device="cuda"
|
||||
) -> tuple[list[float], list[float]]:
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||
|
||||
avg_training_losses = []
|
||||
avg_validation_losses = []
|
||||
|
||||
if epochs is None:
|
||||
epochs = 100
|
||||
|
||||
for epoch in range(epochs):
|
||||
|
||||
model.train()
|
||||
total_loss = []
|
||||
|
||||
for x, y in tqdm(training_loader):
|
||||
# size (B, 128)
|
||||
x = x.long().to(device)
|
||||
|
||||
# size (B)
|
||||
y = y.long().to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
logits = _forward(model, x, device)
|
||||
|
||||
loss = loss_fn(logits, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss.append(loss.item())
|
||||
|
||||
avg_training_losses.append(sum(total_loss) / len(total_loss))
|
||||
|
||||
# ----- validation -----
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
losses = []
|
||||
for x, y in validation_loader:
|
||||
x = x.long().to(device)
|
||||
y = y.long().to(device)
|
||||
|
||||
logits = _forward(model, x, device)
|
||||
loss = loss_fn(logits, y)
|
||||
losses.append(loss.item())
|
||||
|
||||
avg_loss = sum(losses) / len(losses)
|
||||
avg_validation_losses.append(avg_loss)
|
||||
|
||||
return avg_training_losses, avg_validation_losses
|
||||
19
src/trainers/trainer.py
Normal file
19
src/trainers/trainer.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class Trainer(ABC):
|
||||
"""Abstract class for trainers."""
|
||||
|
||||
@abstractmethod
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
n_epochs: int | None,
|
||||
device: str
|
||||
) -> nn.Module:
|
||||
pass
|
||||
1
src/utils/__init__.py
Normal file
1
src/utils/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .utils import *
|
||||
175
src/utils/benchmark.py
Normal file
175
src/utils/benchmark.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"""Utilities functions for benchmarking."""
|
||||
import json
|
||||
import string
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from logging import getLogger
|
||||
from os import getpid, path
|
||||
from pathlib import Path
|
||||
from random import choices
|
||||
from subprocess import DEVNULL, PIPE, CalledProcessError, TimeoutExpired, run
|
||||
from timeit import timeit
|
||||
from typing import Callable
|
||||
|
||||
from memray import Tracker
|
||||
|
||||
from ..utils.benchmark_dataclasses import BenchmarkItem, BenchmarkResult
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
def get_commit_hash() -> str:
|
||||
"""
|
||||
Get the commit hash of the current git repository.
|
||||
|
||||
If not working in a git repository, return a random string that looks like a commit hash.
|
||||
"""
|
||||
try:
|
||||
return run(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
check=True,
|
||||
stdout=PIPE,
|
||||
stderr=DEVNULL,
|
||||
text=True,
|
||||
).stdout.strip()
|
||||
except CalledProcessError as e:
|
||||
log.error(
|
||||
"Could not determine the commit hash. Are you using a git repository?:\n%s",
|
||||
e,
|
||||
)
|
||||
log.error("Using a random string as commit hash.")
|
||||
return "".join(choices(string.hexdigits[:-6], k=40))
|
||||
|
||||
|
||||
def init_stat_file(stat_file: Path, header: str) -> int:
|
||||
"""Initialize a statistics file with a header."""
|
||||
# Check if the parent directory exists
|
||||
stat_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check if the file exists
|
||||
if stat_file.exists():
|
||||
# Nothing left to do
|
||||
return 0
|
||||
|
||||
# Initialize the file by writing the header to it.
|
||||
log.debug("Initializing statistics file %s", stat_file)
|
||||
stat_file.touch()
|
||||
stat_file.write_text(f"{header}\n", encoding="utf-8")
|
||||
return 1
|
||||
|
||||
|
||||
def track_time_memory(task: Callable, result: BenchmarkResult, mem_file: Path, mem_json_file: Path):
|
||||
"""Track the time and memory consumption of a task."""
|
||||
|
||||
def task_with_result():
|
||||
result.value = task()
|
||||
|
||||
# Measure memory consumption
|
||||
with Tracker(file_name=mem_file, native_traces=True, follow_fork=True, memory_interval_ms=1):
|
||||
try:
|
||||
# Measure runtime
|
||||
result.runtime = timeit(task_with_result, number=1, globals=globals())
|
||||
except BaseException as e:
|
||||
log.error("Error while timing the program:\n%s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
# Convert binary memory file into JSON.
|
||||
try:
|
||||
run(
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"memray",
|
||||
"stats",
|
||||
"--json",
|
||||
"--num-largest",
|
||||
"1",
|
||||
"--output",
|
||||
mem_json_file,
|
||||
mem_file,
|
||||
],
|
||||
check=True,
|
||||
timeout=100,
|
||||
stdout=DEVNULL,
|
||||
)
|
||||
# Parse JSON to get peak_memory
|
||||
mem_results = json.loads(mem_json_file.read_text(encoding="utf-8"))
|
||||
result.peak_memory = mem_results["metadata"]["peak_memory"]
|
||||
|
||||
except CalledProcessError as e:
|
||||
log.error(
|
||||
"Something went wrong while processing the memray memory file %s:\n%s",
|
||||
mem_file,
|
||||
e,
|
||||
)
|
||||
except TimeoutExpired as e:
|
||||
log.error(
|
||||
"Timeout expired while processing the memray memory file %s:\n%s}",
|
||||
mem_file,
|
||||
e,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def execute_benchmark(
|
||||
benchmark_item: BenchmarkItem,
|
||||
results_dir: str | Path,
|
||||
timeout: int = 100,
|
||||
) -> BenchmarkResult:
|
||||
"""Execute a benchmark and track its runtime and peak memory consumption."""
|
||||
mem_file = Path(path.join(results_dir, f"memray-{benchmark_item.task.__name__}.mem"))
|
||||
mem_json_file = Path(path.join(results_dir, f"memray-{benchmark_item.task.__name__}.json"))
|
||||
|
||||
result = BenchmarkResult(benchmark_item)
|
||||
|
||||
try:
|
||||
# Time and track memory usage
|
||||
# Kill after timeout in seconds
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
lambda: track_time_memory(
|
||||
lambda: benchmark_item.task(**benchmark_item.arguments), result, mem_file, mem_json_file
|
||||
)
|
||||
)
|
||||
executed_result = future.result(timeout=timeout)
|
||||
|
||||
if executed_result is not None:
|
||||
result = executed_result
|
||||
|
||||
log.info(
|
||||
"PID %d: %s finished [%.6f seconds, %d bytes]",
|
||||
getpid(),
|
||||
benchmark_item.get_method(),
|
||||
result.runtime,
|
||||
result.peak_memory,
|
||||
)
|
||||
except TimeoutError:
|
||||
log.error("Timeout expired while running the benchmark_suite, cleaning up now.")
|
||||
|
||||
log.info(
|
||||
"PID %d: %s failed after timeout (%d seconds)",
|
||||
getpid(),
|
||||
benchmark_item.get_method(),
|
||||
timeout,
|
||||
)
|
||||
finally:
|
||||
# Clean up memory dump file to save disk space.
|
||||
mem_file.unlink()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import hydra
|
||||
|
||||
# Dummy example, read the contents of the dataset
|
||||
def _read_contents(filename):
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
log.info("Dataset content: %s", f.read())
|
||||
|
||||
def _read_contents_wrapper(cfg):
|
||||
return _read_contents(cfg.dataset.path)
|
||||
|
||||
hydra_wrapped = hydra.main(config_path="../../config", config_name="config", version_base="1.2")(
|
||||
_read_contents_wrapper
|
||||
)()
|
||||
79
src/utils/benchmark_dataclasses.py
Normal file
79
src/utils/benchmark_dataclasses.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""
|
||||
Benchmark data classes.
|
||||
|
||||
This module contains the BenchmarkResult class which is used to store and print the results of a
|
||||
benchmark_suite.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class BenchmarkItem:
|
||||
"""A class used to represent a benchmark_suite (iteration)."""
|
||||
|
||||
task: Callable
|
||||
arguments: dict
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the BenchmarkItem object."""
|
||||
return self.get_in_data_format()
|
||||
|
||||
def get_method(self) -> str:
|
||||
"""
|
||||
Format the method as if it were a function call.
|
||||
"""
|
||||
method_name = self.task.__name__
|
||||
arguments = ", ".join(
|
||||
f'{key}={str(value)[:15]}'
|
||||
for key, value in self.arguments.items()
|
||||
)
|
||||
return f"{method_name}({arguments})"
|
||||
|
||||
def get_in_data_format(self) -> str:
|
||||
"""
|
||||
Format the benchmark_suite item to be printed to a .dat file.
|
||||
"""
|
||||
# Flatten out arguments
|
||||
values = list(self.__dict__.values())
|
||||
values[1:2] = values[1].values()
|
||||
|
||||
return " ".join(map(str, values))
|
||||
|
||||
def get_header(self) -> str:
|
||||
"""
|
||||
Returns the header which is just the names of the fields separated by spaces.
|
||||
"""
|
||||
return " ".join(self.__dict__.keys())
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class BenchmarkResult:
|
||||
"""A class used to represent the result of a benchmark_suite."""
|
||||
|
||||
benchmark_item: BenchmarkItem
|
||||
runtime: float = 0
|
||||
peak_memory: int = 0
|
||||
value: Any = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the BenchmarkResult object."""
|
||||
return self.get_in_data_format()
|
||||
|
||||
def get_in_data_format(self) -> str:
|
||||
"""
|
||||
Format the benchmark_suite result to be printed to a .dat file.
|
||||
"""
|
||||
return " ".join(map(str, self.__dict__.values()))
|
||||
|
||||
def get_header(self) -> str:
|
||||
"""
|
||||
Returns the header which is just the names of the fields separated by spaces.
|
||||
"""
|
||||
# Get header of the BenchmarkItem
|
||||
keys = list(self.__annotations__.keys())
|
||||
keys[0:1] = self.benchmark_item.__annotations__.keys()
|
||||
keys[1:2] = self.benchmark_item.arguments.keys()
|
||||
|
||||
return " ".join(keys)
|
||||
64
src/utils/utils.py
Normal file
64
src/utils/utils.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
import csv
|
||||
from os import path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset
|
||||
|
||||
|
||||
def make_context_pairs(data: bytes, context_length: int) -> TensorDataset:
|
||||
data = torch.tensor(list(data), dtype=torch.long)
|
||||
sample_count = data.shape[0] - context_length
|
||||
x = data.unfold(0, context_length, 1)[:sample_count]
|
||||
y = data[context_length:]
|
||||
return TensorDataset(x, y)
|
||||
|
||||
|
||||
def print_distribution(from_to: tuple[int, int], probabilities: list[float]):
|
||||
plt.hist(range(from_to[0], from_to[1]), weights=probabilities)
|
||||
plt.show()
|
||||
|
||||
|
||||
def print_losses(train_losses: list[float], validation_losses: list[float], filename: str | None = None, show=False):
|
||||
plt.plot(train_losses, label="Training loss")
|
||||
plt.plot(validation_losses, label="Validation loss")
|
||||
plt.xlabel("Epoch")
|
||||
plt.ylabel("Loss (cross entropy)")
|
||||
plt.legend()
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
if filename is None:
|
||||
filename = path.join("results", "losses.png")
|
||||
|
||||
print(f"Saving losses to {filename}...")
|
||||
plt.savefig(filename)
|
||||
|
||||
# Also write to CSV file
|
||||
with open(filename.replace(".png", ".csv"), "w") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["epoch", "train_loss", "validation_loss"])
|
||||
for i in range(len(train_losses)):
|
||||
writer.writerow([i, train_losses[i], validation_losses[i]])
|
||||
|
||||
print("Done")
|
||||
|
||||
|
||||
def determine_device():
|
||||
# NVIDIA GPUs (most HPC clusters)
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
# Apple Silicon (macOS)
|
||||
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
# Intel GPUs (oneAPI)
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return torch.device("xpu")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def load_data(path: str) -> bytes:
|
||||
with open(path, "rb") as f:
|
||||
return f.read()
|
||||
Reference in a new issue