mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
transformer_lm: add --dataset enwik8 (#838)
* transformer_lm: add --dataset enwik8 * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
df6bc09d74
commit
7979b84a9e
@ -10,7 +10,9 @@ import numpy as np
|
||||
|
||||
|
||||
def load_dataset(dataname):
|
||||
if dataname == "ptb":
|
||||
if dataname == "enwik8":
|
||||
return enwik8()
|
||||
elif dataname == "ptb":
|
||||
return ptb()
|
||||
elif dataname == "wikitext2":
|
||||
return wikitext(dataset="2")
|
||||
@ -87,7 +89,37 @@ def ptb(save_dir="/tmp"):
|
||||
return _load(save_dir, filenames)
|
||||
|
||||
|
||||
def enwik8(save_dir="/tmp"):
|
||||
"""
|
||||
Load the enwik8 language modeling dataset:
|
||||
https://mattmahoney.net/dc/textdata.html
|
||||
"""
|
||||
out_file = os.path.join(save_dir, "enwik8.zip")
|
||||
if not os.path.exists(out_file):
|
||||
request.urlretrieve("http://mattmahoney.net/dc/enwik8.zip", out_file)
|
||||
|
||||
with zipfile.ZipFile(out_file) as zf:
|
||||
data = zf.read("enwik8")
|
||||
|
||||
num_test_bytes = 5000000 # 90 + 5 + 5 split
|
||||
|
||||
train_data = data[: -2 * num_test_bytes]
|
||||
valid_data = data[-2 * num_test_bytes : -num_test_bytes]
|
||||
test_data = data[-num_test_bytes:]
|
||||
|
||||
vocab = set(c for c in train_data)
|
||||
vocab = {c: i for i, c in enumerate(vocab)}
|
||||
|
||||
def to_array(dataset):
|
||||
return np.array([vocab[c] for c in dataset], dtype=np.uint32)
|
||||
|
||||
return vocab, to_array(train_data), to_array(valid_data), to_array(test_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
vocab, train, val, test = enwik8()
|
||||
assert len(vocab) == 205, "enwik8: Wrong vocab size"
|
||||
|
||||
vocab, train, val, test = ptb()
|
||||
assert len(vocab) == 10000, "PTB: Wrong vocab size"
|
||||
|
||||
|
@ -157,7 +157,7 @@ if __name__ == "__main__":
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="ptb",
|
||||
choices=["ptb", "wikitext2", "wikitext103"],
|
||||
choices=["enwik8", "ptb", "wikitext2", "wikitext103"],
|
||||
help="Dataset to train and evaluate on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
Loading…
Reference in New Issue
Block a user