From 8fd953ee2b8109cb1764d4df38f9a1479f9f0a0a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 19 Feb 2024 20:37:15 -0800 Subject: [PATCH] Support for slerp merging models (#455) * support for slerp merging models * docs * update docs * format' --- llms/README.md | 16 +-- llms/mlx_lm/LORA.md | 2 + llms/mlx_lm/MERGE.md | 50 ++++++++ llms/mlx_lm/SERVER.md | 63 ++++++++++ llms/mlx_lm/examples/merge_config.yaml | 11 ++ llms/mlx_lm/merge.py | 158 +++++++++++++++++++++++++ llms/mlx_lm/models/llama.py | 4 + llms/mlx_lm/models/mixtral.py | 5 +- llms/mlx_lm/models/olmo.py | 4 + llms/mlx_lm/models/phi.py | 4 + llms/mlx_lm/models/phixtral.py | 4 + llms/mlx_lm/models/qwen2.py | 4 + llms/mlx_lm/models/stablelm_epoch.py | 4 + llms/mlx_lm/requirements.txt | 1 + llms/setup.py | 2 +- lora/lora.py | 12 +- 16 files changed, 329 insertions(+), 15 deletions(-) create mode 100644 llms/mlx_lm/MERGE.md create mode 100644 llms/mlx_lm/SERVER.md create mode 100644 llms/mlx_lm/examples/merge_config.yaml create mode 100644 llms/mlx_lm/merge.py diff --git a/llms/README.md b/llms/README.md index 4a9e0831..a15d00c8 100644 --- a/llms/README.md +++ b/llms/README.md @@ -14,9 +14,11 @@ pip install mlx-lm conda install -c conda-forge mlx-lm ``` -The `mlx-lm` package also supports LoRA and QLoRA fine-tuning. For more details -on this see the [LoRA -documentation](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md). +The `mlx-lm` package also has: + +- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md) +- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md) +- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md) ### Python API @@ -25,7 +27,7 @@ You can use `mlx-lm` as a module: ```python from mlx_lm import load, generate -model, tokenizer = load("mistralai/Mistral-7B-v0.1") +model, tokenizer = load("mistralai/Mistral-7B-Instruct-v0.1") response = generate(model, tokenizer, prompt="hello", verbose=True) ``` @@ -44,7 +46,7 @@ You can convert models in the Python API with: ```python from mlx_lm import convert -upload_repo = "mlx-community/My-Mistral-7B-v0.1-4bit" +upload_repo = "mistralai/Mistral-7B-Instruct-v0.1" convert("mistralai/Mistral-7B-v0.1", quantize=True, upload_repo=upload_repo) ``` @@ -64,7 +66,7 @@ To see a description of all the arguments you can do: You can also use `mlx-lm` from the command line with: ``` -python -m mlx_lm.generate --model mistralai/Mistral-7B-v0.1 --prompt "hello" +python -m mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.1 --prompt "hello" ``` This will download a Mistral 7B model from the Hugging Face Hub and generate @@ -79,7 +81,7 @@ python -m mlx_lm.generate --help To quantize a model from the command line run: ``` -python -m mlx_lm.convert --hf-path mistralai/Mistral-7B-v0.1 -q +python -m mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.1 -q ``` For more options run: diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index b8377f88..445b929a 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -8,6 +8,8 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families: - Llama - Phi2 - Mixtral +- Qwen2 +- OLMo ## Contents diff --git a/llms/mlx_lm/MERGE.md b/llms/mlx_lm/MERGE.md new file mode 100644 index 00000000..2ee2414c --- /dev/null +++ b/llms/mlx_lm/MERGE.md @@ -0,0 +1,50 @@ +# Model Merging + +You can use `mlx-lm` to merge models and upload them to the Hugging +Face hub or save them locally for LoRA fine tuning. + +The main command is `mlx_lm.merge`: + +```shell +python -m mlx_lm.merge --config config.yaml +``` + +The merged model will be saved by default in `mlx_merged_model`. To see a +full list of options run: + +```shell +python -m mlx_lm.merge --help +``` + +Here is an example `config.yaml`: + +```yaml +models: + - OpenPipe/mistral-ft-optimized-1218 + - mlabonne/NeuralHermes-2.5-Mistral-7B +method: slerp +parameters: + t: + - filter: self_attn + value: [0, 0.5, 0.3, 0.7, 1] + - filter: mlp + value: [1, 0.5, 0.7, 0.3, 0] + - value: 0.5 +``` + +The `models` field is a list of Hugging Face repo ids. The first model in the +list is treated as the base model into which the remaining models are merged. + +The `method` field is the merging method. Right now `slerp` is the only +supported method. + +The `parameters` are the corresponding parameters for the given `method`. +Each parameter is a list with `filter` determining which layer the parameter +applies to and `value` determining the actual value used. The last item in +the list without a `filter` field is the default. + +If `value` is a list, it specifies the start and end values for the +corresponding segment of blocks. In the example above, the models have 32 +blocks. For blocks 1-8, the layers with `self_attn` in the name will use the +values `np.linspace(0, 0.5, 8)`, the same layers in the next 8 blocks (9-16) +will use `np.linspace(0.5, 0.3, 8)`, and so on. diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md new file mode 100644 index 00000000..1176951d --- /dev/null +++ b/llms/mlx_lm/SERVER.md @@ -0,0 +1,63 @@ +# HTTP Model Server + +You use `mlx-lm` to make an HTTP API for generating text with any supported +model. The HTTP API is intended to be similar to the [OpenAI chat +API](https://platform.openai.com/docs/api-reference). + +Start the server with: + +```shell +python -m mlx_lm.server --model +``` + +For example: + +```shell +python -m mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1 +``` + +This will start a text generation server on port `8080` of the `localhost` +using Mistral 7B instruct. The model will be downloaded from the provided +Hugging Face repo if it is not already in the local cache. + +To see a full list of options run: + +```shell +python -m mlx_lm.server --help +``` + +You can make a request to the model by running: + +```shell +curl localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [{"role": "user", "content": "Say this is a test!"}], + "temperature": 0.7 + }' +``` + +### Request Fields + +- `messages`: An array of message objects representing the conversation + history. Each message object should have a role (e.g. user, assistant) and + content (the message text). + +- `role_mapping`: (Optional) A dictionary to customize the role prefixes in + the generated prompt. If not provided, the default mappings are used. + +- `stop`: (Optional) An array of strings or a single string. Thesse are + sequences of tokens on which the generation should stop. + +- `max_tokens`: (Optional) An integer specifying the maximum number of tokens + to generate. Defaults to `100`. + +- `stream`: (Optional) A boolean indicating if the response should be + streamed. If true, responses are sent as they are generated. Defaults to + false. + +- `temperature`: (Optional) A float specifying the sampling temperature. + Defaults to `1.0`. + +- `top_p`: (Optional) A float specifying the nucleus sampling parameter. + Defaults to `1.0`. diff --git a/llms/mlx_lm/examples/merge_config.yaml b/llms/mlx_lm/examples/merge_config.yaml new file mode 100644 index 00000000..98701e55 --- /dev/null +++ b/llms/mlx_lm/examples/merge_config.yaml @@ -0,0 +1,11 @@ +models: + - OpenPipe/mistral-ft-optimized-1218 + - mlabonne/NeuralHermes-2.5-Mistral-7B +method: slerp +parameters: + t: + - filter: self_attn + value: [0, 0.5, 0.3, 0.7, 1] + - filter: mlp + value: [1, 0.5, 0.7, 0.3, 0] + - value: 0.5 diff --git a/llms/mlx_lm/merge.py b/llms/mlx_lm/merge.py new file mode 100644 index 00000000..999d081e --- /dev/null +++ b/llms/mlx_lm/merge.py @@ -0,0 +1,158 @@ +import argparse +import glob +import json +from pathlib import Path + +import mlx.core as mx +import numpy as np +import yaml +from mlx.utils import tree_flatten, tree_map + +from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub + + +def configure_parser() -> argparse.ArgumentParser: + """ + Configures and returns the argument parser for the script. + + Returns: + argparse.ArgumentParser: Configured argument parser. + """ + parser = argparse.ArgumentParser(description="Merge multiple models.") + + parser.add_argument("--config", type=str, help="Path to the YAML config.") + parser.add_argument( + "--mlx-path", + type=str, + default="mlx_merged_model", + help="Path to save the MLX model.", + ) + parser.add_argument( + "--upload-repo", + help="The Hugging Face repo to upload the model to.", + type=str, + default=None, + ) + return parser + + +def slerp(t, w1, w2, eps=1e-5): + """ + Spherical linear interpolation + + Args: + t (float): Interpolation weight in [0.0, 1.0] + w1 (mx.array): First input + w2 (mx.array): Second input + eps (float): Constant for numerical stability + Returns: + mx.array: Interpolated result + """ + t = float(t) + if t == 0: + return w1 + elif t == 1: + return w2 + # Normalize + v1 = w1 / mx.linalg.norm(w1) + v2 = w2 / mx.linalg.norm(w2) + # Angle + dot = mx.clip((v1 * v2).sum(), 0.0, 1.0) + theta = mx.arccos(dot) + sin_theta = mx.sin(theta + eps) + s1 = mx.sin(theta * (1 - t)) / sin_theta + s2 = mx.sin(theta * t) / sin_theta + return s1 * w1 + s2 * w2 + + +def merge_models(base_model, model, config): + method = config.get("method", None) + if method != "slerp": + raise ValueError(f"Merge method {method} not supported") + + num_layers = len(model.layers) + + def unpack_values(vals): + if isinstance(vals, (int, float)): + return np.full(num_layers, vals) + bins = len(vals) - 1 + sizes = [num_layers // bins] * bins + sizes[-1] = num_layers - sum(sizes[:-1]) + return np.concatenate( + [np.linspace(v1, v2, s) for v1, v2, s in zip(vals[:-1], vals[1:], sizes)] + ) + + param_list = config["parameters"]["t"] + params = {} + filter_keys = set() + for pl in param_list[:-1]: + params[pl["filter"]] = unpack_values(pl["value"]) + filter_keys.add(pl["filter"]) + default = unpack_values(param_list[-1]["value"]) + + for e in range(num_layers): + bl = base_model.layers[e] + l = model.layers[e] + base_weights = bl.parameters() + weights = l.parameters() + for k, w1 in base_weights.items(): + w2 = weights[k] + t = params.get(k, default)[e] + base_weights[k] = tree_map(lambda x, y: slerp(t, x, y), w1, w2) + base_model.update(base_weights) + + +def merge( + config: str, + mlx_path: str = "mlx_model", + upload_repo: str = None, +): + with open(config, "r") as fid: + merge_conf = yaml.safe_load(fid) + print("[INFO] Loading") + + model_paths = merge_conf.get("models", []) + if len(model_paths) < 2: + raise ValueError(f"Expected at least 2 models, got {len(models)}.") + + # Load all models + base_hf_path = model_paths[0] + base_path = get_model_path(base_hf_path) + base_model, base_config, tokenizer = fetch_from_hub(base_path) + models = [] + for mp in model_paths[1:]: + model, config, _ = fetch_from_hub(get_model_path(mp)) + base_type = base_config["model_type"] + model_type = config["model_type"] + if base_type != model_type: + raise ValueError( + f"Can only merge models of the same type," + f" but got {base_type} and {model_type}." + ) + models.append(model) + + # Merge models into base model + for m in models: + merge_models(base_model, m, merge_conf) + + # Save base model + mlx_path = Path(mlx_path) + weights = dict(tree_flatten(base_model.parameters())) + save_weights(mlx_path, weights) + py_files = glob.glob(str(base_path / "*.py")) + for file in py_files: + shutil.copy(file, mlx_path) + + tokenizer.save_pretrained(mlx_path) + + with open(mlx_path / "config.json", "w") as fid: + json.dump(base_config, fid, indent=4) + + if upload_repo is not None: + upload_to_hub(mlx_path, upload_repo, base_hf_path) + + +if __name__ == "__main__": + parser = configure_parser() + args = parser.parse_args() + merge(**vars(args)) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index f9f96525..a38db95b 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -205,3 +205,7 @@ class Model(nn.Module): return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index a2ff0d06..f584d509 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -11,7 +11,6 @@ from .base import BaseModelArgs @dataclass class ModelArgs(BaseModelArgs): model_type: str - vocab_size: int vocab_size: int = 32000 max_position_embeddings: int = 4096 * 32 hidden_size: int = 4096 @@ -260,3 +259,7 @@ class Model(nn.Module): ): out, cache = self.model(inputs, cache) return self.lm_head(out), cache + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 0a2c9c0d..f9fe1475 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -178,3 +178,7 @@ class Model(nn.Module): cache=None, ): return self.model(inputs, cache) + + @property + def layers(self): + return self.model.transformer.blocks diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index ce8c226d..84c5d4f9 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -178,3 +178,7 @@ class Model(nn.Module): y, cache = self.model(x, mask, cache) return self.lm_head(y), cache + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 14ef5d45..9cd23997 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -216,3 +216,7 @@ class Model(nn.Module): y, cache = self.transformer(x, mask, cache) return self.lm_head(y), cache + + @property + def layers(self): + return self.transformer.h diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index f3f868ad..59aa6918 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -205,3 +205,7 @@ class Model(nn.Module): return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/models/stablelm_epoch.py b/llms/mlx_lm/models/stablelm_epoch.py index 2d492295..2f88bd03 100644 --- a/llms/mlx_lm/models/stablelm_epoch.py +++ b/llms/mlx_lm/models/stablelm_epoch.py @@ -184,3 +184,7 @@ class Model(nn.Module): y, cache = self.model(x, mask, cache) return self.lm_head(y), cache + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index defc3e78..420aa822 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -2,3 +2,4 @@ mlx>=0.1 numpy transformers>=4.37.0 protobuf +pyyaml diff --git a/llms/setup.py b/llms/setup.py index abdf21cd..c29a08a6 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -8,7 +8,7 @@ with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid: requirements = [str(r) for r in pkg_resources.parse_requirements(fid)] setup( name="mlx-lm", - version="0.0.11", + version="0.0.12", description="LLMs on Apple silicon with MLX and the Hugging Face Hub", long_description=open("README.md", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/lora/lora.py b/lora/lora.py index 1d847e97..64e8d671 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -139,12 +139,12 @@ class Dataset: def load(args): def load_and_check(name): - dataset_path = Path(args.data) / f"{name}.jsonl" - try: - train = Dataset(dataset_path) - except Exception as e: - print(f"Unable to build dataset {dataset_path} ({e})") - raise + dataset_path = Path(args.data) / f"{name}.jsonl" + try: + train = Dataset(dataset_path) + except Exception as e: + print(f"Unable to build dataset {dataset_path} ({e})") + raise names = ("train", "valid", "test") train, valid, test = (load_and_check(n) for n in names)