transformer_lm: add --dataset enwik8 (#838)

* transformer_lm: add --dataset enwik8

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Volodymyr Kyrylov 2024-06-26 20:59:01 +02:00 committed by GitHub
parent df6bc09d74
commit 7979b84a9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 2 deletions

View File

@ -10,7 +10,9 @@ import numpy as np
def load_dataset(dataname):
if dataname == "ptb":
if dataname == "enwik8":
return enwik8()
elif dataname == "ptb":
return ptb()
elif dataname == "wikitext2":
return wikitext(dataset="2")
@ -87,7 +89,37 @@ def ptb(save_dir="/tmp"):
return _load(save_dir, filenames)
def enwik8(save_dir="/tmp"):
"""
Load the enwik8 language modeling dataset:
https://mattmahoney.net/dc/textdata.html
"""
out_file = os.path.join(save_dir, "enwik8.zip")
if not os.path.exists(out_file):
request.urlretrieve("http://mattmahoney.net/dc/enwik8.zip", out_file)
with zipfile.ZipFile(out_file) as zf:
data = zf.read("enwik8")
num_test_bytes = 5000000 # 90 + 5 + 5 split
train_data = data[: -2 * num_test_bytes]
valid_data = data[-2 * num_test_bytes : -num_test_bytes]
test_data = data[-num_test_bytes:]
vocab = set(c for c in train_data)
vocab = {c: i for i, c in enumerate(vocab)}
def to_array(dataset):
return np.array([vocab[c] for c in dataset], dtype=np.uint32)
return vocab, to_array(train_data), to_array(valid_data), to_array(test_data)
if __name__ == "__main__":
vocab, train, val, test = enwik8()
assert len(vocab) == 205, "enwik8: Wrong vocab size"
vocab, train, val, test = ptb()
assert len(vocab) == 10000, "PTB: Wrong vocab size"

View File

@ -157,7 +157,7 @@ if __name__ == "__main__":
"--dataset",
type=str,
default="ptb",
choices=["ptb", "wikitext2", "wikitext103"],
choices=["enwik8", "ptb", "wikitext2", "wikitext103"],
help="Dataset to train and evaluate on.",
)
parser.add_argument(