From df6bc09d7471f05f7aec69b2bfa54290a60b22af Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Wed, 26 Jun 2024 13:20:50 -0400 Subject: [PATCH 1/5] Configuration-based use of HF hub-hosted datasets for training (#701) * Add hf_dataset configuration for using HF hub-hosted datasets for (Q)LoRA training * Pre-commit formatting * Fix YAML config example * Print DS info * Include name * Add hf_dataset parameter default * Remove TextHFDataset and CompletionsHFDataset and use Dataset and CompletionsDataset instead, adding a text_key constructor argument to the former (and changing it to work with a provided data structure instead of just from a JSON file), and prompt_key and completion_key arguments to the latter with defaults for backwards compatibility. * nits * update docs --------- Co-authored-by: Awni Hannun --- .circleci/config.yml | 2 +- llms/mlx_lm/LORA.md | 36 +++++++++- llms/mlx_lm/examples/lora_config.yaml | 8 +++ llms/mlx_lm/tuner/datasets.py | 99 ++++++++++++++++++++------- llms/mlx_lm/version.py | 2 +- llms/setup.py | 3 + llms/tests/test_datsets.py | 18 +++++ 7 files changed, 140 insertions(+), 28 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 556f209e..02fa1de8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -32,7 +32,7 @@ jobs: pip install --upgrade pip pip install unittest-xml-reporting cd llms/ - pip install -e . + pip install -e ".[testing]" - run: name: Run Python tests command: | diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 3d65f213..2e739d0f 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -151,9 +151,14 @@ Examples GitHub repo has an [example of the WikiSQL data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the correct format. +Datasets can be specified in `*.jsonl` files locally or loaded from Hugging +Face. + +### Local Datasets + For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a `valid.jsonl` to be in the data directory. For evaluation (`--test`), the data -loader expects a `test.jsonl` in the data directory. +loader expects a `test.jsonl` in the data directory. Currently, `*.jsonl` files support three data formats: `chat`, `completions`, and `text`. Here are three examples of these formats: @@ -199,7 +204,34 @@ Currently, `*.jsonl` files support three data formats: `chat`, Note, the format is automatically determined by the dataset. Note also, keys in each line not expected by the loader will be ignored. -For the `chat` and `completions` formats, Hugging Face [chat +### Hugging Face Datasets + +To use Hugging Face datasets, first install the `datasets` package: + +``` +pip install datasets +``` + +Specify the Hugging Face dataset arguments in a YAML config. For example: + +``` +hf_dataset: + name: "billsum" + prompt_feature: "text" + completion_feature: "summary" +``` + +- Use `prompt_feature` and `completion_feature` to specify keys for a + `completions` dataset. Use `text_feature` to specify the key for a `text` + dataset. + +- To specify the train, valid, or test splits, set the corresponding + `{train,valid,test}_split` argument. + +- Arguments specified in `config` will be passed as keyword arguments to + [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). + +In general, for the `chat` and `completions` formats, Hugging Face [chat templates](https://huggingface.co/blog/chat-templates) are used. This applies the model's chat template by default. If the model does not have a chat template, then Hugging Face will use a default. For example, the final text in diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index d3c0d22a..073a5b6f 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -69,3 +69,11 @@ lora_parameters: # warmup: 100 # 0 for no warmup # warmup_init: 1e-7 # 0 if not specified # arguments: [1e-5, 1000, 1e-7] # passed to scheduler + +#hf_dataset: +# name: "billsum" +# train_split: "train[:1000]" +# valid_split: "train[-100:]" +# prompt_feature: "text" +# completion_feature: "summary" + diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index e5776160..3d99894c 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,20 +1,21 @@ import json from pathlib import Path +from typing import Dict, List from transformers import PreTrainedTokenizer class Dataset: """ - Light-weight wrapper to hold lines from a jsonl file + Light-weight wrapper to hold a dataset. """ - def __init__(self, path: Path): - with open(path, "r") as fid: - self._data = [json.loads(l) for l in fid] + def __init__(self, data: List[Dict[str, str]], text_key: str = "text"): + self._text_key = text_key + self._data = data def __getitem__(self, idx: int): - return self._data[idx]["text"] + return self._data[idx][self._text_key] def __len__(self): if self._data is None: @@ -28,8 +29,8 @@ class ChatDataset(Dataset): https://platform.openai.com/docs/guides/fine-tuning/example-format """ - def __init__(self, path: Path, tokenizer: PreTrainedTokenizer): - super().__init__(path) + def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): + super().__init__(data) self._tokenizer = tokenizer def __getitem__(self, idx: int): @@ -43,19 +44,28 @@ class ChatDataset(Dataset): class CompletionsDataset(Dataset): """ A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...} + or using user-provided keys for prompt and completion values https://platform.openai.com/docs/guides/fine-tuning/example-format """ - def __init__(self, path: Path, tokenizer: PreTrainedTokenizer): - super().__init__(path) + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + prompt_key: str = "prompt", + completion_key: str = "completion", + ): + super().__init__(data) self._tokenizer = tokenizer + self._prompt_key = prompt_key + self._completion_key = completion_key def __getitem__(self, idx: int): data = self._data[idx] text = self._tokenizer.apply_chat_template( [ - {"role": "user", "content": data["prompt"]}, - {"role": "assistant", "content": data["completion"]}, + {"role": "user", "content": data[self._prompt_key]}, + {"role": "assistant", "content": data[self._completion_key]}, ], tokenize=False, add_generation_prompt=True, @@ -68,14 +78,13 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None): if not path.exists(): return [] with open(path, "r") as fid: - first_line = next(fid) - first_obj = json.loads(first_line) - if "messages" in first_obj: - return ChatDataset(path, tokenizer) - elif "prompt" in first_obj and "completion" in first_obj: - return CompletionsDataset(path, tokenizer) - elif "text" in first_obj: - return Dataset(path) + data = [json.loads(l) for l in fid] + if "messages" in data[0]: + return ChatDataset(data, tokenizer) + elif "prompt" in data[0] and "completion" in data[0]: + return CompletionsDataset(data, tokenizer) + elif "text" in data[0]: + return Dataset(data) else: raise ValueError( "Unsupported data format, check the supported formats here:\n" @@ -84,11 +93,53 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None): def load_dataset(args, tokenizer: PreTrainedTokenizer): - names = ("train", "valid", "test") - data_path = Path(args.data) - train, valid, test = [ - create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names - ] + if getattr(args, "hf_dataset", None) is not None: + import datasets + + hf_args = args.hf_dataset + dataset_name = hf_args["name"] + print(f"Loading Hugging Face dataset {dataset_name}.") + text_feature = hf_args.get("text_feature") + prompt_feature = hf_args.get("prompt_feature") + completion_feature = hf_args.get("completion_feature") + + def create_hf_dataset(split: str = None): + ds = datasets.load_dataset( + dataset_name, + split=split, + **hf_args.get("config", {}), + ) + if prompt_feature and completion_feature: + return CompletionsDataset( + ds, tokenizer, prompt_feature, completion_feature + ) + elif text_feature: + return Dataset(train_ds, text_key=text_feature) + else: + raise ValueError( + "Specify either a prompt and completion feature or a text " + "feature for the Hugging Face dataset." + ) + + if args.train: + train_split = hf_args.get("train_split", "train[:80%]") + valid_split = hf_args.get("valid_split", "train[-10%:]") + train = create_hf_dataset(split=train_split) + valid = create_hf_dataset(split=valid_split) + else: + train, valid = [], [] + if args.test: + test = create_hf_dataset(split=hf_args.get("test_split")) + else: + test = [] + + else: + names = ("train", "valid", "test") + data_path = Path(args.data) + + train, valid, test = [ + create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names + ] if args.train and len(train) == 0: raise ValueError( "Training set not found or empty. Must provide training set for fine-tuning." diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 88c3e75e..40b73ede 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.15.0" +__version__ = "0.16.0" diff --git a/llms/setup.py b/llms/setup.py index 648e1e04..88deed17 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -26,6 +26,9 @@ setup( install_requires=requirements, packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"], python_requires=">=3.8", + extras_require={ + "testing": ["datasets"], + }, entry_points={ "console_scripts": [ "mlx_lm.convert = mlx_lm.convert:main", diff --git a/llms/tests/test_datsets.py b/llms/tests/test_datsets.py index 8d8c01a5..240bfb4a 100644 --- a/llms/tests/test_datsets.py +++ b/llms/tests/test_datsets.py @@ -76,6 +76,24 @@ class TestDatasets(unittest.TestCase): self.assertTrue(len(valid[0]) > 0) self.assertTrue(isinstance(train, datasets.ChatDataset)) + def test_hf(self): + args = types.SimpleNamespace( + hf_dataset={ + "name": "billsum", + "prompt_feature": "text", + "completion_feature": "summary", + }, + test=False, + train=True, + ) + tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) + train, valid, test = datasets.load_dataset(args, tokenizer) + self.assertTrue(len(train) > 0) + self.assertTrue(len(train[0]) > 0) + self.assertTrue(len(valid) > 0) + self.assertTrue(len(valid[0]) > 0) + self.assertEqual(len(test), 0) + if __name__ == "__main__": unittest.main() From 7979b84a9e46145be3874cfaa06c99fc14361dd0 Mon Sep 17 00:00:00 2001 From: Volodymyr Kyrylov Date: Wed, 26 Jun 2024 20:59:01 +0200 Subject: [PATCH 2/5] transformer_lm: add --dataset enwik8 (#838) * transformer_lm: add --dataset enwik8 * nits --------- Co-authored-by: Awni Hannun --- transformer_lm/datasets.py | 34 +++++++++++++++++++++++++++++++++- transformer_lm/main.py | 2 +- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/transformer_lm/datasets.py b/transformer_lm/datasets.py index 7b077ef3..7d6ddc0f 100644 --- a/transformer_lm/datasets.py +++ b/transformer_lm/datasets.py @@ -10,7 +10,9 @@ import numpy as np def load_dataset(dataname): - if dataname == "ptb": + if dataname == "enwik8": + return enwik8() + elif dataname == "ptb": return ptb() elif dataname == "wikitext2": return wikitext(dataset="2") @@ -87,7 +89,37 @@ def ptb(save_dir="/tmp"): return _load(save_dir, filenames) +def enwik8(save_dir="/tmp"): + """ + Load the enwik8 language modeling dataset: + https://mattmahoney.net/dc/textdata.html + """ + out_file = os.path.join(save_dir, "enwik8.zip") + if not os.path.exists(out_file): + request.urlretrieve("http://mattmahoney.net/dc/enwik8.zip", out_file) + + with zipfile.ZipFile(out_file) as zf: + data = zf.read("enwik8") + + num_test_bytes = 5000000 # 90 + 5 + 5 split + + train_data = data[: -2 * num_test_bytes] + valid_data = data[-2 * num_test_bytes : -num_test_bytes] + test_data = data[-num_test_bytes:] + + vocab = set(c for c in train_data) + vocab = {c: i for i, c in enumerate(vocab)} + + def to_array(dataset): + return np.array([vocab[c] for c in dataset], dtype=np.uint32) + + return vocab, to_array(train_data), to_array(valid_data), to_array(test_data) + + if __name__ == "__main__": + vocab, train, val, test = enwik8() + assert len(vocab) == 205, "enwik8: Wrong vocab size" + vocab, train, val, test = ptb() assert len(vocab) == 10000, "PTB: Wrong vocab size" diff --git a/transformer_lm/main.py b/transformer_lm/main.py index 044af58c..dc725cbe 100644 --- a/transformer_lm/main.py +++ b/transformer_lm/main.py @@ -157,7 +157,7 @@ if __name__ == "__main__": "--dataset", type=str, default="ptb", - choices=["ptb", "wikitext2", "wikitext103"], + choices=["enwik8", "ptb", "wikitext2", "wikitext103"], help="Dataset to train and evaluate on.", ) parser.add_argument( From 9f10728145828fba08d797b43e77b5ee7e63729f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 27 Jun 2024 06:38:19 -0700 Subject: [PATCH 3/5] fix yi (#852) --- llms/mlx_lm/tokenizer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 7e251a09..6caad629 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -120,7 +120,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): self.trim_space = trim_space # Extract the tokens in a list from id to text - self.tokenmap = [None] * len(tokenizer.vocab) + self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) for value, tokenid in tokenizer.vocab.items(): self.tokenmap[tokenid] = value From 538339b599b0fef12de0df0bfa5d3f0e85519642 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 27 Jun 2024 10:06:28 -0700 Subject: [PATCH 4/5] gemma2 (#855) --- llms/mlx_lm/models/gemma2.py | 190 +++++++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/utils.py | 1 + 2 files changed, 191 insertions(+) create mode 100644 llms/mlx_lm/models/gemma2.py diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py new file mode 100644 index 00000000..bd531c5d --- /dev/null +++ b/llms/mlx_lm/models/gemma2.py @@ -0,0 +1,190 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + head_dim: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int + rope_theta: float = 10000 + rope_traditional: bool = False + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def __call__(self, x): + return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.head_dim = head_dim = args.head_dim + + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.pre_feedforward_layernorm = RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + self.post_attention_layernorm(r) + r = self.mlp(self.pre_feedforward_layernorm(h)) + out = h + self.post_feedforward_layernorm(r) + return out + + +class GemmaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + h = h * (self.args.hidden_size**0.5) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model_type = args.model_type + self.model = GemmaModel(args) + self.args = args + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + out = self.model.embed_tokens.as_linear(out) + return out + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return self.args.head_dim + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 2614c7a5..fe9740f5 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -95,6 +95,7 @@ def linear_to_lora_layers( "qwen2", "qwen2_moe", "gemma", + "gemma2", "starcoder2", "cohere", "minicpm", From f212b770d8b5143e23102eda20400ae43340f844 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 27 Jun 2024 11:37:57 -0700 Subject: [PATCH 5/5] Server loads the model on demand from the request (#851) --- llms/mlx_lm/server.py | 91 ++++++++++++++++++++++++++++++--------- llms/tests/test_server.py | 19 +++++--- 2 files changed, 82 insertions(+), 28 deletions(-) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 19f3f46a..b53971a3 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -8,6 +8,7 @@ import uuid import warnings from functools import lru_cache from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union import mlx.core as mx @@ -81,14 +82,68 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): return prompt.rstrip() +class ModelProvider: + def __init__(self, cli_args: argparse.Namespace): + """Load models on demand and persist them across the whole process.""" + self.cli_args = cli_args + self.model_key = None + self.model = None + self.tokenizer = None + + # Preload the default model if it is provided + if self.cli_args.model is not None: + self.load("default_model") + + def _validate_model_path(self, model_path: str): + model_path = Path(model_path) + if model_path.exists() and not model_path.is_relative_to(Path.cwd()): + raise RuntimeError( + "Local models must be relative to the current working dir." + ) + + def load(self, model_path): + if self.model_key == model_path: + return self.model, self.tokenizer + + # Remove the old model if it exists. + self.model = None + self.tokenizer = None + + # Building tokenizer_config + tokenizer_config = { + "trust_remote_code": True if self.cli_args.trust_remote_code else None + } + if self.cli_args.chat_template: + tokenizer_config["chat_template"] = self.cli_args.chat_template + + if model_path == "default_model" and self.cli_args.model is not None: + model, tokenizer = load( + self.cli_args.model, + adapter_path=self.cli_args.adapter_path, + tokenizer_config=tokenizer_config, + ) + else: + self._validate_model_path(model_path) + model, tokenizer = load(model_path, tokenizer_config=tokenizer_config) + + if self.cli_args.use_default_chat_template: + if tokenizer.chat_template is None: + tokenizer.chat_template = tokenizer.default_chat_template + + self.model_key = model_path + self.model = model + self.tokenizer = tokenizer + + return self.model, self.tokenizer + + class APIHandler(BaseHTTPRequestHandler): - def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs): + def __init__(self, model_provider: ModelProvider, *args, **kwargs): """ Create static request specific metadata """ - self.model = model - self.tokenizer = tokenizer self.created = int(time.time()) + self.model_provider = model_provider super().__init__(*args, **kwargs) def _set_cors_headers(self): @@ -148,6 +203,15 @@ class APIHandler(BaseHTTPRequestHandler): self.logprobs = self.body.get("logprobs", -1) self.validate_model_parameters() + # Load the model if needed + try: + self.model, self.tokenizer = self.model_provider.load(self.requested_model) + except: + self._set_completion_headers(404) + self.end_headers() + self.wfile.write(b"Not Found") + return + # Get stop id sequences, if provided stop_words = self.body.get("stop") stop_words = stop_words or [] @@ -513,15 +577,14 @@ class APIHandler(BaseHTTPRequestHandler): def run( host: str, port: int, - model: nn.Module, - tokenizer: TokenizerWrapper, + model_provider: ModelProvider, server_class=HTTPServer, handler_class=APIHandler, ): server_address = (host, port) httpd = server_class( server_address, - lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs), + lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs), ) warnings.warn( "mlx_lm.server is not recommended for production as " @@ -536,7 +599,6 @@ def main(): parser.add_argument( "--model", type=str, - required=True, help="The path to the MLX model weights, tokenizer, and config", ) parser.add_argument( @@ -598,20 +660,7 @@ def main(): logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB") mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Building tokenizer_config - tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} - if args.chat_template: - tokenizer_config["chat_template"] = args.chat_template - - model, tokenizer = load( - args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config - ) - - if args.use_default_chat_template: - if tokenizer.chat_template is None: - tokenizer.chat_template = tokenizer.default_chat_template - - run(args.host, args.port, model, tokenizer) + run(args.host, args.port, ModelProvider(args)) if __name__ == "__main__": diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index 998ad1c7..4d71a5a3 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -7,19 +7,24 @@ from mlx_lm.server import APIHandler from mlx_lm.utils import load +class DummyModelProvider: + def __init__(self): + HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + self.model, self.tokenizer = load(HF_MODEL_PATH) + + def load(self, model): + assert model in ["default_model", "chat_model"] + return self.model, self.tokenizer + + class TestServer(unittest.TestCase): @classmethod def setUpClass(cls): - HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" - - cls.model, cls.tokenizer = load(HF_MODEL_PATH) - + cls.model_provider = DummyModelProvider() cls.server_address = ("localhost", 0) cls.httpd = http.server.HTTPServer( cls.server_address, - lambda *args, **kwargs: APIHandler( - cls.model, cls.tokenizer, *args, **kwargs - ), + lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs), ) cls.port = cls.httpd.server_port cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)