2024-02-28 00:47:56 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
2024-02-20 12:37:15 +08:00
|
|
|
import argparse
|
|
|
|
import glob
|
2024-02-28 05:27:08 +08:00
|
|
|
import shutil
|
2024-02-20 12:37:15 +08:00
|
|
|
from pathlib import Path
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
from typing import Optional
|
2024-02-20 12:37:15 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
2024-02-28 11:40:42 +08:00
|
|
|
import mlx.nn as nn
|
2024-02-20 12:37:15 +08:00
|
|
|
import numpy as np
|
|
|
|
import yaml
|
|
|
|
from mlx.utils import tree_flatten, tree_map
|
|
|
|
|
2024-03-14 21:36:05 +08:00
|
|
|
from .utils import (
|
|
|
|
fetch_from_hub,
|
|
|
|
get_model_path,
|
|
|
|
save_config,
|
|
|
|
save_weights,
|
|
|
|
upload_to_hub,
|
|
|
|
)
|
2024-02-20 12:37:15 +08:00
|
|
|
|
|
|
|
|
|
|
|
def configure_parser() -> argparse.ArgumentParser:
|
|
|
|
"""
|
|
|
|
Configures and returns the argument parser for the script.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
argparse.ArgumentParser: Configured argument parser.
|
|
|
|
"""
|
|
|
|
parser = argparse.ArgumentParser(description="Merge multiple models.")
|
|
|
|
|
|
|
|
parser.add_argument("--config", type=str, help="Path to the YAML config.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--mlx-path",
|
|
|
|
type=str,
|
|
|
|
default="mlx_merged_model",
|
|
|
|
help="Path to save the MLX model.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--upload-repo",
|
|
|
|
help="The Hugging Face repo to upload the model to.",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
|
|
|
def slerp(t, w1, w2, eps=1e-5):
|
|
|
|
"""
|
|
|
|
Spherical linear interpolation
|
|
|
|
|
|
|
|
Args:
|
|
|
|
t (float): Interpolation weight in [0.0, 1.0]
|
|
|
|
w1 (mx.array): First input
|
|
|
|
w2 (mx.array): Second input
|
|
|
|
eps (float): Constant for numerical stability
|
|
|
|
Returns:
|
|
|
|
mx.array: Interpolated result
|
|
|
|
"""
|
|
|
|
t = float(t)
|
|
|
|
if t == 0:
|
|
|
|
return w1
|
|
|
|
elif t == 1:
|
|
|
|
return w2
|
|
|
|
# Normalize
|
|
|
|
v1 = w1 / mx.linalg.norm(w1)
|
|
|
|
v2 = w2 / mx.linalg.norm(w2)
|
|
|
|
# Angle
|
|
|
|
dot = mx.clip((v1 * v2).sum(), 0.0, 1.0)
|
|
|
|
theta = mx.arccos(dot)
|
|
|
|
sin_theta = mx.sin(theta + eps)
|
|
|
|
s1 = mx.sin(theta * (1 - t)) / sin_theta
|
|
|
|
s2 = mx.sin(theta * t) / sin_theta
|
|
|
|
return s1 * w1 + s2 * w2
|
|
|
|
|
|
|
|
|
2024-02-28 11:40:42 +08:00
|
|
|
def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
|
2024-02-20 12:37:15 +08:00
|
|
|
method = config.get("method", None)
|
|
|
|
if method != "slerp":
|
|
|
|
raise ValueError(f"Merge method {method} not supported")
|
|
|
|
|
|
|
|
num_layers = len(model.layers)
|
|
|
|
|
|
|
|
def unpack_values(vals):
|
|
|
|
if isinstance(vals, (int, float)):
|
|
|
|
return np.full(num_layers, vals)
|
|
|
|
bins = len(vals) - 1
|
|
|
|
sizes = [num_layers // bins] * bins
|
|
|
|
sizes[-1] = num_layers - sum(sizes[:-1])
|
|
|
|
return np.concatenate(
|
|
|
|
[np.linspace(v1, v2, s) for v1, v2, s in zip(vals[:-1], vals[1:], sizes)]
|
|
|
|
)
|
|
|
|
|
|
|
|
param_list = config["parameters"]["t"]
|
|
|
|
params = {}
|
|
|
|
filter_keys = set()
|
|
|
|
for pl in param_list[:-1]:
|
|
|
|
params[pl["filter"]] = unpack_values(pl["value"])
|
|
|
|
filter_keys.add(pl["filter"])
|
|
|
|
default = unpack_values(param_list[-1]["value"])
|
|
|
|
|
|
|
|
for e in range(num_layers):
|
|
|
|
bl = base_model.layers[e]
|
|
|
|
l = model.layers[e]
|
|
|
|
base_weights = bl.parameters()
|
|
|
|
weights = l.parameters()
|
|
|
|
for k, w1 in base_weights.items():
|
|
|
|
w2 = weights[k]
|
|
|
|
t = params.get(k, default)[e]
|
|
|
|
base_weights[k] = tree_map(lambda x, y: slerp(t, x, y), w1, w2)
|
|
|
|
base_model.update(base_weights)
|
|
|
|
|
|
|
|
|
|
|
|
def merge(
|
|
|
|
config: str,
|
|
|
|
mlx_path: str = "mlx_model",
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
upload_repo: Optional[str] = None,
|
2024-02-20 12:37:15 +08:00
|
|
|
):
|
|
|
|
with open(config, "r") as fid:
|
|
|
|
merge_conf = yaml.safe_load(fid)
|
|
|
|
print("[INFO] Loading")
|
|
|
|
|
|
|
|
model_paths = merge_conf.get("models", [])
|
|
|
|
if len(model_paths) < 2:
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.")
|
2024-02-20 12:37:15 +08:00
|
|
|
|
|
|
|
# Load all models
|
|
|
|
base_hf_path = model_paths[0]
|
|
|
|
base_path = get_model_path(base_hf_path)
|
2024-02-21 05:36:55 +08:00
|
|
|
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
|
2024-02-20 12:37:15 +08:00
|
|
|
models = []
|
|
|
|
for mp in model_paths[1:]:
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
|
2024-02-20 12:37:15 +08:00
|
|
|
base_type = base_config["model_type"]
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
model_type = model_config["model_type"]
|
2024-02-20 12:37:15 +08:00
|
|
|
if base_type != model_type:
|
|
|
|
raise ValueError(
|
|
|
|
f"Can only merge models of the same type,"
|
|
|
|
f" but got {base_type} and {model_type}."
|
|
|
|
)
|
|
|
|
models.append(model)
|
|
|
|
|
|
|
|
# Merge models into base model
|
|
|
|
for m in models:
|
|
|
|
merge_models(base_model, m, merge_conf)
|
|
|
|
|
|
|
|
# Save base model
|
|
|
|
mlx_path = Path(mlx_path)
|
|
|
|
weights = dict(tree_flatten(base_model.parameters()))
|
2024-02-21 05:36:55 +08:00
|
|
|
del models, base_model
|
|
|
|
save_weights(mlx_path, weights, donate_weights=True)
|
2024-02-20 12:37:15 +08:00
|
|
|
py_files = glob.glob(str(base_path / "*.py"))
|
|
|
|
for file in py_files:
|
|
|
|
shutil.copy(file, mlx_path)
|
|
|
|
|
|
|
|
tokenizer.save_pretrained(mlx_path)
|
|
|
|
|
2024-03-14 21:36:05 +08:00
|
|
|
save_config(config, config_path=mlx_path / "config.json")
|
2024-02-20 12:37:15 +08:00
|
|
|
|
|
|
|
if upload_repo is not None:
|
|
|
|
upload_to_hub(mlx_path, upload_repo, base_hf_path)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = configure_parser()
|
|
|
|
args = parser.parse_args()
|
|
|
|
merge(**vars(args))
|