feat: i think i set up the encoder
This commit is contained in:
parent
947aba31ee
commit
63d1b6f5ae
2 changed files with 51 additions and 6 deletions
|
|
@ -3,7 +3,7 @@ from torch.utils.data import TensorDataset
|
||||||
|
|
||||||
|
|
||||||
def make_context_pairs(data: bytes, context_length: int) -> TensorDataset:
|
def make_context_pairs(data: bytes, context_length: int) -> TensorDataset:
|
||||||
data = torch.tensor(data, dtype=torch.uint8)
|
data = torch.tensor(list(data), dtype=torch.uint8)
|
||||||
sample_count = data.shape[0] - context_length
|
sample_count = data.shape[0] - context_length
|
||||||
x = data.unfold(0, context_length, 1)[:sample_count]
|
x = data.unfold(0, context_length, 1)[:sample_count]
|
||||||
y = data[context_length:]
|
y = data[context_length:]
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.functional as F
|
import torch.nn.functional as F
|
||||||
import optuna.trial as tr
|
import optuna.trial as tr
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from optuna_trial import create_model
|
from optuna_trial import create_model
|
||||||
from data_utils import make_context_pairs
|
from data_utils import make_context_pairs
|
||||||
|
import optuna
|
||||||
|
|
||||||
# hyper parameters
|
# hyper parameters
|
||||||
context_length = 128
|
context_length = 128
|
||||||
|
|
@ -18,18 +20,61 @@ def train_and_eval(
|
||||||
epochs: int = 100,
|
epochs: int = 100,
|
||||||
learning_rate: float = 1e-3,
|
learning_rate: float = 1e-3,
|
||||||
device: torch.device = torch.device("cpu")
|
device: torch.device = torch.device("cpu")
|
||||||
):
|
) -> dict:
|
||||||
model.to(device)
|
model.to(device)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
||||||
training_loader = DataLoader(make_context_pairs(training_data, context_length=context_length))
|
training_loader = DataLoader(make_context_pairs(training_data, context_length=context_length), batch_size=batch_size)
|
||||||
validation_loader= DataLoader(make_context_pairs(validation_data, context_length=context_length))
|
validation_loader= DataLoader(make_context_pairs(validation_data, context_length=context_length), batch_size=batch_size)
|
||||||
|
|
||||||
|
training_losses = []
|
||||||
|
validation_losses = []
|
||||||
|
best_val_loss = float("inf")
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
model.train()
|
model.train()
|
||||||
|
train_loss = 0
|
||||||
|
for x, y in tqdm(training_loader, desc=f"Epoch {epoch}"):
|
||||||
|
x, y = x.to(device), y.to(device)
|
||||||
|
prediction = model(x)
|
||||||
|
loss = F.cross_entropy(prediction, y)
|
||||||
|
train_loss += loss.item()
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
training_losses.append(train_loss / len(training_loader))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
val_loss = 0
|
||||||
|
for x, y in validation_loader:
|
||||||
|
x, y = x.to(device), y.to(device)
|
||||||
|
prediction = model(x)
|
||||||
|
loss = F.cross_entropy(prediction, y)
|
||||||
|
val_loss += loss.item()
|
||||||
|
validation_losses.append(val_loss / len(validation_loader))
|
||||||
|
if validation_losses[-1] < best_val_loss:
|
||||||
|
best_val_loss = validation_losses[-1]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"training_losses": training_losses,
|
||||||
|
"validation_losses": validation_losses,
|
||||||
|
"best_validation_loss": best_val_loss
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def objective_function(trial: tr.Trial):
|
|
||||||
|
def objective_function(trial: tr.Trial, train_data: bytes, validation_data: bytes, batch_size: int):
|
||||||
model = create_model(trial)
|
model = create_model(trial)
|
||||||
|
result = train_and_eval(model, train_data, validation_data, batch_size)
|
||||||
|
return result["best_validation_loss"]
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
train_data = b""
|
||||||
|
validation_data = b""
|
||||||
|
batch_size = 0
|
||||||
|
|
||||||
|
study = optuna.create_study(study_name="CNN network",direction="minimize")
|
||||||
|
study.optimize(lambda trial: objective_function(trial, train_data, validation_data, batch_size), n_trials=10)
|
||||||
Reference in a new issue