feat: uhm, i changed some things

This commit is contained in:
RobinMeersman 2025-11-25 20:20:08 +01:00
parent b58682cb49
commit 6de4db24cc
27 changed files with 1302 additions and 137 deletions

View file

@ -190,7 +190,7 @@ def run(args, kwargs):
import models.Model as Model
model = Model.Model(args)
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.set_temperature(args.temperature)
model.enable_hard_round(args.hard_round)
@ -208,7 +208,7 @@ def run(args, kwargs):
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model, dim=0)
model.to(args.device)
model.to(args.DEVICE)
def lr_lambda(epoch):
return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch)