From 362e88a74465b0910aeb3edbb0033d181d5c1bdc Mon Sep 17 00:00:00 2001
From: Anchen
Date: Tue, 23 Jan 2024 08:44:37 -0800
Subject: [PATCH] feat: move lora into mlx-lm (#337)
* feat: Add lora and qlora training to mlx-lm
---------
Co-authored-by: Awni Hannun
---
llms/README.md | 4 +
llms/mlx_lm/LORA.md | 165 ++++++++++++++++++++++
llms/mlx_lm/README.md | 7 +-
llms/mlx_lm/convert.py | 119 ++++------------
llms/mlx_lm/fuse.py | 91 ++++++++++++
llms/mlx_lm/lora.py | 252 ++++++++++++++++++++++++++++++++++
llms/mlx_lm/models/mixtral.py | 33 +++--
llms/mlx_lm/tuner/__init__.py | 0
llms/mlx_lm/tuner/lora.py | 88 ++++++++++++
llms/mlx_lm/tuner/trainer.py | 204 +++++++++++++++++++++++++++
llms/mlx_lm/tuner/utils.py | 22 +++
llms/mlx_lm/utils.py | 109 ++++++++++++++-
llms/phixtral/phixtral.py | 4 +-
13 files changed, 987 insertions(+), 111 deletions(-)
create mode 100644 llms/mlx_lm/LORA.md
create mode 100644 llms/mlx_lm/fuse.py
create mode 100644 llms/mlx_lm/lora.py
create mode 100644 llms/mlx_lm/tuner/__init__.py
create mode 100644 llms/mlx_lm/tuner/lora.py
create mode 100644 llms/mlx_lm/tuner/trainer.py
create mode 100644 llms/mlx_lm/tuner/utils.py
diff --git a/llms/README.md b/llms/README.md
index 3787fc94..b2af7fb7 100644
--- a/llms/README.md
+++ b/llms/README.md
@@ -14,6 +14,10 @@ pip install mlx-lm
conda install -c conda-forge mlx-lm
```
+The `mlx-lm` package also supports LoRA and QLoRA fine-tuning. For more details
+on this see the [LoRA
+documentation](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md).
+
### Python API
You can use `mlx-lm` as a module:
diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md
new file mode 100644
index 00000000..b8377f88
--- /dev/null
+++ b/llms/mlx_lm/LORA.md
@@ -0,0 +1,165 @@
+# Fine-Tuning with LoRA or QLoRA
+
+You can use use the `mlx-lm` package to fine-tune an LLM with low rank
+adaptation (LoRA) for a target task.[^lora] The example also supports quantized
+LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
+
+- Mistral
+- Llama
+- Phi2
+- Mixtral
+
+## Contents
+
+* [Run](#Run)
+ * [Fine-tune](#Fine-tune)
+ * [Evaluate](#Evaluate)
+ * [Generate](#Generate)
+* [Fuse and Upload](#Fuse-and-Upload)
+* [Data](#Data)
+* [Memory Issues](#Memory-Issues)
+
+## Run
+
+The main command is `mlx_lm.lora`. To see a full list of options run:
+
+```shell
+python -m mlx_lm.lora --help
+```
+
+Note, in the following the `--model` argument can be any compatible Hugging
+Face repo or a local path to a converted model.
+
+### Fine-tune
+
+To fine-tune a model use:
+
+```shell
+python -m mlx_lm.lora \
+ --model \
+ --train \
+ --data \
+ --iters 600
+```
+
+The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
+when using `--train` and a path to a `test.jsonl` when using `--test`. For more
+details on the data format see the section on [Data](#Data).
+
+For example, to fine-tune a Mistral 7B you can use `--model
+mistralai/Mistral-7B-v0.1`.
+
+If `--model` points to a quantized model, then the training will use QLoRA,
+otherwise it will use regular LoRA.
+
+By default, the adapter weights are saved in `adapters.npz`. You can specify
+the output location with `--adapter-file`.
+
+You can resume fine-tuning with an existing adapter with
+`--resume-adapter-file `.
+
+### Evaluate
+
+To compute test set perplexity use:
+
+```shell
+python -m mlx_lm.lora \
+ --model \
+ --adapter-file \
+ --data \
+ --test
+```
+
+## Fuse and Upload
+
+You can generate a model fused with the low-rank adapters using the
+`mlx_lm.fuse` command. This command also allows you to upload the fused model
+to the Hugging Face Hub.
+
+To see supported options run:
+
+```shell
+python -m mlx_lm.fuse --help
+```
+
+To generate the fused model run:
+
+```shell
+python -m mlx_lm.fuse --model
+```
+
+This will by default load the adapters from `adapters.npz`, and save the fused
+model in the path `lora_fused_model/`. All of these are configurable.
+
+To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
+to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
+useful for the sake of attribution and model versioning.
+
+For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
+
+```shell
+python -m mlx_lm.fuse \
+ --model mistralai/Mistral-7B-v0.1 \
+ --upload-repo mlx-community/my-4bit-lora-mistral \
+ --hf-path mistralai/Mistral-7B-v0.1
+```
+
+## Data
+
+The LoRA command expects you to provide a dataset with `--data`. The MLX
+Examples GitHub repo has an [example of the WikiSQL
+data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
+correct format.
+
+For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
+`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
+loader expects a `test.jsonl` in the data directory. Each line in the `*.jsonl`
+file should look like:
+
+```
+{"text": "This is an example for the model."}
+```
+
+Note, other keys will be ignored by the loader.
+
+## Memory Issues
+
+Fine-tuning a large model with LoRA requires a machine with a decent amount
+of memory. Here are some tips to reduce memory use should you need to do so:
+
+1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
+ with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
+ more details.
+
+2. Try using a smaller batch size with `--batch-size`. The default is `4` so
+ setting this to `2` or `1` will reduce memory consumption. This may slow
+ things down a little, but will also reduce the memory use.
+
+3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
+ is `16`, so you can try `8` or `4`. This reduces the amount of memory
+ needed for back propagation. It may also reduce the quality of the
+ fine-tuned model if you are fine-tuning with a lot of data.
+
+4. Longer examples require more memory. If it makes sense for your data, one thing
+ you can do is break your examples into smaller
+ sequences when making the `{train, valid, test}.jsonl` files.
+
+For example, for a machine with 32 GB the following should run reasonably fast:
+
+```
+python lora.py \
+ --model mistralai/Mistral-7B-v0.1 \
+ --train \
+ --batch-size 1 \
+ --lora-layers 4 \
+ --data wikisql
+```
+
+The above command on an M1 Max with 32 GB runs at about 250
+tokens-per-second, using the MLX Example
+[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data)
+data set.
+
+
+[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
+[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
diff --git a/llms/mlx_lm/README.md b/llms/mlx_lm/README.md
index f7fd1dad..66f2b5e9 100644
--- a/llms/mlx_lm/README.md
+++ b/llms/mlx_lm/README.md
@@ -3,5 +3,8 @@
This an example of large language model text generation that can pull models from
the Hugging Face Hub.
-For more information on this example, see the
-[README](../README.md) in the parent directory.
+For more information on this example, see the [README](../README.md) in the
+parent directory.
+
+This package also supports fine tuning with LoRA or QLoRA. For more information
+see the [LoRA documentation](LORA.md).
diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py
index ead4f0e4..3f590f1c 100644
--- a/llms/mlx_lm/convert.py
+++ b/llms/mlx_lm/convert.py
@@ -2,17 +2,21 @@ import argparse
import copy
import glob
import json
+import shutil
from pathlib import Path
-from typing import Dict, Tuple
+from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
-import transformers
from mlx.utils import tree_flatten
-from .utils import get_model_path, linear_class_predicate, load_model
-
-MAX_FILE_SIZE_GB = 15
+from .utils import (
+ fetch_from_hub,
+ get_model_path,
+ linear_class_predicate,
+ save_weights,
+ upload_to_hub,
+)
def configure_parser() -> argparse.ArgumentParser:
@@ -55,22 +59,9 @@ def configure_parser() -> argparse.ArgumentParser:
return parser
-def fetch_from_hub(
- model_path: str,
-) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
- model_path = get_model_path(model_path)
-
- model = load_model(model_path)
-
- config = transformers.AutoConfig.from_pretrained(model_path)
- tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
-
- return model, config.to_dict(), tokenizer
-
-
def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int
-) -> tuple:
+) -> Tuple:
"""
Applies quantization to the model weights.
@@ -81,7 +72,7 @@ def quantize_model(
q_bits (int): Bits per weight for quantization.
Returns:
- tuple: Tuple containing quantized weights and config.
+ Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
@@ -94,76 +85,6 @@ def quantize_model(
return quantized_weights, quantized_config
-def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
- """
- Splits the weights into smaller shards.
-
- Args:
- weights (dict): Model weights.
- max_file_size_gb (int): Maximum size of each shard in gigabytes.
-
- Returns:
- list: List of weight shards.
- """
- max_file_size_bytes = max_file_size_gb << 30
- shards = []
- shard, shard_size = {}, 0
- for k, v in weights.items():
- estimated_size = v.size * v.dtype.size
- if shard_size + estimated_size > max_file_size_bytes:
- shards.append(shard)
- shard, shard_size = {}, 0
- shard[k] = v
- shard_size += estimated_size
- shards.append(shard)
- return shards
-
-
-def upload_to_hub(path: str, upload_repo: str, hf_path: str):
- """
- Uploads the model to Hugging Face hub.
-
- Args:
- path (str): Local path to the model.
- upload_repo (str): Name of the HF repo to upload to.
- hf_path (str): Path to the original Hugging Face model.
- """
- import os
-
- from huggingface_hub import HfApi, ModelCard, logging
-
- card = ModelCard.load(hf_path)
- card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
- card.text = f"""
-# {upload_repo}
-This model was converted to MLX format from [`{hf_path}`]().
-Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
-## Use with mlx
-
-```bash
-pip install mlx-lm
-```
-
-```python
-from mlx_lm import load, generate
-
-model, tokenizer = load("{upload_repo}")
-response = generate(model, tokenizer, prompt="hello", verbose=True)
-```
-"""
- card.save(os.path.join(path, "README.md"))
-
- logging.set_verbosity_info()
-
- api = HfApi()
- api.create_repo(repo_id=upload_repo, exist_ok=True)
- api.upload_folder(
- folder_path=path,
- repo_id=upload_repo,
- repo_type="model",
- )
-
-
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
@@ -174,7 +95,8 @@ def convert(
upload_repo: str = None,
):
print("[INFO] Loading")
- model, config, tokenizer = fetch_from_hub(hf_path)
+ model_path = get_model_path(hf_path)
+ model, config, tokenizer = fetch_from_hub(model_path)
weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype)
@@ -185,12 +107,17 @@ def convert(
model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits)
- mlx_path = Path(mlx_path)
- mlx_path.mkdir(parents=True, exist_ok=True)
- shards = make_shards(weights)
- for i, shard in enumerate(shards):
- mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard)
+ if isinstance(mlx_path, str):
+ mlx_path = Path(mlx_path)
+
+ save_weights(mlx_path, weights)
+
+ py_files = glob.glob(str(model_path / "*.py"))
+ for file in py_files:
+ shutil.copy(file, mlx_path)
+
tokenizer.save_pretrained(mlx_path)
+
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py
new file mode 100644
index 00000000..60a13f94
--- /dev/null
+++ b/llms/mlx_lm/fuse.py
@@ -0,0 +1,91 @@
+import argparse
+import glob
+import json
+import shutil
+from pathlib import Path
+from typing import Any, Dict, Union
+
+from mlx.utils import tree_flatten, tree_unflatten
+
+from .tuner.lora import LoRALinear
+from .tuner.utils import apply_lora_layers
+from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub
+
+
+def parse_arguments() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
+ parser.add_argument(
+ "--model",
+ default="mlx_model",
+ help="The path to the local model directory or Hugging Face repo.",
+ )
+ parser.add_argument(
+ "--save-path",
+ default="lora_fused_model",
+ help="The path to save the fused model.",
+ )
+ parser.add_argument(
+ "--adapter-file",
+ type=str,
+ default="adapters.npz",
+ help="Path to the trained adapter weights (npz or safetensors).",
+ )
+ parser.add_argument(
+ "--hf-path",
+ type=str,
+ default=None,
+ help="Path to the original Hugging Face model. Required for upload if --model is a local directory.",
+ )
+ parser.add_argument(
+ "--upload-repo",
+ help="The Hugging Face repo to upload the model to.",
+ type=str,
+ default=None,
+ )
+ return parser.parse_args()
+
+
+def main() -> None:
+ print("Loading pretrained model")
+ args = parse_arguments()
+
+ model_path = get_model_path(args.model)
+ model, config, tokenizer = fetch_from_hub(model_path)
+
+ model.freeze()
+ model = apply_lora_layers(model, args.adapter_file)
+ fused_linears = [
+ (n, m.to_linear())
+ for n, m in model.named_modules()
+ if isinstance(m, LoRALinear)
+ ]
+
+ model.update_modules(tree_unflatten(fused_linears))
+ weights = dict(tree_flatten(model.parameters()))
+
+ save_path = Path(args.save_path)
+
+ save_weights(save_path, weights)
+
+ py_files = glob.glob(str(model_path / "*.py"))
+ for file in py_files:
+ shutil.copy(file, save_path)
+
+ tokenizer.save_pretrained(save_path)
+
+ with open(save_path / "config.json", "w") as fid:
+ json.dump(config, fid, indent=4)
+
+ if args.upload_repo is not None:
+ hf_path = args.hf_path or (
+ args.model if not Path(args.model).exists() else None
+ )
+ if hf_path is None:
+ raise ValueError(
+ "Must provide original Hugging Face repo to upload local model."
+ )
+ upload_to_hub(args.save_path, args.upload_repo, hf_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
new file mode 100644
index 00000000..2bcb8099
--- /dev/null
+++ b/llms/mlx_lm/lora.py
@@ -0,0 +1,252 @@
+import argparse
+import json
+import math
+from pathlib import Path
+
+import mlx.optimizers as optim
+import numpy as np
+from mlx.utils import tree_flatten
+
+from .models import llama, mixtral, phi2
+from .tuner.lora import LoRALinear
+from .tuner.trainer import TrainingArgs, evaluate, train
+from .utils import generate, load
+
+SUPPORTED_MODELS = [llama.Model, mixtral.Model, phi2.Model]
+
+
+def build_parser():
+ parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
+ parser.add_argument(
+ "--model",
+ default="mlx_model",
+ help="The path to the local model directory or Hugging Face repo.",
+ )
+ # Generation args
+ parser.add_argument(
+ "--max-tokens",
+ "-m",
+ type=int,
+ default=100,
+ help="The maximum number of tokens to generate",
+ )
+ parser.add_argument(
+ "--temp", type=float, default=0.8, help="The sampling temperature"
+ )
+ parser.add_argument(
+ "--prompt",
+ "-p",
+ type=str,
+ help="The prompt for generation",
+ default=None,
+ )
+
+ # Training args
+ parser.add_argument(
+ "--train",
+ action="store_true",
+ help="Do training",
+ )
+ parser.add_argument(
+ "--data",
+ type=str,
+ default="data/",
+ help="Directory with {train, valid, test}.jsonl files",
+ )
+ parser.add_argument(
+ "--lora-layers",
+ type=int,
+ default=16,
+ help="Number of layers to fine-tune",
+ )
+ parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.")
+ parser.add_argument(
+ "--iters", type=int, default=1000, help="Iterations to train for."
+ )
+ parser.add_argument(
+ "--val-batches",
+ type=int,
+ default=25,
+ help="Number of validation batches, -1 uses the entire validation set.",
+ )
+ parser.add_argument(
+ "--learning-rate", type=float, default=1e-5, help="Adam learning rate."
+ )
+ parser.add_argument(
+ "--steps-per-report",
+ type=int,
+ default=10,
+ help="Number of training steps between loss reporting.",
+ )
+ parser.add_argument(
+ "--steps-per-eval",
+ type=int,
+ default=200,
+ help="Number of training steps between validations.",
+ )
+ parser.add_argument(
+ "--resume-adapter-file",
+ type=str,
+ default=None,
+ help="Load path to resume training with the given adapter weights.",
+ )
+ parser.add_argument(
+ "--adapter-file",
+ type=str,
+ default="adapters.npz",
+ help="Save/load path for the trained adapter weights.",
+ )
+ parser.add_argument(
+ "--save-every",
+ type=int,
+ default=100,
+ help="Save the model every N iterations.",
+ )
+ parser.add_argument(
+ "--test",
+ action="store_true",
+ help="Evaluate on the test set after training",
+ )
+ parser.add_argument(
+ "--test-batches",
+ type=int,
+ default=500,
+ help="Number of test set batches, -1 uses the entire test set.",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
+ return parser
+
+
+class Dataset:
+ """
+ Light-weight wrapper to hold lines from a jsonl file
+ """
+
+ def __init__(self, path: Path, key: str = "text"):
+ if not path.exists():
+ self._data = None
+ else:
+ with open(path, "r") as fid:
+ self._data = [json.loads(l) for l in fid]
+ self._key = key
+
+ def __getitem__(self, idx: int):
+ return self._data[idx][self._key]
+
+ def __len__(self):
+ if self._data is None:
+ return 0
+ return len(self._data)
+
+
+def load_dataset(args):
+ names = ("train", "valid", "test")
+ train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names)
+ if args.train and len(train) == 0:
+ raise ValueError(
+ "Training set not found or empty. Must provide training set for fine-tuning."
+ )
+ if args.train and len(valid) == 0:
+ raise ValueError(
+ "Validation set not found or empty. Must provide validation set for fine-tuning."
+ )
+ if args.test and len(test) == 0:
+ raise ValueError(
+ "Test set not found or empty. Must provide test set for evaluation."
+ )
+ return train, valid, test
+
+
+if __name__ == "__main__":
+ parser = build_parser()
+ args = parser.parse_args()
+
+ np.random.seed(args.seed)
+
+ print("Loading pretrained model")
+ model, tokenizer = load(args.model)
+
+ if model.__class__ not in SUPPORTED_MODELS:
+ raise ValueError(
+ f"Model {model.__class__} not supported. "
+ f"Supported models: { SUPPORTED_MODELS}"
+ )
+
+ # Freeze all layers other than LORA linears
+ model.freeze()
+ for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
+ l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
+ l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
+ if hasattr(l, "block_sparse_moe"):
+ l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
+
+ p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
+ print(f"Total parameters {p:.3f}M")
+ p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
+ print(f"Trainable parameters {p:.3f}M")
+
+ print("Loading datasets")
+ train_set, valid_set, test_set = load_dataset(args)
+
+ # Resume training the given adapters.
+ if args.resume_adapter_file is not None:
+ print(f"Loading pretrained adapters from {args.resume_adapter_file}")
+ model.load_weights(args.resume_adapter_file, strict=False)
+ # init training args
+ trainingArgs = TrainingArgs(
+ batch_size=args.batch_size,
+ iters=args.iters,
+ val_batches=args.val_batches,
+ steps_per_report=args.steps_per_report,
+ steps_per_eval=args.steps_per_eval,
+ steps_per_save=args.save_every,
+ adapter_file=args.adapter_file,
+ )
+ if args.train:
+ print("Training")
+ model.train()
+ opt = optim.Adam(learning_rate=args.learning_rate)
+ # Train model
+ train(
+ model=model,
+ tokenizer=tokenizer,
+ args=trainingArgs,
+ optimizer=opt,
+ train_dataset=train_set,
+ val_dataset=valid_set,
+ )
+
+ # Load the LoRA adapter weights which we assume should exist by this point
+ if not Path(args.adapter_file).is_file():
+ raise ValueError(
+ f"Adapter file {args.adapter_file} missing. "
+ "Use --train to learn and save the adapters.npz."
+ )
+ model.load_weights(args.adapter_file, strict=False)
+
+ if args.test:
+ print("Testing")
+ model.eval()
+
+ test_loss = evaluate(
+ model=model,
+ dataset=test_set,
+ tokenizer=tokenizer,
+ batch_size=args.batch_size,
+ num_batches=args.test_batches,
+ )
+
+ test_ppl = math.exp(test_loss)
+
+ print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
+
+ if args.prompt is not None:
+ print("Generating")
+ generate(
+ model=model,
+ tokenizer=tokenizer,
+ temp=args.temp,
+ max_tokens=args.max_tokens,
+ prompt=args.prompt,
+ verbose=True,
+ )
diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py
index d35900cf..63401de1 100644
--- a/llms/mlx_lm/models/mixtral.py
+++ b/llms/mlx_lm/models/mixtral.py
@@ -3,6 +3,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
+import numpy as np
from .base import BaseModelArgs
@@ -158,18 +159,34 @@ class MixtralSparseMoeBlock(nn.Module):
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
- inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
+
+ inds = mx.stop_gradient(
+ mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
+ ) # TODO remove it once we figure out how to fine tune TopK in MOE
+
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
- y = []
- for xt, st, it in zip(x, scores, inds.tolist()):
- yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
- yt = (yt * st).sum(axis=-1)
- y.append(yt[None, :])
- y = mx.concatenate(y)
+ if self.training:
+ mx.eval(inds)
+ inds = np.array(inds)
+ y = mx.zeros((x.shape[0], ne, x.shape[-1]))
+ for e, expert in enumerate(self.experts):
+ idx1, idx2 = map(mx.array, np.where(inds == e))
+ if idx1.size == 0:
+ continue
+ y[idx1, idx2] = expert(x[idx1])
+
+ y = (y * scores[:, :, None]).sum(axis=1)
+ else:
+ y = []
+ for xt, st, it in zip(x, scores, inds.tolist()):
+ yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
+ yt = (yt * st).sum(axis=-1)
+ y.append(yt[None, :])
+ y = mx.concatenate(y)
return y.reshape(orig_shape)
@@ -229,7 +246,7 @@ class MixtralModel(nn.Module):
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
- return self.norm(h[:, T - 1 : T, :]), cache
+ return self.norm(h), cache
class Model(nn.Module):
diff --git a/llms/mlx_lm/tuner/__init__.py b/llms/mlx_lm/tuner/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/llms/mlx_lm/tuner/lora.py b/llms/mlx_lm/tuner/lora.py
new file mode 100644
index 00000000..f0ec601b
--- /dev/null
+++ b/llms/mlx_lm/tuner/lora.py
@@ -0,0 +1,88 @@
+import math
+
+import mlx.core as mx
+import mlx.nn as nn
+
+
+class LoRALinear(nn.Module):
+ @staticmethod
+ def from_linear(linear: nn.Linear, rank: int = 8, scale: float = 20.0):
+ # TODO remove when input_dims and output_dims are attributes
+ # on linear and quantized linear
+ output_dims, input_dims = linear.weight.shape
+ if isinstance(linear, nn.QuantizedLinear):
+ input_dims *= 32 // linear.bits
+ lora_lin = LoRALinear(
+ input_dims=input_dims, output_dims=output_dims, rank=rank, scale=scale
+ )
+ lora_lin.linear = linear
+ return lora_lin
+
+ def to_linear(self):
+ linear = self.linear
+ bias = "bias" in linear
+ weight = linear.weight
+ is_quantized = isinstance(linear, nn.QuantizedLinear)
+
+ # Use the same type as the linear weight if not quantized
+ dtype = weight.dtype
+
+ if is_quantized:
+ dtype = mx.float16
+ weight = mx.dequantize(
+ weight,
+ linear.scales,
+ linear.biases,
+ linear.group_size,
+ linear.bits,
+ )
+ output_dims, input_dims = weight.shape
+ fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
+
+ lora_b = (self.scale * self.lora_b.T).astype(dtype)
+ lora_a = self.lora_a.T.astype(dtype)
+ fused_linear.weight = weight + lora_b @ lora_a
+ if bias:
+ fused_linear.bias = linear.bias
+
+ if is_quantized:
+ fused_linear = nn.QuantizedLinear.from_linear(
+ fused_linear,
+ linear.group_size,
+ linear.bits,
+ )
+
+ return fused_linear
+
+ def __init__(
+ self,
+ input_dims: int,
+ output_dims: int,
+ rank: int = 8,
+ bias: bool = False,
+ scale: float = 20.0,
+ ):
+ super().__init__()
+
+ # Regular linear layer weights
+ self.linear = nn.Linear(input_dims, output_dims, bias=bias)
+
+ # Scale for low-rank update
+ self.scale = scale
+
+ # Low rank lora weights
+ scale = 1 / math.sqrt(input_dims)
+ self.lora_a = mx.random.uniform(
+ low=-scale,
+ high=scale,
+ shape=(input_dims, rank),
+ )
+ self.lora_b = mx.zeros(shape=(rank, output_dims))
+
+ def __call__(self, x):
+ dtype = self.linear.weight.dtype
+ if isinstance(self.linear, nn.QuantizedLinear):
+ dtype = self.linear.scales.dtype
+ y = self.linear(x.astype(dtype))
+ z = (x @ self.lora_a) @ self.lora_b
+ return y + self.scale * z
diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py
new file mode 100644
index 00000000..3c255703
--- /dev/null
+++ b/llms/mlx_lm/tuner/trainer.py
@@ -0,0 +1,204 @@
+import os
+import time
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+import numpy as np
+from mlx.utils import tree_flatten
+
+
+@dataclass
+class TrainingArgs:
+ lora_layers: int = field(
+ default=16, metadata={"help": "Number of layers to fine-tune"}
+ )
+ batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
+ iters: int = field(default=100, metadata={"help": "Iterations to train for."})
+ val_batches: int = field(
+ default=25,
+ metadata={
+ "help": "Number of validation batches, -1 uses the entire validation set."
+ },
+ )
+ steps_per_report: int = field(
+ default=10,
+ metadata={"help": "Number of training steps between loss reporting."},
+ )
+ steps_per_eval: int = field(
+ default=200, metadata={"help": "Number of training steps between validations."}
+ )
+ steps_per_save: int = field(
+ default=100, metadata={"help": "Save the model every number steps"}
+ )
+ max_seq_length: int = field(
+ default=2048, metadata={"help": "Maximum sequence length."}
+ )
+ adapter_file: str = field(
+ default="adapter.npz",
+ metadata={"help": "Save/load path for the trained adapter weights."},
+ )
+
+
+def default_loss(model, inputs, targets, lengths):
+ logits, _ = model(inputs)
+ logits = logits.astype(mx.float32)
+
+ length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
+
+ ce = nn.losses.cross_entropy(logits, targets) * length_mask
+ ntoks = length_mask.sum()
+ ce = ce.sum() / ntoks
+
+ return ce, ntoks
+
+
+def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
+ while True:
+ # Shuffle indices
+ indices = np.arange(len(dataset))
+ indices = np.random.permutation(indices)
+ # Collect batches from dataset
+ for i in range(0, len(indices) - batch_size + 1, batch_size):
+ # Encode batch
+ batch = [
+ tokenizer.encode(dataset[indices[i + j]]) for j in range(batch_size)
+ ]
+ lengths = [len(x) for x in batch]
+
+ # Check if any sequence is longer than max_seq_length
+ if max(lengths) > max_seq_length:
+ print(
+ "[WARNING] Some sequences are longer than 2048 tokens. "
+ "Consider pre-splitting your data to save memory."
+ )
+
+ # Pad to the max length
+ batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
+
+ for j in range(batch_size):
+ batch_arr[j, : lengths[j]] = batch[j]
+ batch = mx.array(batch_arr)
+ yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
+
+ if not train:
+ break
+
+
+def evaluate(
+ model,
+ dataset,
+ tokenizer,
+ batch_size,
+ num_batches,
+ max_seq_length=2048,
+ loss: callable = default_loss,
+):
+ all_losses = []
+ ntokens = 0
+ for it, batch in zip(
+ range(num_batches),
+ iterate_batches(
+ dataset=dataset,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
+ max_seq_length=max_seq_length,
+ ),
+ ):
+ losses, toks = loss(model, *batch)
+ all_losses.append((losses * toks).item())
+ ntokens += toks.item()
+
+ return np.sum(all_losses) / ntokens
+
+
+def train(
+ model,
+ tokenizer,
+ optimizer,
+ train_dataset,
+ val_dataset,
+ args: TrainingArgs = TrainingArgs(),
+ loss: callable = default_loss,
+):
+ # Create value and grad function for loss
+ loss_value_and_grad = nn.value_and_grad(model, loss)
+
+ losses = []
+ n_tokens = 0
+ print("Starting training..., iters:", args.iters)
+ # Main training loop
+ start = time.perf_counter()
+ for it, batch in zip(
+ range(args.iters),
+ iterate_batches(
+ dataset=train_dataset,
+ tokenizer=tokenizer,
+ batch_size=args.batch_size,
+ max_seq_length=args.max_seq_length,
+ train=True,
+ ),
+ ):
+ # Forward and backward pass
+ (lvalue, toks), grad = loss_value_and_grad(model, *batch)
+
+ # Model update
+ optimizer.update(model, grad)
+
+ mx.eval(model.parameters(), optimizer.state, lvalue)
+
+ # Record loss
+ losses.append(lvalue.item())
+ n_tokens += toks.item()
+
+ # Report training loss if needed
+ if (it + 1) % args.steps_per_report == 0:
+ train_loss = np.mean(losses)
+
+ stop = time.perf_counter()
+ print(
+ f"Iter {it + 1}: Train loss {train_loss:.3f}, "
+ f"It/sec {args.steps_per_report / (stop - start):.3f}, "
+ f"Tokens/sec {float(n_tokens) / (stop - start):.3f}"
+ )
+ losses = []
+ n_tokens = 0
+ start = time.perf_counter()
+
+ # Report validation loss if needed
+ if it == 0 or (it + 1) % args.steps_per_eval == 0:
+ stop = time.perf_counter()
+ val_loss = evaluate(
+ model=model,
+ dataset=val_dataset,
+ loss=loss,
+ tokenizer=tokenizer,
+ batch_size=args.batch_size,
+ num_batches=args.val_batches,
+ )
+ print(
+ f"Iter {it + 1}: "
+ f"Val loss {val_loss:.3f}, "
+ f"Val took {(time.perf_counter() - stop):.3f}s"
+ )
+
+ start = time.perf_counter()
+
+ # Save adapter weights if needed
+ if (it + 1) % args.steps_per_save == 0:
+ save_adapter(model=model, adapter_file=args.adapter_file)
+ print(
+ f"Iter {it + 1}: Saved adapter weights to {os.path.join(args.adapter_file)}."
+ )
+ # save final adapter weights
+ save_adapter(model=model, adapter_file=args.adapter_file)
+ print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")
+
+
+def save_adapter(
+ model: nn.Module,
+ adapter_file: str,
+):
+ flattened_tree = tree_flatten(model.trainable_parameters())
+
+ mx.savez(adapter_file, **dict(flattened_tree))
diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py
new file mode 100644
index 00000000..993f216a
--- /dev/null
+++ b/llms/mlx_lm/tuner/utils.py
@@ -0,0 +1,22 @@
+import mlx.core as mx
+from mlx.utils import tree_unflatten
+
+from .lora import LoRALinear
+
+
+def apply_lora_layers(model, adapter_file: str):
+ adapters = list(mx.load(adapter_file).items())
+ linear_replacements = {}
+ lora_layers = set(
+ [name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters]
+ )
+
+ for name, module in model.named_modules():
+ if name in lora_layers:
+ replacement_module = LoRALinear.from_linear(module)
+ linear_replacements[name] = replacement_module
+
+ model.update_modules(tree_unflatten(list(linear_replacements.items())))
+
+ model.update(tree_unflatten(adapters))
+ return model
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index b9a81e16..bc35f9f6 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -1,13 +1,14 @@
+import copy
import glob
import json
import logging
from pathlib import Path
-from typing import Generator, Tuple
+from typing import Any, Dict, Generator, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
-from transformers import AutoTokenizer, PreTrainedTokenizer
+from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
# Local imports
from .models import llama, mixtral, phi2, plamo, qwen
@@ -21,6 +22,7 @@ MODEL_MAPPING = {
"qwen": qwen,
"plamo": plamo,
}
+MAX_FILE_SIZE_GB = 15
linear_class_predicate = (
lambda m: isinstance(m, nn.Linear)
@@ -204,6 +206,7 @@ def load_model(model_path: Path) -> nn.Module:
mx.eval(model.parameters())
+ model.eval()
return model
@@ -211,7 +214,7 @@ def load(
path_or_hf_repo: str, tokenizer_config={}
) -> Tuple[nn.Module, PreTrainedTokenizer]:
"""
- Load the model from a given path or a huggingface repository.
+ Load the model and tokenizer from a given path or a huggingface repository.
Args:
model_path (Path): The path or the huggingface repository to load the model from.
@@ -229,3 +232,103 @@ def load(
model = load_model(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
return model, tokenizer
+
+
+def fetch_from_hub(
+ model_path: Path,
+) -> Tuple[Dict, dict, PreTrainedTokenizer]:
+ model = load_model(model_path)
+
+ config = AutoConfig.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+
+ return model, config.to_dict(), tokenizer
+
+
+def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
+ """
+ Splits the weights into smaller shards.
+
+ Args:
+ weights (dict): Model weights.
+ max_file_size_gb (int): Maximum size of each shard in gigabytes.
+
+ Returns:
+ list: List of weight shards.
+ """
+ max_file_size_bytes = max_file_size_gb << 30
+ shards = []
+ shard, shard_size = {}, 0
+ for k, v in weights.items():
+ estimated_size = v.size * v.dtype.size
+ if shard_size + estimated_size > max_file_size_bytes:
+ shards.append(shard)
+ shard, shard_size = {}, 0
+ shard[k] = v
+ shard_size += estimated_size
+ shards.append(shard)
+ return shards
+
+
+def upload_to_hub(path: str, upload_repo: str, hf_path: str):
+ """
+ Uploads the model to Hugging Face hub.
+
+ Args:
+ path (str): Local path to the model.
+ upload_repo (str): Name of the HF repo to upload to.
+ hf_path (str): Path to the original Hugging Face model.
+ """
+ import os
+
+ from huggingface_hub import HfApi, ModelCard, logging
+
+ card = ModelCard.load(hf_path)
+ card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
+ card.text = f"""
+# {upload_repo}
+This model was converted to MLX format from [`{hf_path}`]().
+Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
+## Use with mlx
+
+```bash
+pip install mlx-lm
+```
+
+```python
+from mlx_lm import load, generate
+
+model, tokenizer = load("{upload_repo}")
+response = generate(model, tokenizer, prompt="hello", verbose=True)
+```
+"""
+ card.save(os.path.join(path, "README.md"))
+
+ logging.set_verbosity_info()
+
+ api = HfApi()
+ api.create_repo(repo_id=upload_repo, exist_ok=True)
+ api.upload_folder(
+ folder_path=path,
+ repo_id=upload_repo,
+ repo_type="model",
+ )
+
+
+def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
+ """Save model weights into specified directory."""
+ if isinstance(save_path, str):
+ save_path = Path(save_path)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ shards = make_shards(weights)
+ shards_count = len(shards)
+ shard_file_format = (
+ "model-{:05d}-of-{:05d}.safetensors"
+ if shards_count > 1
+ else "model.safetensors"
+ )
+
+ for i, shard in enumerate(shards):
+ shard_name = shard_file_format.format(i + 1, shards_count)
+ mx.save_safetensors(str(save_path / shard_name), shard)
diff --git a/llms/phixtral/phixtral.py b/llms/phixtral/phixtral.py
index 6774bc1f..0aa37ebf 100644
--- a/llms/phixtral/phixtral.py
+++ b/llms/phixtral/phixtral.py
@@ -4,7 +4,7 @@ import json
import math
from dataclasses import dataclass, field
from pathlib import Path
-from typing import Optional
+from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -198,7 +198,7 @@ class Model(nn.Module):
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
- ) -> tuple[mx.array, mx.array]:
+ ) -> Tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])