40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
import functools
|
|
import os, shutil
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
|
def logging(s, log_path, print_=True, log_=True):
|
|
if print_:
|
|
print(s)
|
|
if log_:
|
|
with open(log_path, 'a+') as f_log:
|
|
f_log.write(s + '\n')
|
|
|
|
def get_logger(log_path, **kwargs):
|
|
return functools.partial(logging, log_path=log_path, **kwargs)
|
|
|
|
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
|
|
if debug:
|
|
print('Debug Mode : no experiment dir created')
|
|
return functools.partial(logging, log_path=None, log_=False)
|
|
|
|
if not os.path.exists(dir_path):
|
|
os.makedirs(dir_path)
|
|
|
|
print('Experiment dir : {}'.format(dir_path))
|
|
if scripts_to_save is not None:
|
|
script_path = os.path.join(dir_path, 'scripts')
|
|
if not os.path.exists(script_path):
|
|
os.makedirs(script_path)
|
|
for script in scripts_to_save:
|
|
dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
|
|
shutil.copyfile(script, dst_file)
|
|
|
|
return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
|
|
|
|
def save_checkpoint(model, optimizer, path, epoch):
|
|
torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))
|
|
torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch)))
|