Merge branch 'ml-explore:main' into completion_only

This commit is contained in:
Chime Ogbuji
2024-06-28 18:55:15 -04:00
committed by GitHub
14 changed files with 448 additions and 59 deletions

View File

@@ -32,7 +32,7 @@ jobs:
pip install --upgrade pip pip install --upgrade pip
pip install unittest-xml-reporting pip install unittest-xml-reporting
cd llms/ cd llms/
pip install -e . pip install -e ".[testing]"
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |

View File

@@ -151,9 +151,14 @@ Examples GitHub repo has an [example of the WikiSQL
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
correct format. correct format.
Datasets can be specified in `*.jsonl` files locally or loaded from Hugging
Face.
### Local Datasets
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data `valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
loader expects a `test.jsonl` in the data directory. loader expects a `test.jsonl` in the data directory.
Currently, `*.jsonl` files support three data formats: `chat`, Currently, `*.jsonl` files support three data formats: `chat`,
`completions`, and `text`. Here are three examples of these formats: `completions`, and `text`. Here are three examples of these formats:
@@ -199,7 +204,34 @@ Currently, `*.jsonl` files support three data formats: `chat`,
Note, the format is automatically determined by the dataset. Note also, keys in Note, the format is automatically determined by the dataset. Note also, keys in
each line not expected by the loader will be ignored. each line not expected by the loader will be ignored.
For the `chat` and `completions` formats, Hugging Face [chat ### Hugging Face Datasets
To use Hugging Face datasets, first install the `datasets` package:
```
pip install datasets
```
Specify the Hugging Face dataset arguments in a YAML config. For example:
```
hf_dataset:
name: "billsum"
prompt_feature: "text"
completion_feature: "summary"
```
- Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text`
dataset.
- To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument.
- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
In general, for the `chat` and `completions` formats, Hugging Face [chat
templates](https://huggingface.co/blog/chat-templates) are used. This applies templates](https://huggingface.co/blog/chat-templates) are used. This applies
the model's chat template by default. If the model does not have a chat the model's chat template by default. If the model does not have a chat
template, then Hugging Face will use a default. For example, the final text in template, then Hugging Face will use a default. For example, the final text in

View File

@@ -69,3 +69,11 @@ lora_parameters:
# warmup: 100 # 0 for no warmup # warmup: 100 # 0 for no warmup
# warmup_init: 1e-7 # 0 if not specified # warmup_init: 1e-7 # 0 if not specified
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler # arguments: [1e-5, 1000, 1e-7] # passed to scheduler
#hf_dataset:
# name: "billsum"
# train_split: "train[:1000]"
# valid_split: "train[-100:]"
# prompt_feature: "text"
# completion_feature: "summary"

View File

@@ -0,0 +1,190 @@
from dataclasses import dataclass
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
head_dim: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
rope_theta: float = 10000
rope_traditional: bool = False
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.pre_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + self.post_attention_layernorm(r)
r = self.mlp(self.pre_feedforward_layernorm(h))
out = h + self.post_feedforward_layernorm(r)
return out
class GemmaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = GemmaModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
return out
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -8,6 +8,7 @@ import uuid
import warnings import warnings
from functools import lru_cache from functools import lru_cache
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
@@ -81,14 +82,68 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip() return prompt.rstrip()
class ModelProvider:
def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process."""
self.cli_args = cli_args
self.model_key = None
self.model = None
self.tokenizer = None
# Preload the default model if it is provided
if self.cli_args.model is not None:
self.load("default_model")
def _validate_model_path(self, model_path: str):
model_path = Path(model_path)
if model_path.exists() and not model_path.is_relative_to(Path.cwd()):
raise RuntimeError(
"Local models must be relative to the current working dir."
)
def load(self, model_path):
if self.model_key == model_path:
return self.model, self.tokenizer
# Remove the old model if it exists.
self.model = None
self.tokenizer = None
# Building tokenizer_config
tokenizer_config = {
"trust_remote_code": True if self.cli_args.trust_remote_code else None
}
if self.cli_args.chat_template:
tokenizer_config["chat_template"] = self.cli_args.chat_template
if model_path == "default_model" and self.cli_args.model is not None:
model, tokenizer = load(
self.cli_args.model,
adapter_path=self.cli_args.adapter_path,
tokenizer_config=tokenizer_config,
)
else:
self._validate_model_path(model_path)
model, tokenizer = load(model_path, tokenizer_config=tokenizer_config)
if self.cli_args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
self.model_key = model_path
self.model = model
self.tokenizer = tokenizer
return self.model, self.tokenizer
class APIHandler(BaseHTTPRequestHandler): class APIHandler(BaseHTTPRequestHandler):
def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs): def __init__(self, model_provider: ModelProvider, *args, **kwargs):
""" """
Create static request specific metadata Create static request specific metadata
""" """
self.model = model
self.tokenizer = tokenizer
self.created = int(time.time()) self.created = int(time.time())
self.model_provider = model_provider
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _set_cors_headers(self): def _set_cors_headers(self):
@@ -148,6 +203,15 @@ class APIHandler(BaseHTTPRequestHandler):
self.logprobs = self.body.get("logprobs", -1) self.logprobs = self.body.get("logprobs", -1)
self.validate_model_parameters() self.validate_model_parameters()
# Load the model if needed
try:
self.model, self.tokenizer = self.model_provider.load(self.requested_model)
except:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
return
# Get stop id sequences, if provided # Get stop id sequences, if provided
stop_words = self.body.get("stop") stop_words = self.body.get("stop")
stop_words = stop_words or [] stop_words = stop_words or []
@@ -513,15 +577,14 @@ class APIHandler(BaseHTTPRequestHandler):
def run( def run(
host: str, host: str,
port: int, port: int,
model: nn.Module, model_provider: ModelProvider,
tokenizer: TokenizerWrapper,
server_class=HTTPServer, server_class=HTTPServer,
handler_class=APIHandler, handler_class=APIHandler,
): ):
server_address = (host, port) server_address = (host, port)
httpd = server_class( httpd = server_class(
server_address, server_address,
lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs), lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
) )
warnings.warn( warnings.warn(
"mlx_lm.server is not recommended for production as " "mlx_lm.server is not recommended for production as "
@@ -536,7 +599,6 @@ def main():
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
required=True,
help="The path to the MLX model weights, tokenizer, and config", help="The path to the MLX model weights, tokenizer, and config",
) )
parser.add_argument( parser.add_argument(
@@ -598,20 +660,7 @@ def main():
logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB") logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB")
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Building tokenizer_config run(args.host, args.port, ModelProvider(args))
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.chat_template:
tokenizer_config["chat_template"] = args.chat_template
model, tokenizer = load(
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
run(args.host, args.port, model, tokenizer)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -120,7 +120,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
self.trim_space = trim_space self.trim_space = trim_space
# Extract the tokens in a list from id to text # Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab) self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
for value, tokenid in tokenizer.vocab.items(): for value, tokenid in tokenizer.vocab.items():
self.tokenmap[tokenid] = value self.tokenmap[tokenid] = value

View File

@@ -1,20 +1,21 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, List
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
class Dataset: class Dataset:
""" """
Light-weight wrapper to hold lines from a jsonl file Light-weight wrapper to hold a dataset.
""" """
def __init__(self, path: Path): def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
with open(path, "r") as fid: self._text_key = text_key
self._data = [json.loads(l) for l in fid] self._data = data
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx]["text"] return self._data[idx][self._text_key]
def __len__(self): def __len__(self):
if self._data is None: if self._data is None:
@@ -28,8 +29,8 @@ class ChatDataset(Dataset):
https://platform.openai.com/docs/guides/fine-tuning/example-format https://platform.openai.com/docs/guides/fine-tuning/example-format
""" """
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer): def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(path) super().__init__(data)
self._tokenizer = tokenizer self._tokenizer = tokenizer
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
@@ -43,19 +44,28 @@ class ChatDataset(Dataset):
class CompletionsDataset(Dataset): class CompletionsDataset(Dataset):
""" """
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...} A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
or using user-provided keys for prompt and completion values
https://platform.openai.com/docs/guides/fine-tuning/example-format https://platform.openai.com/docs/guides/fine-tuning/example-format
""" """
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer): def __init__(
super().__init__(path) self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
completion_key: str = "completion",
):
super().__init__(data)
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._prompt_key = prompt_key
self._completion_key = completion_key
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
data = self._data[idx] data = self._data[idx]
text = self._tokenizer.apply_chat_template( text = self._tokenizer.apply_chat_template(
[ [
{"role": "user", "content": data["prompt"]}, {"role": "user", "content": data[self._prompt_key]},
{"role": "assistant", "content": data["completion"]}, {"role": "assistant", "content": data[self._completion_key]},
], ],
tokenize=False, tokenize=False,
add_generation_prompt=True, add_generation_prompt=True,
@@ -68,14 +78,13 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
if not path.exists(): if not path.exists():
return [] return []
with open(path, "r") as fid: with open(path, "r") as fid:
first_line = next(fid) data = [json.loads(l) for l in fid]
first_obj = json.loads(first_line) if "messages" in data[0]:
if "messages" in first_obj: return ChatDataset(data, tokenizer)
return ChatDataset(path, tokenizer) elif "prompt" in data[0] and "completion" in data[0]:
elif "prompt" in first_obj and "completion" in first_obj: return CompletionsDataset(data, tokenizer)
return CompletionsDataset(path, tokenizer) elif "text" in data[0]:
elif "text" in first_obj: return Dataset(data)
return Dataset(path)
else: else:
raise ValueError( raise ValueError(
"Unsupported data format, check the supported formats here:\n" "Unsupported data format, check the supported formats here:\n"
@@ -84,11 +93,53 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
def load_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer):
names = ("train", "valid", "test") if getattr(args, "hf_dataset", None) is not None:
data_path = Path(args.data) import datasets
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names hf_args = args.hf_dataset
] dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature = hf_args.get("text_feature")
prompt_feature = hf_args.get("prompt_feature")
completion_feature = hf_args.get("completion_feature")
def create_hf_dataset(split: str = None):
ds = datasets.load_dataset(
dataset_name,
split=split,
**hf_args.get("config", {}),
)
if prompt_feature and completion_feature:
return CompletionsDataset(
ds, tokenizer, prompt_feature, completion_feature
)
elif text_feature:
return Dataset(train_ds, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
"feature for the Hugging Face dataset."
)
if args.train:
train_split = hf_args.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]")
train = create_hf_dataset(split=train_split)
valid = create_hf_dataset(split=valid_split)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(split=hf_args.get("test_split"))
else:
test = []
else:
names = ("train", "valid", "test")
data_path = Path(args.data)
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
]
if args.train and len(train) == 0: if args.train and len(train) == 0:
raise ValueError( raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning." "Training set not found or empty. Must provide training set for fine-tuning."

View File

@@ -95,6 +95,7 @@ def linear_to_lora_layers(
"qwen2", "qwen2",
"qwen2_moe", "qwen2_moe",
"gemma", "gemma",
"gemma2",
"starcoder2", "starcoder2",
"cohere", "cohere",
"minicpm", "minicpm",

View File

@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.15.0" __version__ = "0.16.0"

View File

@@ -26,6 +26,9 @@ setup(
install_requires=requirements, install_requires=requirements,
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"], packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
python_requires=">=3.8", python_requires=">=3.8",
extras_require={
"testing": ["datasets"],
},
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
"mlx_lm.convert = mlx_lm.convert:main", "mlx_lm.convert = mlx_lm.convert:main",

View File

@@ -76,6 +76,24 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(len(valid[0]) > 0) self.assertTrue(len(valid[0]) > 0)
self.assertTrue(isinstance(train, datasets.ChatDataset)) self.assertTrue(isinstance(train, datasets.ChatDataset))
def test_hf(self):
args = types.SimpleNamespace(
hf_dataset={
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
},
test=False,
train=True,
)
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertTrue(len(train) > 0)
self.assertTrue(len(train[0]) > 0)
self.assertTrue(len(valid) > 0)
self.assertTrue(len(valid[0]) > 0)
self.assertEqual(len(test), 0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -7,19 +7,24 @@ from mlx_lm.server import APIHandler
from mlx_lm.utils import load from mlx_lm.utils import load
class DummyModelProvider:
def __init__(self):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH)
def load(self, model):
assert model in ["default_model", "chat_model"]
return self.model, self.tokenizer
class TestServer(unittest.TestCase): class TestServer(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" cls.model_provider = DummyModelProvider()
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
cls.server_address = ("localhost", 0) cls.server_address = ("localhost", 0)
cls.httpd = http.server.HTTPServer( cls.httpd = http.server.HTTPServer(
cls.server_address, cls.server_address,
lambda *args, **kwargs: APIHandler( lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs),
cls.model, cls.tokenizer, *args, **kwargs
),
) )
cls.port = cls.httpd.server_port cls.port = cls.httpd.server_port
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever) cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)

View File

@@ -10,7 +10,9 @@ import numpy as np
def load_dataset(dataname): def load_dataset(dataname):
if dataname == "ptb": if dataname == "enwik8":
return enwik8()
elif dataname == "ptb":
return ptb() return ptb()
elif dataname == "wikitext2": elif dataname == "wikitext2":
return wikitext(dataset="2") return wikitext(dataset="2")
@@ -87,7 +89,37 @@ def ptb(save_dir="/tmp"):
return _load(save_dir, filenames) return _load(save_dir, filenames)
def enwik8(save_dir="/tmp"):
"""
Load the enwik8 language modeling dataset:
https://mattmahoney.net/dc/textdata.html
"""
out_file = os.path.join(save_dir, "enwik8.zip")
if not os.path.exists(out_file):
request.urlretrieve("http://mattmahoney.net/dc/enwik8.zip", out_file)
with zipfile.ZipFile(out_file) as zf:
data = zf.read("enwik8")
num_test_bytes = 5000000 # 90 + 5 + 5 split
train_data = data[: -2 * num_test_bytes]
valid_data = data[-2 * num_test_bytes : -num_test_bytes]
test_data = data[-num_test_bytes:]
vocab = set(c for c in train_data)
vocab = {c: i for i, c in enumerate(vocab)}
def to_array(dataset):
return np.array([vocab[c] for c in dataset], dtype=np.uint32)
return vocab, to_array(train_data), to_array(valid_data), to_array(test_data)
if __name__ == "__main__": if __name__ == "__main__":
vocab, train, val, test = enwik8()
assert len(vocab) == 205, "enwik8: Wrong vocab size"
vocab, train, val, test = ptb() vocab, train, val, test = ptb()
assert len(vocab) == 10000, "PTB: Wrong vocab size" assert len(vocab) == 10000, "PTB: Wrong vocab size"

View File

@@ -157,7 +157,7 @@ if __name__ == "__main__":
"--dataset", "--dataset",
type=str, type=str,
default="ptb", default="ptb",
choices=["ptb", "wikitext2", "wikitext103"], choices=["enwik8", "ptb", "wikitext2", "wikitext103"],
help="Dataset to train and evaluate on.", help="Dataset to train and evaluate on.",
) )
parser.add_argument( parser.add_argument(