feat: optuna optimization performed
This commit is contained in:
parent
2ab4abdf93
commit
fe207962de
5 changed files with 15 additions and 18 deletions
|
|
@ -17,19 +17,18 @@ class LoremIpsumDataset(Dataset):
|
||||||
path = join(curdir, "data")
|
path = join(curdir, "data")
|
||||||
self._root = path
|
self._root = path
|
||||||
# Convert text to bytes (UTF-8 encoded)
|
# Convert text to bytes (UTF-8 encoded)
|
||||||
self.dataset = torch.tensor([ord(c) for c in list(_text)], dtype=torch.long)
|
self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long)
|
||||||
|
self.context_length = 128
|
||||||
sequence_count = self.dataset.shape[0] // 128 # how many vectors of 128 elements can we make
|
|
||||||
self.dataset = self.dataset[:sequence_count * 128]
|
|
||||||
self.dataset = self.dataset.view(-1, 128)
|
|
||||||
|
|
||||||
print(self.dataset.shape)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
# Number of possible sequences of length sequence_length
|
# Number of possible sequences of length sequence_length
|
||||||
return self.dataset.size(0)
|
return self.dataset.size(0) - self.context_length
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
x = self.dataset[idx: idx + self.context_length]
|
||||||
|
y = self.dataset[idx + self.context_length]
|
||||||
|
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
return self.transform(self.dataset[idx])
|
x = self.transform(x)
|
||||||
return self.dataset[idx]
|
|
||||||
|
return x, y
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ class CNNPredictor(nn.Module):
|
||||||
self,
|
self,
|
||||||
vocab_size=256,
|
vocab_size=256,
|
||||||
embed_dim=64,
|
embed_dim=64,
|
||||||
hidden_channels=128,
|
hidden_dim=128,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -15,11 +15,11 @@ class CNNPredictor(nn.Module):
|
||||||
|
|
||||||
# 2. Convolutional feature extractor
|
# 2. Convolutional feature extractor
|
||||||
self.conv_layers = nn.Sequential(
|
self.conv_layers = nn.Sequential(
|
||||||
nn.Conv1d(embed_dim, hidden_channels, kernel_size=5, padding=2),
|
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=5, padding=2),
|
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=5, padding=2),
|
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -27,7 +27,7 @@ class CNNPredictor(nn.Module):
|
||||||
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
|
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
|
||||||
|
|
||||||
# 4. Final classifier
|
# 4. Final classifier
|
||||||
self.fc = nn.Linear(hidden_channels, vocab_size) # → (B, 256)
|
self.fc = nn.Linear(hidden_dim, vocab_size) # → (B, 256)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
BIN
CNN-model/models/final_model.pt
Normal file
BIN
CNN-model/models/final_model.pt
Normal file
Binary file not shown.
|
|
@ -13,7 +13,7 @@ from .train import train
|
||||||
|
|
||||||
def create_model(trial: tr.Trial, vocab_size: int = 256):
|
def create_model(trial: tr.Trial, vocab_size: int = 256):
|
||||||
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
||||||
embedding_dim = trial.suggest_int("embedding_dim", 64, 512, log=True)
|
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
|
||||||
|
|
||||||
return CNNPredictor(
|
return CNNPredictor(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,4 @@ def train(
|
||||||
avg_loss = sum(losses) / len(losses)
|
avg_loss = sum(losses) / len(losses)
|
||||||
avg_validation_losses.append(avg_loss)
|
avg_validation_losses.append(avg_loss)
|
||||||
|
|
||||||
tqdm.write(f"epoch: {epoch + 1}, avg val loss = {avg_loss:.4f}")
|
|
||||||
|
|
||||||
return avg_training_losses, avg_validation_losses
|
return avg_training_losses, avg_validation_losses
|
||||||
|
|
|
||||||
Reference in a new issue