mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
transformer_lm: add --dataset enwik8 (#838)
* transformer_lm: add --dataset enwik8 * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
df6bc09d74
commit
7979b84a9e
@@ -10,7 +10,9 @@ import numpy as np
|
||||
|
||||
|
||||
def load_dataset(dataname):
|
||||
if dataname == "ptb":
|
||||
if dataname == "enwik8":
|
||||
return enwik8()
|
||||
elif dataname == "ptb":
|
||||
return ptb()
|
||||
elif dataname == "wikitext2":
|
||||
return wikitext(dataset="2")
|
||||
@@ -87,7 +89,37 @@ def ptb(save_dir="/tmp"):
|
||||
return _load(save_dir, filenames)
|
||||
|
||||
|
||||
def enwik8(save_dir="/tmp"):
|
||||
"""
|
||||
Load the enwik8 language modeling dataset:
|
||||
https://mattmahoney.net/dc/textdata.html
|
||||
"""
|
||||
out_file = os.path.join(save_dir, "enwik8.zip")
|
||||
if not os.path.exists(out_file):
|
||||
request.urlretrieve("http://mattmahoney.net/dc/enwik8.zip", out_file)
|
||||
|
||||
with zipfile.ZipFile(out_file) as zf:
|
||||
data = zf.read("enwik8")
|
||||
|
||||
num_test_bytes = 5000000 # 90 + 5 + 5 split
|
||||
|
||||
train_data = data[: -2 * num_test_bytes]
|
||||
valid_data = data[-2 * num_test_bytes : -num_test_bytes]
|
||||
test_data = data[-num_test_bytes:]
|
||||
|
||||
vocab = set(c for c in train_data)
|
||||
vocab = {c: i for i, c in enumerate(vocab)}
|
||||
|
||||
def to_array(dataset):
|
||||
return np.array([vocab[c] for c in dataset], dtype=np.uint32)
|
||||
|
||||
return vocab, to_array(train_data), to_array(valid_data), to_array(test_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
vocab, train, val, test = enwik8()
|
||||
assert len(vocab) == 205, "enwik8: Wrong vocab size"
|
||||
|
||||
vocab, train, val, test = ptb()
|
||||
assert len(vocab) == 10000, "PTB: Wrong vocab size"
|
||||
|
||||
|
Reference in New Issue
Block a user