mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
feat: move lora into mlx-lm (#337)
* feat: Add lora and qlora training to mlx-lm --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
85c1ff8fd6
commit
362e88a744
@ -14,6 +14,10 @@ pip install mlx-lm
|
||||
conda install -c conda-forge mlx-lm
|
||||
```
|
||||
|
||||
The `mlx-lm` package also supports LoRA and QLoRA fine-tuning. For more details
|
||||
on this see the [LoRA
|
||||
documentation](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md).
|
||||
|
||||
### Python API
|
||||
|
||||
You can use `mlx-lm` as a module:
|
||||
|
165
llms/mlx_lm/LORA.md
Normal file
165
llms/mlx_lm/LORA.md
Normal file
@ -0,0 +1,165 @@
|
||||
# Fine-Tuning with LoRA or QLoRA
|
||||
|
||||
You can use use the `mlx-lm` package to fine-tune an LLM with low rank
|
||||
adaptation (LoRA) for a target task.[^lora] The example also supports quantized
|
||||
LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
||||
|
||||
- Mistral
|
||||
- Llama
|
||||
- Phi2
|
||||
- Mixtral
|
||||
|
||||
## Contents
|
||||
|
||||
* [Run](#Run)
|
||||
* [Fine-tune](#Fine-tune)
|
||||
* [Evaluate](#Evaluate)
|
||||
* [Generate](#Generate)
|
||||
* [Fuse and Upload](#Fuse-and-Upload)
|
||||
* [Data](#Data)
|
||||
* [Memory Issues](#Memory-Issues)
|
||||
|
||||
## Run
|
||||
|
||||
The main command is `mlx_lm.lora`. To see a full list of options run:
|
||||
|
||||
```shell
|
||||
python -m mlx_lm.lora --help
|
||||
```
|
||||
|
||||
Note, in the following the `--model` argument can be any compatible Hugging
|
||||
Face repo or a local path to a converted model.
|
||||
|
||||
### Fine-tune
|
||||
|
||||
To fine-tune a model use:
|
||||
|
||||
```shell
|
||||
python -m mlx_lm.lora \
|
||||
--model <path_to_model> \
|
||||
--train \
|
||||
--data <path_to_data> \
|
||||
--iters 600
|
||||
```
|
||||
|
||||
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
|
||||
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
|
||||
details on the data format see the section on [Data](#Data).
|
||||
|
||||
For example, to fine-tune a Mistral 7B you can use `--model
|
||||
mistralai/Mistral-7B-v0.1`.
|
||||
|
||||
If `--model` points to a quantized model, then the training will use QLoRA,
|
||||
otherwise it will use regular LoRA.
|
||||
|
||||
By default, the adapter weights are saved in `adapters.npz`. You can specify
|
||||
the output location with `--adapter-file`.
|
||||
|
||||
You can resume fine-tuning with an existing adapter with
|
||||
`--resume-adapter-file <path_to_adapters.npz>`.
|
||||
|
||||
### Evaluate
|
||||
|
||||
To compute test set perplexity use:
|
||||
|
||||
```shell
|
||||
python -m mlx_lm.lora \
|
||||
--model <path_to_model> \
|
||||
--adapter-file <path_to_adapters.npz> \
|
||||
--data <path_to_data> \
|
||||
--test
|
||||
```
|
||||
|
||||
## Fuse and Upload
|
||||
|
||||
You can generate a model fused with the low-rank adapters using the
|
||||
`mlx_lm.fuse` command. This command also allows you to upload the fused model
|
||||
to the Hugging Face Hub.
|
||||
|
||||
To see supported options run:
|
||||
|
||||
```shell
|
||||
python -m mlx_lm.fuse --help
|
||||
```
|
||||
|
||||
To generate the fused model run:
|
||||
|
||||
```shell
|
||||
python -m mlx_lm.fuse --model <path_to_model>
|
||||
```
|
||||
|
||||
This will by default load the adapters from `adapters.npz`, and save the fused
|
||||
model in the path `lora_fused_model/`. All of these are configurable.
|
||||
|
||||
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
||||
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
|
||||
useful for the sake of attribution and model versioning.
|
||||
|
||||
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
|
||||
|
||||
```shell
|
||||
python -m mlx_lm.fuse \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--upload-repo mlx-community/my-4bit-lora-mistral \
|
||||
--hf-path mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
## Data
|
||||
|
||||
The LoRA command expects you to provide a dataset with `--data`. The MLX
|
||||
Examples GitHub repo has an [example of the WikiSQL
|
||||
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
|
||||
correct format.
|
||||
|
||||
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
|
||||
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
||||
loader expects a `test.jsonl` in the data directory. Each line in the `*.jsonl`
|
||||
file should look like:
|
||||
|
||||
```
|
||||
{"text": "This is an example for the model."}
|
||||
```
|
||||
|
||||
Note, other keys will be ignored by the loader.
|
||||
|
||||
## Memory Issues
|
||||
|
||||
Fine-tuning a large model with LoRA requires a machine with a decent amount
|
||||
of memory. Here are some tips to reduce memory use should you need to do so:
|
||||
|
||||
1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
|
||||
with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
|
||||
more details.
|
||||
|
||||
2. Try using a smaller batch size with `--batch-size`. The default is `4` so
|
||||
setting this to `2` or `1` will reduce memory consumption. This may slow
|
||||
things down a little, but will also reduce the memory use.
|
||||
|
||||
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
|
||||
is `16`, so you can try `8` or `4`. This reduces the amount of memory
|
||||
needed for back propagation. It may also reduce the quality of the
|
||||
fine-tuned model if you are fine-tuning with a lot of data.
|
||||
|
||||
4. Longer examples require more memory. If it makes sense for your data, one thing
|
||||
you can do is break your examples into smaller
|
||||
sequences when making the `{train, valid, test}.jsonl` files.
|
||||
|
||||
For example, for a machine with 32 GB the following should run reasonably fast:
|
||||
|
||||
```
|
||||
python lora.py \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--train \
|
||||
--batch-size 1 \
|
||||
--lora-layers 4 \
|
||||
--data wikisql
|
||||
```
|
||||
|
||||
The above command on an M1 Max with 32 GB runs at about 250
|
||||
tokens-per-second, using the MLX Example
|
||||
[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data)
|
||||
data set.
|
||||
|
||||
|
||||
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
@ -3,5 +3,8 @@
|
||||
This an example of large language model text generation that can pull models from
|
||||
the Hugging Face Hub.
|
||||
|
||||
For more information on this example, see the
|
||||
[README](../README.md) in the parent directory.
|
||||
For more information on this example, see the [README](../README.md) in the
|
||||
parent directory.
|
||||
|
||||
This package also supports fine tuning with LoRA or QLoRA. For more information
|
||||
see the [LoRA documentation](LORA.md).
|
||||
|
@ -2,17 +2,21 @@ import argparse
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import transformers
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .utils import get_model_path, linear_class_predicate, load_model
|
||||
|
||||
MAX_FILE_SIZE_GB = 15
|
||||
from .utils import (
|
||||
fetch_from_hub,
|
||||
get_model_path,
|
||||
linear_class_predicate,
|
||||
save_weights,
|
||||
upload_to_hub,
|
||||
)
|
||||
|
||||
|
||||
def configure_parser() -> argparse.ArgumentParser:
|
||||
@ -55,22 +59,9 @@ def configure_parser() -> argparse.ArgumentParser:
|
||||
return parser
|
||||
|
||||
|
||||
def fetch_from_hub(
|
||||
model_path: str,
|
||||
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
|
||||
model_path = get_model_path(model_path)
|
||||
|
||||
model = load_model(model_path)
|
||||
|
||||
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
return model, config.to_dict(), tokenizer
|
||||
|
||||
|
||||
def quantize_model(
|
||||
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
||||
) -> tuple:
|
||||
) -> Tuple:
|
||||
"""
|
||||
Applies quantization to the model weights.
|
||||
|
||||
@ -81,7 +72,7 @@ def quantize_model(
|
||||
q_bits (int): Bits per weight for quantization.
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing quantized weights and config.
|
||||
Tuple: Tuple containing quantized weights and config.
|
||||
"""
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
@ -94,76 +85,6 @@ def quantize_model(
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
||||
"""
|
||||
Splits the weights into smaller shards.
|
||||
|
||||
Args:
|
||||
weights (dict): Model weights.
|
||||
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
||||
|
||||
Returns:
|
||||
list: List of weight shards.
|
||||
"""
|
||||
max_file_size_bytes = max_file_size_gb << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
estimated_size = v.size * v.dtype.size
|
||||
if shard_size + estimated_size > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += estimated_size
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
||||
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
"""
|
||||
Uploads the model to Hugging Face hub.
|
||||
|
||||
Args:
|
||||
path (str): Local path to the model.
|
||||
upload_repo (str): Name of the HF repo to upload to.
|
||||
hf_path (str): Path to the original Hugging Face model.
|
||||
"""
|
||||
import os
|
||||
|
||||
from huggingface_hub import HfApi, ModelCard, logging
|
||||
|
||||
card = ModelCard.load(hf_path)
|
||||
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
||||
card.text = f"""
|
||||
# {upload_repo}
|
||||
This model was converted to MLX format from [`{hf_path}`]().
|
||||
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
|
||||
## Use with mlx
|
||||
|
||||
```bash
|
||||
pip install mlx-lm
|
||||
```
|
||||
|
||||
```python
|
||||
from mlx_lm import load, generate
|
||||
|
||||
model, tokenizer = load("{upload_repo}")
|
||||
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
||||
```
|
||||
"""
|
||||
card.save(os.path.join(path, "README.md"))
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
||||
api.upload_folder(
|
||||
folder_path=path,
|
||||
repo_id=upload_repo,
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
|
||||
def convert(
|
||||
hf_path: str,
|
||||
mlx_path: str = "mlx_model",
|
||||
@ -174,7 +95,8 @@ def convert(
|
||||
upload_repo: str = None,
|
||||
):
|
||||
print("[INFO] Loading")
|
||||
model, config, tokenizer = fetch_from_hub(hf_path)
|
||||
model_path = get_model_path(hf_path)
|
||||
model, config, tokenizer = fetch_from_hub(model_path)
|
||||
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
||||
@ -185,12 +107,17 @@ def convert(
|
||||
model.load_weights(list(weights.items()))
|
||||
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
||||
|
||||
if isinstance(mlx_path, str):
|
||||
mlx_path = Path(mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
shards = make_shards(weights)
|
||||
for i, shard in enumerate(shards):
|
||||
mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard)
|
||||
|
||||
save_weights(mlx_path, weights)
|
||||
|
||||
py_files = glob.glob(str(model_path / "*.py"))
|
||||
for file in py_files:
|
||||
shutil.copy(file, mlx_path)
|
||||
|
||||
tokenizer.save_pretrained(mlx_path)
|
||||
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
|
91
llms/mlx_lm/fuse.py
Normal file
91
llms/mlx_lm/fuse.py
Normal file
@ -0,0 +1,91 @@
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
from .tuner.lora import LoRALinear
|
||||
from .tuner.utils import apply_lora_layers
|
||||
from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub
|
||||
|
||||
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
default="lora_fused_model",
|
||||
help="The path to save the fused model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-file",
|
||||
type=str,
|
||||
default="adapters.npz",
|
||||
help="Path to the trained adapter weights (npz or safetensors).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the original Hugging Face model. Required for upload if --model is a local directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-repo",
|
||||
help="The Hugging Face repo to upload the model to.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
print("Loading pretrained model")
|
||||
args = parse_arguments()
|
||||
|
||||
model_path = get_model_path(args.model)
|
||||
model, config, tokenizer = fetch_from_hub(model_path)
|
||||
|
||||
model.freeze()
|
||||
model = apply_lora_layers(model, args.adapter_file)
|
||||
fused_linears = [
|
||||
(n, m.to_linear())
|
||||
for n, m in model.named_modules()
|
||||
if isinstance(m, LoRALinear)
|
||||
]
|
||||
|
||||
model.update_modules(tree_unflatten(fused_linears))
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
save_path = Path(args.save_path)
|
||||
|
||||
save_weights(save_path, weights)
|
||||
|
||||
py_files = glob.glob(str(model_path / "*.py"))
|
||||
for file in py_files:
|
||||
shutil.copy(file, save_path)
|
||||
|
||||
tokenizer.save_pretrained(save_path)
|
||||
|
||||
with open(save_path / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
if args.upload_repo is not None:
|
||||
hf_path = args.hf_path or (
|
||||
args.model if not Path(args.model).exists() else None
|
||||
)
|
||||
if hf_path is None:
|
||||
raise ValueError(
|
||||
"Must provide original Hugging Face repo to upload local model."
|
||||
)
|
||||
upload_to_hub(args.save_path, args.upload_repo, hf_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
252
llms/mlx_lm/lora.py
Normal file
252
llms/mlx_lm/lora.py
Normal file
@ -0,0 +1,252 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .models import llama, mixtral, phi2
|
||||
from .tuner.lora import LoRALinear
|
||||
from .tuner.trainer import TrainingArgs, evaluate, train
|
||||
from .utils import generate, load
|
||||
|
||||
SUPPORTED_MODELS = [llama.Model, mixtral.Model, phi2.Model]
|
||||
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
# Generation args
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="The maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.8, help="The sampling temperature"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
"-p",
|
||||
type=str,
|
||||
help="The prompt for generation",
|
||||
default=None,
|
||||
)
|
||||
|
||||
# Training args
|
||||
parser.add_argument(
|
||||
"--train",
|
||||
action="store_true",
|
||||
help="Do training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=str,
|
||||
default="data/",
|
||||
help="Directory with {train, valid, test}.jsonl files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-layers",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of layers to fine-tune",
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.")
|
||||
parser.add_argument(
|
||||
"--iters", type=int, default=1000, help="Iterations to train for."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-batches",
|
||||
type=int,
|
||||
default=25,
|
||||
help="Number of validation batches, -1 uses the entire validation set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", type=float, default=1e-5, help="Adam learning rate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps-per-report",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of training steps between loss reporting.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps-per-eval",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Number of training steps between validations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume-adapter-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Load path to resume training with the given adapter weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-file",
|
||||
type=str,
|
||||
default="adapters.npz",
|
||||
help="Save/load path for the trained adapter weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-every",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Save the model every N iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="Evaluate on the test set after training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-batches",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of test set batches, -1 uses the entire test set.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
return parser
|
||||
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
Light-weight wrapper to hold lines from a jsonl file
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path, key: str = "text"):
|
||||
if not path.exists():
|
||||
self._data = None
|
||||
else:
|
||||
with open(path, "r") as fid:
|
||||
self._data = [json.loads(l) for l in fid]
|
||||
self._key = key
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx][self._key]
|
||||
|
||||
def __len__(self):
|
||||
if self._data is None:
|
||||
return 0
|
||||
return len(self._data)
|
||||
|
||||
|
||||
def load_dataset(args):
|
||||
names = ("train", "valid", "test")
|
||||
train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names)
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
"Training set not found or empty. Must provide training set for fine-tuning."
|
||||
)
|
||||
if args.train and len(valid) == 0:
|
||||
raise ValueError(
|
||||
"Validation set not found or empty. Must provide validation set for fine-tuning."
|
||||
)
|
||||
if args.test and len(test) == 0:
|
||||
raise ValueError(
|
||||
"Test set not found or empty. Must provide test set for evaluation."
|
||||
)
|
||||
return train, valid, test
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer = load(args.model)
|
||||
|
||||
if model.__class__ not in SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model {model.__class__} not supported. "
|
||||
f"Supported models: { SUPPORTED_MODELS}"
|
||||
)
|
||||
|
||||
# Freeze all layers other than LORA linears
|
||||
model.freeze()
|
||||
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
|
||||
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
|
||||
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
|
||||
if hasattr(l, "block_sparse_moe"):
|
||||
l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
|
||||
|
||||
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
print(f"Total parameters {p:.3f}M")
|
||||
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
||||
print(f"Trainable parameters {p:.3f}M")
|
||||
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = load_dataset(args)
|
||||
|
||||
# Resume training the given adapters.
|
||||
if args.resume_adapter_file is not None:
|
||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file, strict=False)
|
||||
# init training args
|
||||
trainingArgs = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=args.adapter_file,
|
||||
)
|
||||
if args.train:
|
||||
print("Training")
|
||||
model.train()
|
||||
opt = optim.Adam(learning_rate=args.learning_rate)
|
||||
# Train model
|
||||
train(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=trainingArgs,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
)
|
||||
|
||||
# Load the LoRA adapter weights which we assume should exist by this point
|
||||
if not Path(args.adapter_file).is_file():
|
||||
raise ValueError(
|
||||
f"Adapter file {args.adapter_file} missing. "
|
||||
"Use --train to learn and save the adapters.npz."
|
||||
)
|
||||
model.load_weights(args.adapter_file, strict=False)
|
||||
|
||||
if args.test:
|
||||
print("Testing")
|
||||
model.eval()
|
||||
|
||||
test_loss = evaluate(
|
||||
model=model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
)
|
||||
|
||||
test_ppl = math.exp(test_loss)
|
||||
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
if args.prompt is not None:
|
||||
print("Generating")
|
||||
generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
temp=args.temp,
|
||||
max_tokens=args.max_tokens,
|
||||
prompt=args.prompt,
|
||||
verbose=True,
|
||||
)
|
@ -3,6 +3,7 @@ from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
@ -158,12 +159,28 @@ class MixtralSparseMoeBlock(nn.Module):
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
gates = self.gate(x)
|
||||
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
||||
|
||||
inds = mx.stop_gradient(
|
||||
mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
||||
) # TODO remove it once we figure out how to fine tune TopK in MOE
|
||||
|
||||
scores = mx.softmax(
|
||||
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
||||
axis=-1,
|
||||
).astype(gates.dtype)
|
||||
|
||||
if self.training:
|
||||
mx.eval(inds)
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
continue
|
||||
y[idx1, idx2] = expert(x[idx1])
|
||||
|
||||
y = (y * scores[:, :, None]).sum(axis=1)
|
||||
else:
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
|
||||
@ -229,7 +246,7 @@ class MixtralModel(nn.Module):
|
||||
for e, layer in enumerate(self.layers):
|
||||
h, cache[e] = layer(h, mask, cache[e])
|
||||
|
||||
return self.norm(h[:, T - 1 : T, :]), cache
|
||||
return self.norm(h), cache
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
0
llms/mlx_lm/tuner/__init__.py
Normal file
0
llms/mlx_lm/tuner/__init__.py
Normal file
88
llms/mlx_lm/tuner/lora.py
Normal file
88
llms/mlx_lm/tuner/lora.py
Normal file
@ -0,0 +1,88 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_linear(linear: nn.Linear, rank: int = 8, scale: float = 20.0):
|
||||
# TODO remove when input_dims and output_dims are attributes
|
||||
# on linear and quantized linear
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
if isinstance(linear, nn.QuantizedLinear):
|
||||
input_dims *= 32 // linear.bits
|
||||
lora_lin = LoRALinear(
|
||||
input_dims=input_dims, output_dims=output_dims, rank=rank, scale=scale
|
||||
)
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
|
||||
def to_linear(self):
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = linear.weight
|
||||
is_quantized = isinstance(linear, nn.QuantizedLinear)
|
||||
|
||||
# Use the same type as the linear weight if not quantized
|
||||
dtype = weight.dtype
|
||||
|
||||
if is_quantized:
|
||||
dtype = mx.float16
|
||||
weight = mx.dequantize(
|
||||
weight,
|
||||
linear.scales,
|
||||
linear.biases,
|
||||
linear.group_size,
|
||||
linear.bits,
|
||||
)
|
||||
output_dims, input_dims = weight.shape
|
||||
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||
lora_a = self.lora_a.T.astype(dtype)
|
||||
fused_linear.weight = weight + lora_b @ lora_a
|
||||
if bias:
|
||||
fused_linear.bias = linear.bias
|
||||
|
||||
if is_quantized:
|
||||
fused_linear = nn.QuantizedLinear.from_linear(
|
||||
fused_linear,
|
||||
linear.group_size,
|
||||
linear.bits,
|
||||
)
|
||||
|
||||
return fused_linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
rank: int = 8,
|
||||
bias: bool = False,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(input_dims, rank),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(rank, output_dims))
|
||||
|
||||
def __call__(self, x):
|
||||
dtype = self.linear.weight.dtype
|
||||
if isinstance(self.linear, nn.QuantizedLinear):
|
||||
dtype = self.linear.scales.dtype
|
||||
y = self.linear(x.astype(dtype))
|
||||
z = (x @ self.lora_a) @ self.lora_b
|
||||
return y + self.scale * z
|
204
llms/mlx_lm/tuner/trainer.py
Normal file
204
llms/mlx_lm/tuner/trainer.py
Normal file
@ -0,0 +1,204 @@
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArgs:
|
||||
lora_layers: int = field(
|
||||
default=16, metadata={"help": "Number of layers to fine-tune"}
|
||||
)
|
||||
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
||||
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
|
||||
val_batches: int = field(
|
||||
default=25,
|
||||
metadata={
|
||||
"help": "Number of validation batches, -1 uses the entire validation set."
|
||||
},
|
||||
)
|
||||
steps_per_report: int = field(
|
||||
default=10,
|
||||
metadata={"help": "Number of training steps between loss reporting."},
|
||||
)
|
||||
steps_per_eval: int = field(
|
||||
default=200, metadata={"help": "Number of training steps between validations."}
|
||||
)
|
||||
steps_per_save: int = field(
|
||||
default=100, metadata={"help": "Save the model every number steps"}
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048, metadata={"help": "Maximum sequence length."}
|
||||
)
|
||||
adapter_file: str = field(
|
||||
default="adapter.npz",
|
||||
metadata={"help": "Save/load path for the trained adapter weights."},
|
||||
)
|
||||
|
||||
|
||||
def default_loss(model, inputs, targets, lengths):
|
||||
logits, _ = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
|
||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
||||
ntoks = length_mask.sum()
|
||||
ce = ce.sum() / ntoks
|
||||
|
||||
return ce, ntoks
|
||||
|
||||
|
||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
while True:
|
||||
# Shuffle indices
|
||||
indices = np.arange(len(dataset))
|
||||
indices = np.random.permutation(indices)
|
||||
# Collect batches from dataset
|
||||
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
||||
# Encode batch
|
||||
batch = [
|
||||
tokenizer.encode(dataset[indices[i + j]]) for j in range(batch_size)
|
||||
]
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
# Check if any sequence is longer than max_seq_length
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
"[WARNING] Some sequences are longer than 2048 tokens. "
|
||||
"Consider pre-splitting your data to save memory."
|
||||
)
|
||||
|
||||
# Pad to the max length
|
||||
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
||||
|
||||
for j in range(batch_size):
|
||||
batch_arr[j, : lengths[j]] = batch[j]
|
||||
batch = mx.array(batch_arr)
|
||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
||||
|
||||
def evaluate(
|
||||
model,
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
num_batches,
|
||||
max_seq_length=2048,
|
||||
loss: callable = default_loss,
|
||||
):
|
||||
all_losses = []
|
||||
ntokens = 0
|
||||
for it, batch in zip(
|
||||
range(num_batches),
|
||||
iterate_batches(
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=batch_size,
|
||||
max_seq_length=max_seq_length,
|
||||
),
|
||||
):
|
||||
losses, toks = loss(model, *batch)
|
||||
all_losses.append((losses * toks).item())
|
||||
ntokens += toks.item()
|
||||
|
||||
return np.sum(all_losses) / ntokens
|
||||
|
||||
|
||||
def train(
|
||||
model,
|
||||
tokenizer,
|
||||
optimizer,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
args: TrainingArgs = TrainingArgs(),
|
||||
loss: callable = default_loss,
|
||||
):
|
||||
# Create value and grad function for loss
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||
|
||||
losses = []
|
||||
n_tokens = 0
|
||||
print("Starting training..., iters:", args.iters)
|
||||
# Main training loop
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
range(args.iters),
|
||||
iterate_batches(
|
||||
dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
max_seq_length=args.max_seq_length,
|
||||
train=True,
|
||||
),
|
||||
):
|
||||
# Forward and backward pass
|
||||
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
||||
|
||||
# Model update
|
||||
optimizer.update(model, grad)
|
||||
|
||||
mx.eval(model.parameters(), optimizer.state, lvalue)
|
||||
|
||||
# Record loss
|
||||
losses.append(lvalue.item())
|
||||
n_tokens += toks.item()
|
||||
|
||||
# Report training loss if needed
|
||||
if (it + 1) % args.steps_per_report == 0:
|
||||
train_loss = np.mean(losses)
|
||||
|
||||
stop = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||
f"It/sec {args.steps_per_report / (stop - start):.3f}, "
|
||||
f"Tokens/sec {float(n_tokens) / (stop - start):.3f}"
|
||||
)
|
||||
losses = []
|
||||
n_tokens = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
# Report validation loss if needed
|
||||
if it == 0 or (it + 1) % args.steps_per_eval == 0:
|
||||
stop = time.perf_counter()
|
||||
val_loss = evaluate(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
loss=loss,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
)
|
||||
print(
|
||||
f"Iter {it + 1}: "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val took {(time.perf_counter() - stop):.3f}s"
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save adapter weights if needed
|
||||
if (it + 1) % args.steps_per_save == 0:
|
||||
save_adapter(model=model, adapter_file=args.adapter_file)
|
||||
print(
|
||||
f"Iter {it + 1}: Saved adapter weights to {os.path.join(args.adapter_file)}."
|
||||
)
|
||||
# save final adapter weights
|
||||
save_adapter(model=model, adapter_file=args.adapter_file)
|
||||
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")
|
||||
|
||||
|
||||
def save_adapter(
|
||||
model: nn.Module,
|
||||
adapter_file: str,
|
||||
):
|
||||
flattened_tree = tree_flatten(model.trainable_parameters())
|
||||
|
||||
mx.savez(adapter_file, **dict(flattened_tree))
|
22
llms/mlx_lm/tuner/utils.py
Normal file
22
llms/mlx_lm/tuner/utils.py
Normal file
@ -0,0 +1,22 @@
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .lora import LoRALinear
|
||||
|
||||
|
||||
def apply_lora_layers(model, adapter_file: str):
|
||||
adapters = list(mx.load(adapter_file).items())
|
||||
linear_replacements = {}
|
||||
lora_layers = set(
|
||||
[name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters]
|
||||
)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if name in lora_layers:
|
||||
replacement_module = LoRALinear.from_linear(module)
|
||||
linear_replacements[name] = replacement_module
|
||||
|
||||
model.update_modules(tree_unflatten(list(linear_replacements.items())))
|
||||
|
||||
model.update(tree_unflatten(adapters))
|
||||
return model
|
@ -1,13 +1,14 @@
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Generator, Tuple
|
||||
from typing import Any, Dict, Generator, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models import llama, mixtral, phi2, plamo, qwen
|
||||
@ -21,6 +22,7 @@ MODEL_MAPPING = {
|
||||
"qwen": qwen,
|
||||
"plamo": plamo,
|
||||
}
|
||||
MAX_FILE_SIZE_GB = 15
|
||||
|
||||
linear_class_predicate = (
|
||||
lambda m: isinstance(m, nn.Linear)
|
||||
@ -204,6 +206,7 @@ def load_model(model_path: Path) -> nn.Module:
|
||||
|
||||
mx.eval(model.parameters())
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
@ -211,7 +214,7 @@ def load(
|
||||
path_or_hf_repo: str, tokenizer_config={}
|
||||
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
"""
|
||||
Load the model from a given path or a huggingface repository.
|
||||
Load the model and tokenizer from a given path or a huggingface repository.
|
||||
|
||||
Args:
|
||||
model_path (Path): The path or the huggingface repository to load the model from.
|
||||
@ -229,3 +232,103 @@ def load(
|
||||
model = load_model(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def fetch_from_hub(
|
||||
model_path: Path,
|
||||
) -> Tuple[Dict, dict, PreTrainedTokenizer]:
|
||||
model = load_model(model_path)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
return model, config.to_dict(), tokenizer
|
||||
|
||||
|
||||
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
||||
"""
|
||||
Splits the weights into smaller shards.
|
||||
|
||||
Args:
|
||||
weights (dict): Model weights.
|
||||
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
||||
|
||||
Returns:
|
||||
list: List of weight shards.
|
||||
"""
|
||||
max_file_size_bytes = max_file_size_gb << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
estimated_size = v.size * v.dtype.size
|
||||
if shard_size + estimated_size > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += estimated_size
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
||||
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
"""
|
||||
Uploads the model to Hugging Face hub.
|
||||
|
||||
Args:
|
||||
path (str): Local path to the model.
|
||||
upload_repo (str): Name of the HF repo to upload to.
|
||||
hf_path (str): Path to the original Hugging Face model.
|
||||
"""
|
||||
import os
|
||||
|
||||
from huggingface_hub import HfApi, ModelCard, logging
|
||||
|
||||
card = ModelCard.load(hf_path)
|
||||
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
||||
card.text = f"""
|
||||
# {upload_repo}
|
||||
This model was converted to MLX format from [`{hf_path}`]().
|
||||
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
|
||||
## Use with mlx
|
||||
|
||||
```bash
|
||||
pip install mlx-lm
|
||||
```
|
||||
|
||||
```python
|
||||
from mlx_lm import load, generate
|
||||
|
||||
model, tokenizer = load("{upload_repo}")
|
||||
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
||||
```
|
||||
"""
|
||||
card.save(os.path.join(path, "README.md"))
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
||||
api.upload_folder(
|
||||
folder_path=path,
|
||||
repo_id=upload_repo,
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
|
||||
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
||||
"""Save model weights into specified directory."""
|
||||
if isinstance(save_path, str):
|
||||
save_path = Path(save_path)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shards = make_shards(weights)
|
||||
shards_count = len(shards)
|
||||
shard_file_format = (
|
||||
"model-{:05d}-of-{:05d}.safetensors"
|
||||
if shards_count > 1
|
||||
else "model.safetensors"
|
||||
)
|
||||
|
||||
for i, shard in enumerate(shards):
|
||||
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||
mx.save_safetensors(str(save_path / shard_name), shard)
|
||||
|
@ -4,7 +4,7 @@ import json
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -198,7 +198,7 @@ class Model(nn.Module):
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
|
Loading…
Reference in New Issue
Block a user