diff --git a/transformer_lm/datasets.py b/transformer_lm/datasets.py index 7b077ef3..7d6ddc0f 100644 --- a/transformer_lm/datasets.py +++ b/transformer_lm/datasets.py @@ -10,7 +10,9 @@ import numpy as np def load_dataset(dataname): - if dataname == "ptb": + if dataname == "enwik8": + return enwik8() + elif dataname == "ptb": return ptb() elif dataname == "wikitext2": return wikitext(dataset="2") @@ -87,7 +89,37 @@ def ptb(save_dir="/tmp"): return _load(save_dir, filenames) +def enwik8(save_dir="/tmp"): + """ + Load the enwik8 language modeling dataset: + https://mattmahoney.net/dc/textdata.html + """ + out_file = os.path.join(save_dir, "enwik8.zip") + if not os.path.exists(out_file): + request.urlretrieve("http://mattmahoney.net/dc/enwik8.zip", out_file) + + with zipfile.ZipFile(out_file) as zf: + data = zf.read("enwik8") + + num_test_bytes = 5000000 # 90 + 5 + 5 split + + train_data = data[: -2 * num_test_bytes] + valid_data = data[-2 * num_test_bytes : -num_test_bytes] + test_data = data[-num_test_bytes:] + + vocab = set(c for c in train_data) + vocab = {c: i for i, c in enumerate(vocab)} + + def to_array(dataset): + return np.array([vocab[c] for c in dataset], dtype=np.uint32) + + return vocab, to_array(train_data), to_array(valid_data), to_array(test_data) + + if __name__ == "__main__": + vocab, train, val, test = enwik8() + assert len(vocab) == 205, "enwik8: Wrong vocab size" + vocab, train, val, test = ptb() assert len(vocab) == 10000, "PTB: Wrong vocab size" diff --git a/transformer_lm/main.py b/transformer_lm/main.py index 044af58c..dc725cbe 100644 --- a/transformer_lm/main.py +++ b/transformer_lm/main.py @@ -157,7 +157,7 @@ if __name__ == "__main__": "--dataset", type=str, default="ptb", - choices=["ptb", "wikitext2", "wikitext103"], + choices=["enwik8", "ptb", "wikitext2", "wikitext103"], help="Dataset to train and evaluate on.", ) parser.add_argument(