mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Support for slerp merging models (#455)
* support for slerp merging models * docs * update docs * format'
This commit is contained in:
parent
8c9148a8fd
commit
8fd953ee2b
@ -14,9 +14,11 @@ pip install mlx-lm
|
|||||||
conda install -c conda-forge mlx-lm
|
conda install -c conda-forge mlx-lm
|
||||||
```
|
```
|
||||||
|
|
||||||
The `mlx-lm` package also supports LoRA and QLoRA fine-tuning. For more details
|
The `mlx-lm` package also has:
|
||||||
on this see the [LoRA
|
|
||||||
documentation](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md).
|
- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
|
||||||
|
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
||||||
|
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
||||||
|
|
||||||
### Python API
|
### Python API
|
||||||
|
|
||||||
@ -25,7 +27,7 @@ You can use `mlx-lm` as a module:
|
|||||||
```python
|
```python
|
||||||
from mlx_lm import load, generate
|
from mlx_lm import load, generate
|
||||||
|
|
||||||
model, tokenizer = load("mistralai/Mistral-7B-v0.1")
|
model, tokenizer = load("mistralai/Mistral-7B-Instruct-v0.1")
|
||||||
|
|
||||||
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
||||||
```
|
```
|
||||||
@ -44,7 +46,7 @@ You can convert models in the Python API with:
|
|||||||
```python
|
```python
|
||||||
from mlx_lm import convert
|
from mlx_lm import convert
|
||||||
|
|
||||||
upload_repo = "mlx-community/My-Mistral-7B-v0.1-4bit"
|
upload_repo = "mistralai/Mistral-7B-Instruct-v0.1"
|
||||||
|
|
||||||
convert("mistralai/Mistral-7B-v0.1", quantize=True, upload_repo=upload_repo)
|
convert("mistralai/Mistral-7B-v0.1", quantize=True, upload_repo=upload_repo)
|
||||||
```
|
```
|
||||||
@ -64,7 +66,7 @@ To see a description of all the arguments you can do:
|
|||||||
You can also use `mlx-lm` from the command line with:
|
You can also use `mlx-lm` from the command line with:
|
||||||
|
|
||||||
```
|
```
|
||||||
python -m mlx_lm.generate --model mistralai/Mistral-7B-v0.1 --prompt "hello"
|
python -m mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.1 --prompt "hello"
|
||||||
```
|
```
|
||||||
|
|
||||||
This will download a Mistral 7B model from the Hugging Face Hub and generate
|
This will download a Mistral 7B model from the Hugging Face Hub and generate
|
||||||
@ -79,7 +81,7 @@ python -m mlx_lm.generate --help
|
|||||||
To quantize a model from the command line run:
|
To quantize a model from the command line run:
|
||||||
|
|
||||||
```
|
```
|
||||||
python -m mlx_lm.convert --hf-path mistralai/Mistral-7B-v0.1 -q
|
python -m mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.1 -q
|
||||||
```
|
```
|
||||||
|
|
||||||
For more options run:
|
For more options run:
|
||||||
|
@ -8,6 +8,8 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
|||||||
- Llama
|
- Llama
|
||||||
- Phi2
|
- Phi2
|
||||||
- Mixtral
|
- Mixtral
|
||||||
|
- Qwen2
|
||||||
|
- OLMo
|
||||||
|
|
||||||
## Contents
|
## Contents
|
||||||
|
|
||||||
|
50
llms/mlx_lm/MERGE.md
Normal file
50
llms/mlx_lm/MERGE.md
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# Model Merging
|
||||||
|
|
||||||
|
You can use `mlx-lm` to merge models and upload them to the Hugging
|
||||||
|
Face hub or save them locally for LoRA fine tuning.
|
||||||
|
|
||||||
|
The main command is `mlx_lm.merge`:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.merge --config config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
The merged model will be saved by default in `mlx_merged_model`. To see a
|
||||||
|
full list of options run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.merge --help
|
||||||
|
```
|
||||||
|
|
||||||
|
Here is an example `config.yaml`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
models:
|
||||||
|
- OpenPipe/mistral-ft-optimized-1218
|
||||||
|
- mlabonne/NeuralHermes-2.5-Mistral-7B
|
||||||
|
method: slerp
|
||||||
|
parameters:
|
||||||
|
t:
|
||||||
|
- filter: self_attn
|
||||||
|
value: [0, 0.5, 0.3, 0.7, 1]
|
||||||
|
- filter: mlp
|
||||||
|
value: [1, 0.5, 0.7, 0.3, 0]
|
||||||
|
- value: 0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
The `models` field is a list of Hugging Face repo ids. The first model in the
|
||||||
|
list is treated as the base model into which the remaining models are merged.
|
||||||
|
|
||||||
|
The `method` field is the merging method. Right now `slerp` is the only
|
||||||
|
supported method.
|
||||||
|
|
||||||
|
The `parameters` are the corresponding parameters for the given `method`.
|
||||||
|
Each parameter is a list with `filter` determining which layer the parameter
|
||||||
|
applies to and `value` determining the actual value used. The last item in
|
||||||
|
the list without a `filter` field is the default.
|
||||||
|
|
||||||
|
If `value` is a list, it specifies the start and end values for the
|
||||||
|
corresponding segment of blocks. In the example above, the models have 32
|
||||||
|
blocks. For blocks 1-8, the layers with `self_attn` in the name will use the
|
||||||
|
values `np.linspace(0, 0.5, 8)`, the same layers in the next 8 blocks (9-16)
|
||||||
|
will use `np.linspace(0.5, 0.3, 8)`, and so on.
|
63
llms/mlx_lm/SERVER.md
Normal file
63
llms/mlx_lm/SERVER.md
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# HTTP Model Server
|
||||||
|
|
||||||
|
You use `mlx-lm` to make an HTTP API for generating text with any supported
|
||||||
|
model. The HTTP API is intended to be similar to the [OpenAI chat
|
||||||
|
API](https://platform.openai.com/docs/api-reference).
|
||||||
|
|
||||||
|
Start the server with:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.server --model <path_to_model_or_hf_repo>
|
||||||
|
```
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1
|
||||||
|
```
|
||||||
|
|
||||||
|
This will start a text generation server on port `8080` of the `localhost`
|
||||||
|
using Mistral 7B instruct. The model will be downloaded from the provided
|
||||||
|
Hugging Face repo if it is not already in the local cache.
|
||||||
|
|
||||||
|
To see a full list of options run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.server --help
|
||||||
|
```
|
||||||
|
|
||||||
|
You can make a request to the model by running:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl localhost:8080/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [{"role": "user", "content": "Say this is a test!"}],
|
||||||
|
"temperature": 0.7
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Request Fields
|
||||||
|
|
||||||
|
- `messages`: An array of message objects representing the conversation
|
||||||
|
history. Each message object should have a role (e.g. user, assistant) and
|
||||||
|
content (the message text).
|
||||||
|
|
||||||
|
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
|
||||||
|
the generated prompt. If not provided, the default mappings are used.
|
||||||
|
|
||||||
|
- `stop`: (Optional) An array of strings or a single string. Thesse are
|
||||||
|
sequences of tokens on which the generation should stop.
|
||||||
|
|
||||||
|
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
|
||||||
|
to generate. Defaults to `100`.
|
||||||
|
|
||||||
|
- `stream`: (Optional) A boolean indicating if the response should be
|
||||||
|
streamed. If true, responses are sent as they are generated. Defaults to
|
||||||
|
false.
|
||||||
|
|
||||||
|
- `temperature`: (Optional) A float specifying the sampling temperature.
|
||||||
|
Defaults to `1.0`.
|
||||||
|
|
||||||
|
- `top_p`: (Optional) A float specifying the nucleus sampling parameter.
|
||||||
|
Defaults to `1.0`.
|
11
llms/mlx_lm/examples/merge_config.yaml
Normal file
11
llms/mlx_lm/examples/merge_config.yaml
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
models:
|
||||||
|
- OpenPipe/mistral-ft-optimized-1218
|
||||||
|
- mlabonne/NeuralHermes-2.5-Mistral-7B
|
||||||
|
method: slerp
|
||||||
|
parameters:
|
||||||
|
t:
|
||||||
|
- filter: self_attn
|
||||||
|
value: [0, 0.5, 0.3, 0.7, 1]
|
||||||
|
- filter: mlp
|
||||||
|
value: [1, 0.5, 0.7, 0.3, 0]
|
||||||
|
- value: 0.5
|
158
llms/mlx_lm/merge.py
Normal file
158
llms/mlx_lm/merge.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
from mlx.utils import tree_flatten, tree_map
|
||||||
|
|
||||||
|
from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub
|
||||||
|
|
||||||
|
|
||||||
|
def configure_parser() -> argparse.ArgumentParser:
|
||||||
|
"""
|
||||||
|
Configures and returns the argument parser for the script.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
argparse.ArgumentParser: Configured argument parser.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="Merge multiple models.")
|
||||||
|
|
||||||
|
parser.add_argument("--config", type=str, help="Path to the YAML config.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mlx-path",
|
||||||
|
type=str,
|
||||||
|
default="mlx_merged_model",
|
||||||
|
help="Path to save the MLX model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upload-repo",
|
||||||
|
help="The Hugging Face repo to upload the model to.",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def slerp(t, w1, w2, eps=1e-5):
|
||||||
|
"""
|
||||||
|
Spherical linear interpolation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (float): Interpolation weight in [0.0, 1.0]
|
||||||
|
w1 (mx.array): First input
|
||||||
|
w2 (mx.array): Second input
|
||||||
|
eps (float): Constant for numerical stability
|
||||||
|
Returns:
|
||||||
|
mx.array: Interpolated result
|
||||||
|
"""
|
||||||
|
t = float(t)
|
||||||
|
if t == 0:
|
||||||
|
return w1
|
||||||
|
elif t == 1:
|
||||||
|
return w2
|
||||||
|
# Normalize
|
||||||
|
v1 = w1 / mx.linalg.norm(w1)
|
||||||
|
v2 = w2 / mx.linalg.norm(w2)
|
||||||
|
# Angle
|
||||||
|
dot = mx.clip((v1 * v2).sum(), 0.0, 1.0)
|
||||||
|
theta = mx.arccos(dot)
|
||||||
|
sin_theta = mx.sin(theta + eps)
|
||||||
|
s1 = mx.sin(theta * (1 - t)) / sin_theta
|
||||||
|
s2 = mx.sin(theta * t) / sin_theta
|
||||||
|
return s1 * w1 + s2 * w2
|
||||||
|
|
||||||
|
|
||||||
|
def merge_models(base_model, model, config):
|
||||||
|
method = config.get("method", None)
|
||||||
|
if method != "slerp":
|
||||||
|
raise ValueError(f"Merge method {method} not supported")
|
||||||
|
|
||||||
|
num_layers = len(model.layers)
|
||||||
|
|
||||||
|
def unpack_values(vals):
|
||||||
|
if isinstance(vals, (int, float)):
|
||||||
|
return np.full(num_layers, vals)
|
||||||
|
bins = len(vals) - 1
|
||||||
|
sizes = [num_layers // bins] * bins
|
||||||
|
sizes[-1] = num_layers - sum(sizes[:-1])
|
||||||
|
return np.concatenate(
|
||||||
|
[np.linspace(v1, v2, s) for v1, v2, s in zip(vals[:-1], vals[1:], sizes)]
|
||||||
|
)
|
||||||
|
|
||||||
|
param_list = config["parameters"]["t"]
|
||||||
|
params = {}
|
||||||
|
filter_keys = set()
|
||||||
|
for pl in param_list[:-1]:
|
||||||
|
params[pl["filter"]] = unpack_values(pl["value"])
|
||||||
|
filter_keys.add(pl["filter"])
|
||||||
|
default = unpack_values(param_list[-1]["value"])
|
||||||
|
|
||||||
|
for e in range(num_layers):
|
||||||
|
bl = base_model.layers[e]
|
||||||
|
l = model.layers[e]
|
||||||
|
base_weights = bl.parameters()
|
||||||
|
weights = l.parameters()
|
||||||
|
for k, w1 in base_weights.items():
|
||||||
|
w2 = weights[k]
|
||||||
|
t = params.get(k, default)[e]
|
||||||
|
base_weights[k] = tree_map(lambda x, y: slerp(t, x, y), w1, w2)
|
||||||
|
base_model.update(base_weights)
|
||||||
|
|
||||||
|
|
||||||
|
def merge(
|
||||||
|
config: str,
|
||||||
|
mlx_path: str = "mlx_model",
|
||||||
|
upload_repo: str = None,
|
||||||
|
):
|
||||||
|
with open(config, "r") as fid:
|
||||||
|
merge_conf = yaml.safe_load(fid)
|
||||||
|
print("[INFO] Loading")
|
||||||
|
|
||||||
|
model_paths = merge_conf.get("models", [])
|
||||||
|
if len(model_paths) < 2:
|
||||||
|
raise ValueError(f"Expected at least 2 models, got {len(models)}.")
|
||||||
|
|
||||||
|
# Load all models
|
||||||
|
base_hf_path = model_paths[0]
|
||||||
|
base_path = get_model_path(base_hf_path)
|
||||||
|
base_model, base_config, tokenizer = fetch_from_hub(base_path)
|
||||||
|
models = []
|
||||||
|
for mp in model_paths[1:]:
|
||||||
|
model, config, _ = fetch_from_hub(get_model_path(mp))
|
||||||
|
base_type = base_config["model_type"]
|
||||||
|
model_type = config["model_type"]
|
||||||
|
if base_type != model_type:
|
||||||
|
raise ValueError(
|
||||||
|
f"Can only merge models of the same type,"
|
||||||
|
f" but got {base_type} and {model_type}."
|
||||||
|
)
|
||||||
|
models.append(model)
|
||||||
|
|
||||||
|
# Merge models into base model
|
||||||
|
for m in models:
|
||||||
|
merge_models(base_model, m, merge_conf)
|
||||||
|
|
||||||
|
# Save base model
|
||||||
|
mlx_path = Path(mlx_path)
|
||||||
|
weights = dict(tree_flatten(base_model.parameters()))
|
||||||
|
save_weights(mlx_path, weights)
|
||||||
|
py_files = glob.glob(str(base_path / "*.py"))
|
||||||
|
for file in py_files:
|
||||||
|
shutil.copy(file, mlx_path)
|
||||||
|
|
||||||
|
tokenizer.save_pretrained(mlx_path)
|
||||||
|
|
||||||
|
with open(mlx_path / "config.json", "w") as fid:
|
||||||
|
json.dump(base_config, fid, indent=4)
|
||||||
|
|
||||||
|
if upload_repo is not None:
|
||||||
|
upload_to_hub(mlx_path, upload_repo, base_hf_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = configure_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
merge(**vars(args))
|
@ -205,3 +205,7 @@ class Model(nn.Module):
|
|||||||
return {
|
return {
|
||||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
@ -11,7 +11,6 @@ from .base import BaseModelArgs
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs(BaseModelArgs):
|
class ModelArgs(BaseModelArgs):
|
||||||
model_type: str
|
model_type: str
|
||||||
vocab_size: int
|
|
||||||
vocab_size: int = 32000
|
vocab_size: int = 32000
|
||||||
max_position_embeddings: int = 4096 * 32
|
max_position_embeddings: int = 4096 * 32
|
||||||
hidden_size: int = 4096
|
hidden_size: int = 4096
|
||||||
@ -260,3 +259,7 @@ class Model(nn.Module):
|
|||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out, cache = self.model(inputs, cache)
|
||||||
return self.lm_head(out), cache
|
return self.lm_head(out), cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
@ -178,3 +178,7 @@ class Model(nn.Module):
|
|||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
return self.model(inputs, cache)
|
return self.model(inputs, cache)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.transformer.blocks
|
||||||
|
@ -178,3 +178,7 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
y, cache = self.model(x, mask, cache)
|
y, cache = self.model(x, mask, cache)
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y), cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
@ -216,3 +216,7 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
y, cache = self.transformer(x, mask, cache)
|
y, cache = self.transformer(x, mask, cache)
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y), cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.transformer.h
|
||||||
|
@ -205,3 +205,7 @@ class Model(nn.Module):
|
|||||||
return {
|
return {
|
||||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
@ -184,3 +184,7 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
y, cache = self.model(x, mask, cache)
|
y, cache = self.model(x, mask, cache)
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y), cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
@ -2,3 +2,4 @@ mlx>=0.1
|
|||||||
numpy
|
numpy
|
||||||
transformers>=4.37.0
|
transformers>=4.37.0
|
||||||
protobuf
|
protobuf
|
||||||
|
pyyaml
|
||||||
|
@ -8,7 +8,7 @@ with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid:
|
|||||||
requirements = [str(r) for r in pkg_resources.parse_requirements(fid)]
|
requirements = [str(r) for r in pkg_resources.parse_requirements(fid)]
|
||||||
setup(
|
setup(
|
||||||
name="mlx-lm",
|
name="mlx-lm",
|
||||||
version="0.0.11",
|
version="0.0.12",
|
||||||
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
|
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
|
||||||
long_description=open("README.md", encoding="utf-8").read(),
|
long_description=open("README.md", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
Loading…
Reference in New Issue
Block a user