feat: optuna optimization performed
This commit is contained in:
parent
2ab4abdf93
commit
fe207962de
5 changed files with 15 additions and 18 deletions
|
|
@ -6,7 +6,7 @@ class CNNPredictor(nn.Module):
|
|||
self,
|
||||
vocab_size=256,
|
||||
embed_dim=64,
|
||||
hidden_channels=128,
|
||||
hidden_dim=128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -15,11 +15,11 @@ class CNNPredictor(nn.Module):
|
|||
|
||||
# 2. Convolutional feature extractor
|
||||
self.conv_layers = nn.Sequential(
|
||||
nn.Conv1d(embed_dim, hidden_channels, kernel_size=5, padding=2),
|
||||
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=5, padding=2),
|
||||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=5, padding=2),
|
||||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ class CNNPredictor(nn.Module):
|
|||
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
|
||||
|
||||
# 4. Final classifier
|
||||
self.fc = nn.Linear(hidden_channels, vocab_size) # → (B, 256)
|
||||
self.fc = nn.Linear(hidden_dim, vocab_size) # → (B, 256)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
|
|
|
|||
Reference in a new issue