diff --git a/main.py b/main.py index 9796bf1..ae4e709 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,7 @@ from src.utils import determine_device def main(): args, print_help = parse_arguments() - device = determine_device() + device = args.device or determine_device() print(f"Running on device: {device}...") match args.mode: diff --git a/src/args.py b/src/args.py index 336dc58..af295af 100644 --- a/src/args.py +++ b/src/args.py @@ -10,6 +10,7 @@ def parse_arguments(): parser.add_argument("--verbose", "-v", action="store_true", required=False, help="Enable verbose mode") parser.add_argument("--results", type=str, required=True, help="path to save graphs to") + parser.add_argument("--device", required=False, help="Override the device to use") dataparser = ArgumentParser(add_help=False) dataparser.add_argument("--data-root", type=str, required=False)