feat: Context CLI arg
This commit is contained in:
parent
cd74949b74
commit
a4583d402b
8 changed files with 38 additions and 31 deletions
13
main.py
13
main.py
|
|
@ -14,11 +14,13 @@ def main():
|
||||||
case 'train':
|
case 'train':
|
||||||
size = int(args.size) if args.size else None
|
size = int(args.size) if args.size else None
|
||||||
if args.method == 'optuna':
|
if args.method == 'optuna':
|
||||||
size = 2 ** 12
|
size = min(size, 2 ** 12) if size else 2 ** 12
|
||||||
print(f"Using size {size} for optuna (was {args.size})")
|
if size != args.size:
|
||||||
|
print(f"Using size {size} for optuna (was {args.size})")
|
||||||
if args.debug:
|
if args.debug:
|
||||||
size = 2 ** 10
|
size = min(size, 2 ** 10) if size else 2 ** 10
|
||||||
print(f"Using size {size} for debug (was {args.size})")
|
if size != args.size:
|
||||||
|
print(f"Using size {size} for debug (was {args.size})")
|
||||||
|
|
||||||
train(
|
train(
|
||||||
device=device,
|
device=device,
|
||||||
|
|
@ -29,7 +31,8 @@ def main():
|
||||||
method=args.method,
|
method=args.method,
|
||||||
model_name=args.model,
|
model_name=args.model,
|
||||||
model_path=args.model_load_path,
|
model_path=args.model_load_path,
|
||||||
model_out=args.model_save_path
|
model_out=args.model_save_path,
|
||||||
|
context_length=args.context
|
||||||
)
|
)
|
||||||
|
|
||||||
case 'compress':
|
case 'compress':
|
||||||
|
|
|
||||||
12
src/args.py
12
src/args.py
|
|
@ -13,6 +13,8 @@ def parse_arguments():
|
||||||
dataparser = ArgumentParser(add_help=False)
|
dataparser = ArgumentParser(add_help=False)
|
||||||
dataparser.add_argument("--data-root", type=str, required=False)
|
dataparser.add_argument("--data-root", type=str, required=False)
|
||||||
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
|
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
|
||||||
|
dataparser.add_argument("--size", "-s", type=int, required=False,
|
||||||
|
help="Size of the subset of the dataset to use")
|
||||||
|
|
||||||
modelparser = ArgumentParser(add_help=False)
|
modelparser = ArgumentParser(add_help=False)
|
||||||
modelparser.add_argument("--model", "-m", type=str, required=False,
|
modelparser.add_argument("--model", "-m", type=str, required=False,
|
||||||
|
|
@ -21,6 +23,8 @@ def parse_arguments():
|
||||||
help="Filepath to the model to load")
|
help="Filepath to the model to load")
|
||||||
modelparser.add_argument("--model-save-path", type=str, required=False,
|
modelparser.add_argument("--model-save-path", type=str, required=False,
|
||||||
help="Filepath to the model to save")
|
help="Filepath to the model to save")
|
||||||
|
modelparser.add_argument("--context", type=int, required=False,
|
||||||
|
help="Context length to use")
|
||||||
|
|
||||||
fileparser = ArgumentParser(add_help=False)
|
fileparser = ArgumentParser(add_help=False)
|
||||||
fileparser.add_argument("--input-file", "-i", required=False, type=str)
|
fileparser.add_argument("--input-file", "-i", required=False, type=str)
|
||||||
|
|
@ -35,11 +39,11 @@ def parse_arguments():
|
||||||
train_parser.add_argument("--method",
|
train_parser.add_argument("--method",
|
||||||
choices=["fetch", "optuna", "full"], required=True,
|
choices=["fetch", "optuna", "full"], required=True,
|
||||||
help="Method to use for training")
|
help="Method to use for training")
|
||||||
train_parser.add_argument("--size", "-s", type=int, required=False,
|
|
||||||
help="Size of the subset of the dataset to use")
|
|
||||||
|
|
||||||
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
|
subparsers.add_parser("compress", parents=[modelparser, fileparser],
|
||||||
|
help="Compress a file")
|
||||||
|
|
||||||
decompress_parser = subparsers.add_parser("decompress", parents=[modelparser, fileparser])
|
subparsers.add_parser("decompress", parents=[modelparser, fileparser],
|
||||||
|
help="Decompress a file")
|
||||||
|
|
||||||
return parser.parse_args(), parser.print_help
|
return parser.parse_args(), parser.print_help
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,8 @@ class Dataset(TorchDataset, ABC):
|
||||||
root: str | None,
|
root: str | None,
|
||||||
split: str = 'train',
|
split: str = 'train',
|
||||||
transform: Callable = None,
|
transform: Callable = None,
|
||||||
size: int = -1
|
size: int = -1,
|
||||||
|
context_length: int = 1024
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param root: Path to the dataset root directory
|
:param root: Path to the dataset root directory
|
||||||
|
|
@ -37,8 +38,11 @@ class Dataset(TorchDataset, ABC):
|
||||||
self.split = split
|
self.split = split
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.size = size
|
self.size = size
|
||||||
|
self.context_length = context_length
|
||||||
self.data = None
|
self.data = None
|
||||||
|
|
||||||
|
print(f"Context length: {self.context_length}")
|
||||||
|
|
||||||
self.chunk_offsets: list[int] = []
|
self.chunk_offsets: list[int] = []
|
||||||
self.bytes: bytes = bytes()
|
self.bytes: bytes = bytes()
|
||||||
self.tensor: Tensor = torch.tensor([])
|
self.tensor: Tensor = torch.tensor([])
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,10 @@ class EnWik9DataSet(Dataset):
|
||||||
root: str | None = None,
|
root: str | None = None,
|
||||||
split: str = 'train',
|
split: str = 'train',
|
||||||
transform: Callable | None = None,
|
transform: Callable | None = None,
|
||||||
size: int = -1
|
size: int = -1,
|
||||||
|
context_length: int = 1024
|
||||||
):
|
):
|
||||||
super().__init__('enwik9', root, split, transform, size)
|
super().__init__('enwik9', root, split, transform, size, context_length)
|
||||||
|
|
||||||
print(f"Loading from HuggingFace")
|
print(f"Loading from HuggingFace")
|
||||||
ft = Features({'text': Value('string')})
|
ft = Features({'text': Value('string')})
|
||||||
|
|
@ -26,9 +27,6 @@ class EnWik9DataSet(Dataset):
|
||||||
self.data = text_chunks['text']
|
self.data = text_chunks['text']
|
||||||
self.size = size
|
self.size = size
|
||||||
|
|
||||||
# Model uses fixed 128-length context
|
|
||||||
self.context_length = 128
|
|
||||||
|
|
||||||
self.process_data()
|
self.process_data()
|
||||||
|
|
||||||
# Define splits manually, because they do not exist in the dataset
|
# Define splits manually, because they do not exist in the dataset
|
||||||
|
|
|
||||||
|
|
@ -18,18 +18,15 @@ class HumanReferenceGenomeDataset(Dataset):
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: Callable = None,
|
transform: Callable = None,
|
||||||
size: int = -1,
|
size: int = -1,
|
||||||
|
context_length: int = 1024,
|
||||||
config: str = "6kbp",
|
config: str = "6kbp",
|
||||||
):
|
):
|
||||||
super().__init__("human_reference_genome", root, split, transform, size)
|
super().__init__("human_reference_genome", root, split, transform, size, context_length)
|
||||||
|
|
||||||
print(f"Loading from HuggingFace (config: {config}, split: {split})")
|
print(f"Loading from HuggingFace (config: {config}, split: {split})")
|
||||||
ds = load_dataset("InstaDeepAI/human_reference_genome", config, split=split,
|
data = load_dataset("InstaDeepAI/human_reference_genome", config, split=split,
|
||||||
cache_dir=self.root, trust_remote_code=True)
|
cache_dir=self.root, trust_remote_code=True)
|
||||||
|
self.data = data["sequence"]
|
||||||
# Your Dataset.process_data() expects a list[str]; use the 'sequence' field
|
|
||||||
self.data = ds["sequence"]
|
|
||||||
|
|
||||||
self.context_length = 2048
|
|
||||||
|
|
||||||
self.process_data()
|
self.process_data()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,17 +12,16 @@ class LoremIpsumDataset(Dataset):
|
||||||
root: str | None = None,
|
root: str | None = None,
|
||||||
split: str = 'train',
|
split: str = 'train',
|
||||||
transform: Callable = None,
|
transform: Callable = None,
|
||||||
size: int = 2**30
|
size: int = 2**30,
|
||||||
|
context_length: int = 1024
|
||||||
):
|
):
|
||||||
super().__init__('lorem_ipsum', root, split, transform, size)
|
super().__init__('lorem_ipsum', root, split, transform, size, context_length)
|
||||||
|
|
||||||
_lorem = TextLorem()
|
_lorem = TextLorem()
|
||||||
|
|
||||||
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
|
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
|
||||||
self.size = size
|
self.size = size
|
||||||
|
|
||||||
self.context_length = 128
|
|
||||||
|
|
||||||
self.process_data()
|
self.process_data()
|
||||||
|
|
||||||
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,10 @@ class OpenGenomeDataset(Dataset):
|
||||||
split: str = 'train',
|
split: str = 'train',
|
||||||
transform: Callable = None,
|
transform: Callable = None,
|
||||||
size: int = -1,
|
size: int = -1,
|
||||||
|
context_length: int = 1024,
|
||||||
stage: str = 'stage2'
|
stage: str = 'stage2'
|
||||||
):
|
):
|
||||||
super().__init__('open_genome', root, split, transform, size)
|
super().__init__('open_genome', root, split, transform, size, context_length)
|
||||||
|
|
||||||
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
|
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
|
||||||
ft = Features({'text': Value('string')})
|
ft = Features({'text': Value('string')})
|
||||||
|
|
@ -29,9 +30,6 @@ class OpenGenomeDataset(Dataset):
|
||||||
self.data = data['text']
|
self.data = data['text']
|
||||||
self.size = size
|
self.size = size
|
||||||
|
|
||||||
# Model uses fixed 128-length context
|
|
||||||
self.context_length = 128
|
|
||||||
|
|
||||||
self.process_data()
|
self.process_data()
|
||||||
|
|
||||||
print("Done initializing dataset")
|
print("Done initializing dataset")
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,13 @@ def train(
|
||||||
data_root: str,
|
data_root: str,
|
||||||
n_trials: int | None = None,
|
n_trials: int | None = None,
|
||||||
size: int | None = None,
|
size: int | None = None,
|
||||||
|
context_length: int | None = None,
|
||||||
method: str = 'optuna',
|
method: str = 'optuna',
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
model_path: str | None = None,
|
model_path: str | None = None,
|
||||||
model_out: str | None = None
|
model_out: str | None = None
|
||||||
):
|
):
|
||||||
batch_size = 2
|
batch_size = 64
|
||||||
|
|
||||||
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
|
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
|
||||||
|
|
||||||
|
|
@ -38,6 +39,9 @@ def train(
|
||||||
if size:
|
if size:
|
||||||
dataset_common_args['size'] = size
|
dataset_common_args['size'] = size
|
||||||
|
|
||||||
|
if context_length:
|
||||||
|
dataset_common_args['context_length'] = context_length
|
||||||
|
|
||||||
print("Loading in the dataset...")
|
print("Loading in the dataset...")
|
||||||
if dataset in dataset_called:
|
if dataset in dataset_called:
|
||||||
training_set = dataset_called[dataset](split='train', **dataset_common_args)
|
training_set = dataset_called[dataset](split='train', **dataset_common_args)
|
||||||
|
|
|
||||||
Reference in a new issue