import argparse import glob import json from pathlib import Path import mlx.core as mx import numpy as np import yaml from mlx.utils import tree_flatten, tree_map from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub def configure_parser() -> argparse.ArgumentParser: """ Configures and returns the argument parser for the script. Returns: argparse.ArgumentParser: Configured argument parser. """ parser = argparse.ArgumentParser(description="Merge multiple models.") parser.add_argument("--config", type=str, help="Path to the YAML config.") parser.add_argument( "--mlx-path", type=str, default="mlx_merged_model", help="Path to save the MLX model.", ) parser.add_argument( "--upload-repo", help="The Hugging Face repo to upload the model to.", type=str, default=None, ) return parser def slerp(t, w1, w2, eps=1e-5): """ Spherical linear interpolation Args: t (float): Interpolation weight in [0.0, 1.0] w1 (mx.array): First input w2 (mx.array): Second input eps (float): Constant for numerical stability Returns: mx.array: Interpolated result """ t = float(t) if t == 0: return w1 elif t == 1: return w2 # Normalize v1 = w1 / mx.linalg.norm(w1) v2 = w2 / mx.linalg.norm(w2) # Angle dot = mx.clip((v1 * v2).sum(), 0.0, 1.0) theta = mx.arccos(dot) sin_theta = mx.sin(theta + eps) s1 = mx.sin(theta * (1 - t)) / sin_theta s2 = mx.sin(theta * t) / sin_theta return s1 * w1 + s2 * w2 def merge_models(base_model, model, config): method = config.get("method", None) if method != "slerp": raise ValueError(f"Merge method {method} not supported") num_layers = len(model.layers) def unpack_values(vals): if isinstance(vals, (int, float)): return np.full(num_layers, vals) bins = len(vals) - 1 sizes = [num_layers // bins] * bins sizes[-1] = num_layers - sum(sizes[:-1]) return np.concatenate( [np.linspace(v1, v2, s) for v1, v2, s in zip(vals[:-1], vals[1:], sizes)] ) param_list = config["parameters"]["t"] params = {} filter_keys = set() for pl in param_list[:-1]: params[pl["filter"]] = unpack_values(pl["value"]) filter_keys.add(pl["filter"]) default = unpack_values(param_list[-1]["value"]) for e in range(num_layers): bl = base_model.layers[e] l = model.layers[e] base_weights = bl.parameters() weights = l.parameters() for k, w1 in base_weights.items(): w2 = weights[k] t = params.get(k, default)[e] base_weights[k] = tree_map(lambda x, y: slerp(t, x, y), w1, w2) base_model.update(base_weights) def merge( config: str, mlx_path: str = "mlx_model", upload_repo: str = None, ): with open(config, "r") as fid: merge_conf = yaml.safe_load(fid) print("[INFO] Loading") model_paths = merge_conf.get("models", []) if len(model_paths) < 2: raise ValueError(f"Expected at least 2 models, got {len(models)}.") # Load all models base_hf_path = model_paths[0] base_path = get_model_path(base_hf_path) base_model, base_config, tokenizer = fetch_from_hub(base_path) models = [] for mp in model_paths[1:]: model, config, _ = fetch_from_hub(get_model_path(mp)) base_type = base_config["model_type"] model_type = config["model_type"] if base_type != model_type: raise ValueError( f"Can only merge models of the same type," f" but got {base_type} and {model_type}." ) models.append(model) # Merge models into base model for m in models: merge_models(base_model, m, merge_conf) # Save base model mlx_path = Path(mlx_path) weights = dict(tree_flatten(base_model.parameters())) save_weights(mlx_path, weights) py_files = glob.glob(str(base_path / "*.py")) for file in py_files: shutil.copy(file, mlx_path) tokenizer.save_pretrained(mlx_path) with open(mlx_path / "config.json", "w") as fid: json.dump(base_config, fid, indent=4) if upload_repo is not None: upload_to_hub(mlx_path, upload_repo, base_hf_path) if __name__ == "__main__": parser = configure_parser() args = parser.parse_args() merge(**vars(args))