Support for slerp merging models (#455)

* support for slerp merging models

* docs

* update docs

* format'
This commit is contained in:
Awni Hannun
2024-02-19 20:37:15 -08:00
committed by GitHub
parent 8c9148a8fd
commit 8fd953ee2b
16 changed files with 329 additions and 15 deletions

View File

@@ -8,6 +8,8 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
- Llama
- Phi2
- Mixtral
- Qwen2
- OLMo
## Contents

50
llms/mlx_lm/MERGE.md Normal file
View File

@@ -0,0 +1,50 @@
# Model Merging
You can use `mlx-lm` to merge models and upload them to the Hugging
Face hub or save them locally for LoRA fine tuning.
The main command is `mlx_lm.merge`:
```shell
python -m mlx_lm.merge --config config.yaml
```
The merged model will be saved by default in `mlx_merged_model`. To see a
full list of options run:
```shell
python -m mlx_lm.merge --help
```
Here is an example `config.yaml`:
```yaml
models:
- OpenPipe/mistral-ft-optimized-1218
- mlabonne/NeuralHermes-2.5-Mistral-7B
method: slerp
parameters:
t:
- filter: self_attn
value: [0, 0.5, 0.3, 0.7, 1]
- filter: mlp
value: [1, 0.5, 0.7, 0.3, 0]
- value: 0.5
```
The `models` field is a list of Hugging Face repo ids. The first model in the
list is treated as the base model into which the remaining models are merged.
The `method` field is the merging method. Right now `slerp` is the only
supported method.
The `parameters` are the corresponding parameters for the given `method`.
Each parameter is a list with `filter` determining which layer the parameter
applies to and `value` determining the actual value used. The last item in
the list without a `filter` field is the default.
If `value` is a list, it specifies the start and end values for the
corresponding segment of blocks. In the example above, the models have 32
blocks. For blocks 1-8, the layers with `self_attn` in the name will use the
values `np.linspace(0, 0.5, 8)`, the same layers in the next 8 blocks (9-16)
will use `np.linspace(0.5, 0.3, 8)`, and so on.

63
llms/mlx_lm/SERVER.md Normal file
View File

@@ -0,0 +1,63 @@
# HTTP Model Server
You use `mlx-lm` to make an HTTP API for generating text with any supported
model. The HTTP API is intended to be similar to the [OpenAI chat
API](https://platform.openai.com/docs/api-reference).
Start the server with:
```shell
python -m mlx_lm.server --model <path_to_model_or_hf_repo>
```
For example:
```shell
python -m mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1
```
This will start a text generation server on port `8080` of the `localhost`
using Mistral 7B instruct. The model will be downloaded from the provided
Hugging Face repo if it is not already in the local cache.
To see a full list of options run:
```shell
python -m mlx_lm.server --help
```
You can make a request to the model by running:
```shell
curl localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Say this is a test!"}],
"temperature": 0.7
}'
```
### Request Fields
- `messages`: An array of message objects representing the conversation
history. Each message object should have a role (e.g. user, assistant) and
content (the message text).
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
the generated prompt. If not provided, the default mappings are used.
- `stop`: (Optional) An array of strings or a single string. Thesse are
sequences of tokens on which the generation should stop.
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
to generate. Defaults to `100`.
- `stream`: (Optional) A boolean indicating if the response should be
streamed. If true, responses are sent as they are generated. Defaults to
false.
- `temperature`: (Optional) A float specifying the sampling temperature.
Defaults to `1.0`.
- `top_p`: (Optional) A float specifying the nucleus sampling parameter.
Defaults to `1.0`.

View File

@@ -0,0 +1,11 @@
models:
- OpenPipe/mistral-ft-optimized-1218
- mlabonne/NeuralHermes-2.5-Mistral-7B
method: slerp
parameters:
t:
- filter: self_attn
value: [0, 0.5, 0.3, 0.7, 1]
- filter: mlp
value: [1, 0.5, 0.7, 0.3, 0]
- value: 0.5

158
llms/mlx_lm/merge.py Normal file
View File

@@ -0,0 +1,158 @@
import argparse
import glob
import json
from pathlib import Path
import mlx.core as mx
import numpy as np
import yaml
from mlx.utils import tree_flatten, tree_map
from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub
def configure_parser() -> argparse.ArgumentParser:
"""
Configures and returns the argument parser for the script.
Returns:
argparse.ArgumentParser: Configured argument parser.
"""
parser = argparse.ArgumentParser(description="Merge multiple models.")
parser.add_argument("--config", type=str, help="Path to the YAML config.")
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_merged_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
return parser
def slerp(t, w1, w2, eps=1e-5):
"""
Spherical linear interpolation
Args:
t (float): Interpolation weight in [0.0, 1.0]
w1 (mx.array): First input
w2 (mx.array): Second input
eps (float): Constant for numerical stability
Returns:
mx.array: Interpolated result
"""
t = float(t)
if t == 0:
return w1
elif t == 1:
return w2
# Normalize
v1 = w1 / mx.linalg.norm(w1)
v2 = w2 / mx.linalg.norm(w2)
# Angle
dot = mx.clip((v1 * v2).sum(), 0.0, 1.0)
theta = mx.arccos(dot)
sin_theta = mx.sin(theta + eps)
s1 = mx.sin(theta * (1 - t)) / sin_theta
s2 = mx.sin(theta * t) / sin_theta
return s1 * w1 + s2 * w2
def merge_models(base_model, model, config):
method = config.get("method", None)
if method != "slerp":
raise ValueError(f"Merge method {method} not supported")
num_layers = len(model.layers)
def unpack_values(vals):
if isinstance(vals, (int, float)):
return np.full(num_layers, vals)
bins = len(vals) - 1
sizes = [num_layers // bins] * bins
sizes[-1] = num_layers - sum(sizes[:-1])
return np.concatenate(
[np.linspace(v1, v2, s) for v1, v2, s in zip(vals[:-1], vals[1:], sizes)]
)
param_list = config["parameters"]["t"]
params = {}
filter_keys = set()
for pl in param_list[:-1]:
params[pl["filter"]] = unpack_values(pl["value"])
filter_keys.add(pl["filter"])
default = unpack_values(param_list[-1]["value"])
for e in range(num_layers):
bl = base_model.layers[e]
l = model.layers[e]
base_weights = bl.parameters()
weights = l.parameters()
for k, w1 in base_weights.items():
w2 = weights[k]
t = params.get(k, default)[e]
base_weights[k] = tree_map(lambda x, y: slerp(t, x, y), w1, w2)
base_model.update(base_weights)
def merge(
config: str,
mlx_path: str = "mlx_model",
upload_repo: str = None,
):
with open(config, "r") as fid:
merge_conf = yaml.safe_load(fid)
print("[INFO] Loading")
model_paths = merge_conf.get("models", [])
if len(model_paths) < 2:
raise ValueError(f"Expected at least 2 models, got {len(models)}.")
# Load all models
base_hf_path = model_paths[0]
base_path = get_model_path(base_hf_path)
base_model, base_config, tokenizer = fetch_from_hub(base_path)
models = []
for mp in model_paths[1:]:
model, config, _ = fetch_from_hub(get_model_path(mp))
base_type = base_config["model_type"]
model_type = config["model_type"]
if base_type != model_type:
raise ValueError(
f"Can only merge models of the same type,"
f" but got {base_type} and {model_type}."
)
models.append(model)
# Merge models into base model
for m in models:
merge_models(base_model, m, merge_conf)
# Save base model
mlx_path = Path(mlx_path)
weights = dict(tree_flatten(base_model.parameters()))
save_weights(mlx_path, weights)
py_files = glob.glob(str(base_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
json.dump(base_config, fid, indent=4)
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, base_hf_path)
if __name__ == "__main__":
parser = configure_parser()
args = parser.parse_args()
merge(**vars(args))

View File

@@ -205,3 +205,7 @@ class Model(nn.Module):
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers

View File

@@ -11,7 +11,6 @@ from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
vocab_size: int = 32000
max_position_embeddings: int = 4096 * 32
hidden_size: int = 4096
@@ -260,3 +259,7 @@ class Model(nn.Module):
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
@property
def layers(self):
return self.model.layers

View File

@@ -178,3 +178,7 @@ class Model(nn.Module):
cache=None,
):
return self.model(inputs, cache)
@property
def layers(self):
return self.model.transformer.blocks

View File

@@ -178,3 +178,7 @@ class Model(nn.Module):
y, cache = self.model(x, mask, cache)
return self.lm_head(y), cache
@property
def layers(self):
return self.model.layers

View File

@@ -216,3 +216,7 @@ class Model(nn.Module):
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache
@property
def layers(self):
return self.transformer.h

View File

@@ -205,3 +205,7 @@ class Model(nn.Module):
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers

View File

@@ -184,3 +184,7 @@ class Model(nn.Module):
y, cache = self.model(x, mask, cache)
return self.lm_head(y), cache
@property
def layers(self):
return self.model.layers

View File

@@ -2,3 +2,4 @@ mlx>=0.1
numpy
transformers>=4.37.0
protobuf
pyyaml