feat: uhm, i changed some things
This commit is contained in:
parent
b58682cb49
commit
6de4db24cc
27 changed files with 1302 additions and 137 deletions
31
CNN-model/utils/utils.py
Normal file
31
CNN-model/utils/utils.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import torch
|
||||
from torch.utils.data import TensorDataset
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def make_context_pairs(data: bytes, context_length: int) -> TensorDataset:
|
||||
data = torch.tensor(list(data), dtype=torch.long)
|
||||
sample_count = data.shape[0] - context_length
|
||||
x = data.unfold(0, context_length, 1)[:sample_count]
|
||||
y = data[context_length:]
|
||||
return TensorDataset(x, y)
|
||||
|
||||
def print_distribution(from_to: tuple[int, int], probabilities: list[float]):
|
||||
plt.hist(range(from_to[0], from_to[1]), weights=probabilities)
|
||||
plt.show()
|
||||
|
||||
def print_losses(train_losses: list[float], validation_losses: list[float], show=False):
|
||||
plt.plot(train_losses, label="Training loss")
|
||||
plt.plot(validation_losses, label="Validation loss")
|
||||
plt.xlabel("Epoch")
|
||||
plt.ylabel("Loss (cross entropy)")
|
||||
plt.legend()
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
plt.savefig("losses.png")
|
||||
|
||||
|
||||
def load_data(path: str) -> bytes:
|
||||
with open(path, "rb") as f:
|
||||
return f.read()
|
||||
Reference in a new issue