Merge branch 'ml-explore:main' into completion_only

This commit is contained in:
Chime Ogbuji 2024-06-28 18:55:15 -04:00 committed by GitHub
commit 1929f5351c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 448 additions and 59 deletions

View File

@ -32,7 +32,7 @@ jobs:
pip install --upgrade pip
pip install unittest-xml-reporting
cd llms/
pip install -e .
pip install -e ".[testing]"
- run:
name: Run Python tests
command: |

View File

@ -151,6 +151,11 @@ Examples GitHub repo has an [example of the WikiSQL
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
correct format.
Datasets can be specified in `*.jsonl` files locally or loaded from Hugging
Face.
### Local Datasets
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
loader expects a `test.jsonl` in the data directory.
@ -199,7 +204,34 @@ Currently, `*.jsonl` files support three data formats: `chat`,
Note, the format is automatically determined by the dataset. Note also, keys in
each line not expected by the loader will be ignored.
For the `chat` and `completions` formats, Hugging Face [chat
### Hugging Face Datasets
To use Hugging Face datasets, first install the `datasets` package:
```
pip install datasets
```
Specify the Hugging Face dataset arguments in a YAML config. For example:
```
hf_dataset:
name: "billsum"
prompt_feature: "text"
completion_feature: "summary"
```
- Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text`
dataset.
- To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument.
- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
In general, for the `chat` and `completions` formats, Hugging Face [chat
templates](https://huggingface.co/blog/chat-templates) are used. This applies
the model's chat template by default. If the model does not have a chat
template, then Hugging Face will use a default. For example, the final text in

View File

@ -69,3 +69,11 @@ lora_parameters:
# warmup: 100 # 0 for no warmup
# warmup_init: 1e-7 # 0 if not specified
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler
#hf_dataset:
# name: "billsum"
# train_split: "train[:1000]"
# valid_split: "train[-100:]"
# prompt_feature: "text"
# completion_feature: "summary"

View File

@ -0,0 +1,190 @@
from dataclasses import dataclass
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
head_dim: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
rope_theta: float = 10000
rope_traditional: bool = False
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.pre_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + self.post_attention_layernorm(r)
r = self.mlp(self.pre_feedforward_layernorm(h))
out = h + self.post_feedforward_layernorm(r)
return out
class GemmaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = GemmaModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
return out
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -8,6 +8,7 @@ import uuid
import warnings
from functools import lru_cache
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
import mlx.core as mx
@ -81,14 +82,68 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip()
class ModelProvider:
def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process."""
self.cli_args = cli_args
self.model_key = None
self.model = None
self.tokenizer = None
# Preload the default model if it is provided
if self.cli_args.model is not None:
self.load("default_model")
def _validate_model_path(self, model_path: str):
model_path = Path(model_path)
if model_path.exists() and not model_path.is_relative_to(Path.cwd()):
raise RuntimeError(
"Local models must be relative to the current working dir."
)
def load(self, model_path):
if self.model_key == model_path:
return self.model, self.tokenizer
# Remove the old model if it exists.
self.model = None
self.tokenizer = None
# Building tokenizer_config
tokenizer_config = {
"trust_remote_code": True if self.cli_args.trust_remote_code else None
}
if self.cli_args.chat_template:
tokenizer_config["chat_template"] = self.cli_args.chat_template
if model_path == "default_model" and self.cli_args.model is not None:
model, tokenizer = load(
self.cli_args.model,
adapter_path=self.cli_args.adapter_path,
tokenizer_config=tokenizer_config,
)
else:
self._validate_model_path(model_path)
model, tokenizer = load(model_path, tokenizer_config=tokenizer_config)
if self.cli_args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
self.model_key = model_path
self.model = model
self.tokenizer = tokenizer
return self.model, self.tokenizer
class APIHandler(BaseHTTPRequestHandler):
def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs):
def __init__(self, model_provider: ModelProvider, *args, **kwargs):
"""
Create static request specific metadata
"""
self.model = model
self.tokenizer = tokenizer
self.created = int(time.time())
self.model_provider = model_provider
super().__init__(*args, **kwargs)
def _set_cors_headers(self):
@ -148,6 +203,15 @@ class APIHandler(BaseHTTPRequestHandler):
self.logprobs = self.body.get("logprobs", -1)
self.validate_model_parameters()
# Load the model if needed
try:
self.model, self.tokenizer = self.model_provider.load(self.requested_model)
except:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
return
# Get stop id sequences, if provided
stop_words = self.body.get("stop")
stop_words = stop_words or []
@ -513,15 +577,14 @@ class APIHandler(BaseHTTPRequestHandler):
def run(
host: str,
port: int,
model: nn.Module,
tokenizer: TokenizerWrapper,
model_provider: ModelProvider,
server_class=HTTPServer,
handler_class=APIHandler,
):
server_address = (host, port)
httpd = server_class(
server_address,
lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs),
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
)
warnings.warn(
"mlx_lm.server is not recommended for production as "
@ -536,7 +599,6 @@ def main():
parser.add_argument(
"--model",
type=str,
required=True,
help="The path to the MLX model weights, tokenizer, and config",
)
parser.add_argument(
@ -598,20 +660,7 @@ def main():
logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB")
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.chat_template:
tokenizer_config["chat_template"] = args.chat_template
model, tokenizer = load(
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
run(args.host, args.port, model, tokenizer)
run(args.host, args.port, ModelProvider(args))
if __name__ == "__main__":

View File

@ -120,7 +120,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
self.trim_space = trim_space
# Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab)
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
for value, tokenid in tokenizer.vocab.items():
self.tokenmap[tokenid] = value

View File

@ -1,20 +1,21 @@
import json
from pathlib import Path
from typing import Dict, List
from transformers import PreTrainedTokenizer
class Dataset:
"""
Light-weight wrapper to hold lines from a jsonl file
Light-weight wrapper to hold a dataset.
"""
def __init__(self, path: Path):
with open(path, "r") as fid:
self._data = [json.loads(l) for l in fid]
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
self._text_key = text_key
self._data = data
def __getitem__(self, idx: int):
return self._data[idx]["text"]
return self._data[idx][self._text_key]
def __len__(self):
if self._data is None:
@ -28,8 +29,8 @@ class ChatDataset(Dataset):
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
super().__init__(path)
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(data)
self._tokenizer = tokenizer
def __getitem__(self, idx: int):
@ -43,19 +44,28 @@ class ChatDataset(Dataset):
class CompletionsDataset(Dataset):
"""
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
or using user-provided keys for prompt and completion values
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
super().__init__(path)
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
completion_key: str = "completion",
):
super().__init__(data)
self._tokenizer = tokenizer
self._prompt_key = prompt_key
self._completion_key = completion_key
def __getitem__(self, idx: int):
data = self._data[idx]
text = self._tokenizer.apply_chat_template(
[
{"role": "user", "content": data["prompt"]},
{"role": "assistant", "content": data["completion"]},
{"role": "user", "content": data[self._prompt_key]},
{"role": "assistant", "content": data[self._completion_key]},
],
tokenize=False,
add_generation_prompt=True,
@ -68,14 +78,13 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
if not path.exists():
return []
with open(path, "r") as fid:
first_line = next(fid)
first_obj = json.loads(first_line)
if "messages" in first_obj:
return ChatDataset(path, tokenizer)
elif "prompt" in first_obj and "completion" in first_obj:
return CompletionsDataset(path, tokenizer)
elif "text" in first_obj:
return Dataset(path)
data = [json.loads(l) for l in fid]
if "messages" in data[0]:
return ChatDataset(data, tokenizer)
elif "prompt" in data[0] and "completion" in data[0]:
return CompletionsDataset(data, tokenizer)
elif "text" in data[0]:
return Dataset(data)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
@ -84,11 +93,53 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
def load_dataset(args, tokenizer: PreTrainedTokenizer):
names = ("train", "valid", "test")
data_path = Path(args.data)
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
]
if getattr(args, "hf_dataset", None) is not None:
import datasets
hf_args = args.hf_dataset
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature = hf_args.get("text_feature")
prompt_feature = hf_args.get("prompt_feature")
completion_feature = hf_args.get("completion_feature")
def create_hf_dataset(split: str = None):
ds = datasets.load_dataset(
dataset_name,
split=split,
**hf_args.get("config", {}),
)
if prompt_feature and completion_feature:
return CompletionsDataset(
ds, tokenizer, prompt_feature, completion_feature
)
elif text_feature:
return Dataset(train_ds, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
"feature for the Hugging Face dataset."
)
if args.train:
train_split = hf_args.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]")
train = create_hf_dataset(split=train_split)
valid = create_hf_dataset(split=valid_split)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(split=hf_args.get("test_split"))
else:
test = []
else:
names = ("train", "valid", "test")
data_path = Path(args.data)
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
]
if args.train and len(train) == 0:
raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning."

View File

@ -95,6 +95,7 @@ def linear_to_lora_layers(
"qwen2",
"qwen2_moe",
"gemma",
"gemma2",
"starcoder2",
"cohere",
"minicpm",

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.15.0"
__version__ = "0.16.0"

View File

@ -26,6 +26,9 @@ setup(
install_requires=requirements,
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
python_requires=">=3.8",
extras_require={
"testing": ["datasets"],
},
entry_points={
"console_scripts": [
"mlx_lm.convert = mlx_lm.convert:main",

View File

@ -76,6 +76,24 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(len(valid[0]) > 0)
self.assertTrue(isinstance(train, datasets.ChatDataset))
def test_hf(self):
args = types.SimpleNamespace(
hf_dataset={
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
},
test=False,
train=True,
)
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertTrue(len(train) > 0)
self.assertTrue(len(train[0]) > 0)
self.assertTrue(len(valid) > 0)
self.assertTrue(len(valid[0]) > 0)
self.assertEqual(len(test), 0)
if __name__ == "__main__":
unittest.main()

View File

@ -7,19 +7,24 @@ from mlx_lm.server import APIHandler
from mlx_lm.utils import load
class DummyModelProvider:
def __init__(self):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH)
def load(self, model):
assert model in ["default_model", "chat_model"]
return self.model, self.tokenizer
class TestServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
cls.model_provider = DummyModelProvider()
cls.server_address = ("localhost", 0)
cls.httpd = http.server.HTTPServer(
cls.server_address,
lambda *args, **kwargs: APIHandler(
cls.model, cls.tokenizer, *args, **kwargs
),
lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs),
)
cls.port = cls.httpd.server_port
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)

View File

@ -10,7 +10,9 @@ import numpy as np
def load_dataset(dataname):
if dataname == "ptb":
if dataname == "enwik8":
return enwik8()
elif dataname == "ptb":
return ptb()
elif dataname == "wikitext2":
return wikitext(dataset="2")
@ -87,7 +89,37 @@ def ptb(save_dir="/tmp"):
return _load(save_dir, filenames)
def enwik8(save_dir="/tmp"):
"""
Load the enwik8 language modeling dataset:
https://mattmahoney.net/dc/textdata.html
"""
out_file = os.path.join(save_dir, "enwik8.zip")
if not os.path.exists(out_file):
request.urlretrieve("http://mattmahoney.net/dc/enwik8.zip", out_file)
with zipfile.ZipFile(out_file) as zf:
data = zf.read("enwik8")
num_test_bytes = 5000000 # 90 + 5 + 5 split
train_data = data[: -2 * num_test_bytes]
valid_data = data[-2 * num_test_bytes : -num_test_bytes]
test_data = data[-num_test_bytes:]
vocab = set(c for c in train_data)
vocab = {c: i for i, c in enumerate(vocab)}
def to_array(dataset):
return np.array([vocab[c] for c in dataset], dtype=np.uint32)
return vocab, to_array(train_data), to_array(valid_data), to_array(test_data)
if __name__ == "__main__":
vocab, train, val, test = enwik8()
assert len(vocab) == 205, "enwik8: Wrong vocab size"
vocab, train, val, test = ptb()
assert len(vocab) == 10000, "PTB: Wrong vocab size"

View File

@ -157,7 +157,7 @@ if __name__ == "__main__":
"--dataset",
type=str,
default="ptb",
choices=["ptb", "wikitext2", "wikitext103"],
choices=["enwik8", "ptb", "wikitext2", "wikitext103"],
help="Dataset to train and evaluate on.",
)
parser.add_argument(