feat: Context CLI arg
This commit is contained in:
parent
cd74949b74
commit
a4583d402b
8 changed files with 38 additions and 31 deletions
|
|
@ -23,7 +23,8 @@ class Dataset(TorchDataset, ABC):
|
|||
root: str | None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = -1
|
||||
size: int = -1,
|
||||
context_length: int = 1024
|
||||
):
|
||||
"""
|
||||
:param root: Path to the dataset root directory
|
||||
|
|
@ -37,8 +38,11 @@ class Dataset(TorchDataset, ABC):
|
|||
self.split = split
|
||||
self.transform = transform
|
||||
self.size = size
|
||||
self.context_length = context_length
|
||||
self.data = None
|
||||
|
||||
print(f"Context length: {self.context_length}")
|
||||
|
||||
self.chunk_offsets: list[int] = []
|
||||
self.bytes: bytes = bytes()
|
||||
self.tensor: Tensor = torch.tensor([])
|
||||
|
|
|
|||
|
|
@ -15,9 +15,10 @@ class EnWik9DataSet(Dataset):
|
|||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable | None = None,
|
||||
size: int = -1
|
||||
size: int = -1,
|
||||
context_length: int = 1024
|
||||
):
|
||||
super().__init__('enwik9', root, split, transform, size)
|
||||
super().__init__('enwik9', root, split, transform, size, context_length)
|
||||
|
||||
print(f"Loading from HuggingFace")
|
||||
ft = Features({'text': Value('string')})
|
||||
|
|
@ -26,9 +27,6 @@ class EnWik9DataSet(Dataset):
|
|||
self.data = text_chunks['text']
|
||||
self.size = size
|
||||
|
||||
# Model uses fixed 128-length context
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
# Define splits manually, because they do not exist in the dataset
|
||||
|
|
|
|||
|
|
@ -18,18 +18,15 @@ class HumanReferenceGenomeDataset(Dataset):
|
|||
split: str = "train",
|
||||
transform: Callable = None,
|
||||
size: int = -1,
|
||||
context_length: int = 1024,
|
||||
config: str = "6kbp",
|
||||
):
|
||||
super().__init__("human_reference_genome", root, split, transform, size)
|
||||
super().__init__("human_reference_genome", root, split, transform, size, context_length)
|
||||
|
||||
print(f"Loading from HuggingFace (config: {config}, split: {split})")
|
||||
ds = load_dataset("InstaDeepAI/human_reference_genome", config, split=split,
|
||||
data = load_dataset("InstaDeepAI/human_reference_genome", config, split=split,
|
||||
cache_dir=self.root, trust_remote_code=True)
|
||||
|
||||
# Your Dataset.process_data() expects a list[str]; use the 'sequence' field
|
||||
self.data = ds["sequence"]
|
||||
|
||||
self.context_length = 2048
|
||||
self.data = data["sequence"]
|
||||
|
||||
self.process_data()
|
||||
|
||||
|
|
|
|||
|
|
@ -12,17 +12,16 @@ class LoremIpsumDataset(Dataset):
|
|||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = 2**30
|
||||
size: int = 2**30,
|
||||
context_length: int = 1024
|
||||
):
|
||||
super().__init__('lorem_ipsum', root, split, transform, size)
|
||||
super().__init__('lorem_ipsum', root, split, transform, size, context_length)
|
||||
|
||||
_lorem = TextLorem()
|
||||
|
||||
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
|
||||
self.size = size
|
||||
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@ class OpenGenomeDataset(Dataset):
|
|||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = -1,
|
||||
context_length: int = 1024,
|
||||
stage: str = 'stage2'
|
||||
):
|
||||
super().__init__('open_genome', root, split, transform, size)
|
||||
super().__init__('open_genome', root, split, transform, size, context_length)
|
||||
|
||||
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
|
||||
ft = Features({'text': Value('string')})
|
||||
|
|
@ -29,9 +30,6 @@ class OpenGenomeDataset(Dataset):
|
|||
self.data = data['text']
|
||||
self.size = size
|
||||
|
||||
# Model uses fixed 128-length context
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
|
|
|||
Reference in a new issue