feat: uhm, i changed some things
This commit is contained in:
parent
b58682cb49
commit
6de4db24cc
27 changed files with 1302 additions and 137 deletions
|
|
@ -190,7 +190,7 @@ def run(args, kwargs):
|
|||
import models.Model as Model
|
||||
|
||||
model = Model.Model(args)
|
||||
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
args.DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
model.set_temperature(args.temperature)
|
||||
model.enable_hard_round(args.hard_round)
|
||||
|
||||
|
|
@ -208,7 +208,7 @@ def run(args, kwargs):
|
|||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = torch.nn.DataParallel(model, dim=0)
|
||||
|
||||
model.to(args.device)
|
||||
model.to(args.DEVICE)
|
||||
|
||||
def lr_lambda(epoch):
|
||||
return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch)
|
||||
|
|
|
|||
Reference in a new issue