Streamline datasets
This commit is contained in:
parent
849bcd7b77
commit
befb1a96a5
8 changed files with 222 additions and 64 deletions
|
|
@ -10,6 +10,8 @@ from trainers import OptunaTrainer, Trainer, FullTrainer
|
|||
|
||||
def parse_arguments():
|
||||
parser = ArgumentParser(prog="NeuralCompression")
|
||||
parser.add_argument("--debug", "-d", action="store_true", required=False,
|
||||
help="Enable debug mode: smaller datasets, more information")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", required=False,
|
||||
help="Enable verbose mode")
|
||||
|
||||
|
|
@ -18,7 +20,7 @@ def parse_arguments():
|
|||
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
|
||||
|
||||
modelparser = ArgumentParser(add_help=False)
|
||||
modelparser.add_argument("--model-path", type=str, required=True,
|
||||
modelparser.add_argument("--model-path", type=str, required=False,
|
||||
help="Path to the model to load/save")
|
||||
|
||||
fileparser = ArgumentParser(add_help=False)
|
||||
|
|
@ -33,6 +35,8 @@ def parse_arguments():
|
|||
help="Only fetch the dataset, then exit")
|
||||
|
||||
train_parser = subparsers.add_parser("train", parents=[dataparser, modelparser])
|
||||
train_parser.add_argument("--method", choices=["optuna", "full"], required=True,
|
||||
help="Method to use for training")
|
||||
|
||||
# TODO
|
||||
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
|
||||
|
|
@ -44,7 +48,7 @@ def parse_arguments():
|
|||
|
||||
|
||||
def main():
|
||||
BATCH_SIZE = 64
|
||||
BATCH_SIZE = 2
|
||||
|
||||
# hyper parameters
|
||||
context_length = 128
|
||||
|
|
@ -57,9 +61,18 @@ def main():
|
|||
DEVICE = "cpu"
|
||||
print(f"Running on device: {DEVICE}...")
|
||||
|
||||
dataset_common_args = {
|
||||
'root': args.data_root,
|
||||
'transform': lambda x: x.to(DEVICE)
|
||||
}
|
||||
|
||||
if args.debug:
|
||||
dataset_common_args['size'] = 2**10
|
||||
|
||||
print("Loading in the dataset...")
|
||||
if args.dataset in dataset_called:
|
||||
dataset = dataset_called[args.dataset](root=args.data_root, transform=lambda x: x.to(DEVICE))
|
||||
training_set = dataset_called[args.dataset](split='train', **dataset_common_args)
|
||||
validate_set = dataset_called[args.dataset](split='validation', **dataset_common_args)
|
||||
else:
|
||||
# TODO Allow to import arbitrary files
|
||||
raise NotImplementedError(f"Importing external datasets is not implemented yet")
|
||||
|
|
@ -68,16 +81,10 @@ def main():
|
|||
# TODO More to earlier in chain, because now everything is converted into tensors as well?
|
||||
exit(0)
|
||||
|
||||
dataset_length = len(dataset)
|
||||
print(f"Dataset size = {dataset_length}")
|
||||
|
||||
training_size = ceil(0.8 * dataset_length)
|
||||
|
||||
print(f"Training set size = {training_size}, Validation set size {dataset_length - training_size}")
|
||||
|
||||
train_set, validate_set = torch.utils.data.random_split(dataset, [training_size, dataset_length - training_size])
|
||||
training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
|
||||
print(f"Training set size = {len(training_set)}, Validation set size {len(validate_set)}")
|
||||
training_loader = DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True)
|
||||
validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
|
||||
model = None
|
||||
|
|
@ -85,8 +92,9 @@ def main():
|
|||
print("Loading the model...")
|
||||
model = torch.load(args.model_path)
|
||||
|
||||
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer()
|
||||
trainer: Trainer = OptunaTrainer(n_trials=3 if args.debug else None) if args.method == "optuna" else FullTrainer()
|
||||
|
||||
print("Training")
|
||||
trainer.execute(
|
||||
model=model,
|
||||
train_loader=training_loader,
|
||||
|
|
|
|||
Reference in a new issue