mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
feat: move lora into mlx-lm (#337)
* feat: Add lora and qlora training to mlx-lm --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
85c1ff8fd6
commit
362e88a744
@ -14,6 +14,10 @@ pip install mlx-lm
|
|||||||
conda install -c conda-forge mlx-lm
|
conda install -c conda-forge mlx-lm
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The `mlx-lm` package also supports LoRA and QLoRA fine-tuning. For more details
|
||||||
|
on this see the [LoRA
|
||||||
|
documentation](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md).
|
||||||
|
|
||||||
### Python API
|
### Python API
|
||||||
|
|
||||||
You can use `mlx-lm` as a module:
|
You can use `mlx-lm` as a module:
|
||||||
|
165
llms/mlx_lm/LORA.md
Normal file
165
llms/mlx_lm/LORA.md
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
# Fine-Tuning with LoRA or QLoRA
|
||||||
|
|
||||||
|
You can use use the `mlx-lm` package to fine-tune an LLM with low rank
|
||||||
|
adaptation (LoRA) for a target task.[^lora] The example also supports quantized
|
||||||
|
LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
||||||
|
|
||||||
|
- Mistral
|
||||||
|
- Llama
|
||||||
|
- Phi2
|
||||||
|
- Mixtral
|
||||||
|
|
||||||
|
## Contents
|
||||||
|
|
||||||
|
* [Run](#Run)
|
||||||
|
* [Fine-tune](#Fine-tune)
|
||||||
|
* [Evaluate](#Evaluate)
|
||||||
|
* [Generate](#Generate)
|
||||||
|
* [Fuse and Upload](#Fuse-and-Upload)
|
||||||
|
* [Data](#Data)
|
||||||
|
* [Memory Issues](#Memory-Issues)
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
The main command is `mlx_lm.lora`. To see a full list of options run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.lora --help
|
||||||
|
```
|
||||||
|
|
||||||
|
Note, in the following the `--model` argument can be any compatible Hugging
|
||||||
|
Face repo or a local path to a converted model.
|
||||||
|
|
||||||
|
### Fine-tune
|
||||||
|
|
||||||
|
To fine-tune a model use:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.lora \
|
||||||
|
--model <path_to_model> \
|
||||||
|
--train \
|
||||||
|
--data <path_to_data> \
|
||||||
|
--iters 600
|
||||||
|
```
|
||||||
|
|
||||||
|
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
|
||||||
|
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
|
||||||
|
details on the data format see the section on [Data](#Data).
|
||||||
|
|
||||||
|
For example, to fine-tune a Mistral 7B you can use `--model
|
||||||
|
mistralai/Mistral-7B-v0.1`.
|
||||||
|
|
||||||
|
If `--model` points to a quantized model, then the training will use QLoRA,
|
||||||
|
otherwise it will use regular LoRA.
|
||||||
|
|
||||||
|
By default, the adapter weights are saved in `adapters.npz`. You can specify
|
||||||
|
the output location with `--adapter-file`.
|
||||||
|
|
||||||
|
You can resume fine-tuning with an existing adapter with
|
||||||
|
`--resume-adapter-file <path_to_adapters.npz>`.
|
||||||
|
|
||||||
|
### Evaluate
|
||||||
|
|
||||||
|
To compute test set perplexity use:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.lora \
|
||||||
|
--model <path_to_model> \
|
||||||
|
--adapter-file <path_to_adapters.npz> \
|
||||||
|
--data <path_to_data> \
|
||||||
|
--test
|
||||||
|
```
|
||||||
|
|
||||||
|
## Fuse and Upload
|
||||||
|
|
||||||
|
You can generate a model fused with the low-rank adapters using the
|
||||||
|
`mlx_lm.fuse` command. This command also allows you to upload the fused model
|
||||||
|
to the Hugging Face Hub.
|
||||||
|
|
||||||
|
To see supported options run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.fuse --help
|
||||||
|
```
|
||||||
|
|
||||||
|
To generate the fused model run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.fuse --model <path_to_model>
|
||||||
|
```
|
||||||
|
|
||||||
|
This will by default load the adapters from `adapters.npz`, and save the fused
|
||||||
|
model in the path `lora_fused_model/`. All of these are configurable.
|
||||||
|
|
||||||
|
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
||||||
|
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
|
||||||
|
useful for the sake of attribution and model versioning.
|
||||||
|
|
||||||
|
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.fuse \
|
||||||
|
--model mistralai/Mistral-7B-v0.1 \
|
||||||
|
--upload-repo mlx-community/my-4bit-lora-mistral \
|
||||||
|
--hf-path mistralai/Mistral-7B-v0.1
|
||||||
|
```
|
||||||
|
|
||||||
|
## Data
|
||||||
|
|
||||||
|
The LoRA command expects you to provide a dataset with `--data`. The MLX
|
||||||
|
Examples GitHub repo has an [example of the WikiSQL
|
||||||
|
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
|
||||||
|
correct format.
|
||||||
|
|
||||||
|
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
|
||||||
|
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
||||||
|
loader expects a `test.jsonl` in the data directory. Each line in the `*.jsonl`
|
||||||
|
file should look like:
|
||||||
|
|
||||||
|
```
|
||||||
|
{"text": "This is an example for the model."}
|
||||||
|
```
|
||||||
|
|
||||||
|
Note, other keys will be ignored by the loader.
|
||||||
|
|
||||||
|
## Memory Issues
|
||||||
|
|
||||||
|
Fine-tuning a large model with LoRA requires a machine with a decent amount
|
||||||
|
of memory. Here are some tips to reduce memory use should you need to do so:
|
||||||
|
|
||||||
|
1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
|
||||||
|
with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
|
||||||
|
more details.
|
||||||
|
|
||||||
|
2. Try using a smaller batch size with `--batch-size`. The default is `4` so
|
||||||
|
setting this to `2` or `1` will reduce memory consumption. This may slow
|
||||||
|
things down a little, but will also reduce the memory use.
|
||||||
|
|
||||||
|
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
|
||||||
|
is `16`, so you can try `8` or `4`. This reduces the amount of memory
|
||||||
|
needed for back propagation. It may also reduce the quality of the
|
||||||
|
fine-tuned model if you are fine-tuning with a lot of data.
|
||||||
|
|
||||||
|
4. Longer examples require more memory. If it makes sense for your data, one thing
|
||||||
|
you can do is break your examples into smaller
|
||||||
|
sequences when making the `{train, valid, test}.jsonl` files.
|
||||||
|
|
||||||
|
For example, for a machine with 32 GB the following should run reasonably fast:
|
||||||
|
|
||||||
|
```
|
||||||
|
python lora.py \
|
||||||
|
--model mistralai/Mistral-7B-v0.1 \
|
||||||
|
--train \
|
||||||
|
--batch-size 1 \
|
||||||
|
--lora-layers 4 \
|
||||||
|
--data wikisql
|
||||||
|
```
|
||||||
|
|
||||||
|
The above command on an M1 Max with 32 GB runs at about 250
|
||||||
|
tokens-per-second, using the MLX Example
|
||||||
|
[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data)
|
||||||
|
data set.
|
||||||
|
|
||||||
|
|
||||||
|
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||||
|
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
@ -3,5 +3,8 @@
|
|||||||
This an example of large language model text generation that can pull models from
|
This an example of large language model text generation that can pull models from
|
||||||
the Hugging Face Hub.
|
the Hugging Face Hub.
|
||||||
|
|
||||||
For more information on this example, see the
|
For more information on this example, see the [README](../README.md) in the
|
||||||
[README](../README.md) in the parent directory.
|
parent directory.
|
||||||
|
|
||||||
|
This package also supports fine tuning with LoRA or QLoRA. For more information
|
||||||
|
see the [LoRA documentation](LORA.md).
|
||||||
|
@ -2,17 +2,21 @@ import argparse
|
|||||||
import copy
|
import copy
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import transformers
|
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
from .utils import get_model_path, linear_class_predicate, load_model
|
from .utils import (
|
||||||
|
fetch_from_hub,
|
||||||
MAX_FILE_SIZE_GB = 15
|
get_model_path,
|
||||||
|
linear_class_predicate,
|
||||||
|
save_weights,
|
||||||
|
upload_to_hub,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def configure_parser() -> argparse.ArgumentParser:
|
def configure_parser() -> argparse.ArgumentParser:
|
||||||
@ -55,22 +59,9 @@ def configure_parser() -> argparse.ArgumentParser:
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def fetch_from_hub(
|
|
||||||
model_path: str,
|
|
||||||
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
|
|
||||||
model_path = get_model_path(model_path)
|
|
||||||
|
|
||||||
model = load_model(model_path)
|
|
||||||
|
|
||||||
config = transformers.AutoConfig.from_pretrained(model_path)
|
|
||||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
|
||||||
|
|
||||||
return model, config.to_dict(), tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def quantize_model(
|
def quantize_model(
|
||||||
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
||||||
) -> tuple:
|
) -> Tuple:
|
||||||
"""
|
"""
|
||||||
Applies quantization to the model weights.
|
Applies quantization to the model weights.
|
||||||
|
|
||||||
@ -81,7 +72,7 @@ def quantize_model(
|
|||||||
q_bits (int): Bits per weight for quantization.
|
q_bits (int): Bits per weight for quantization.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: Tuple containing quantized weights and config.
|
Tuple: Tuple containing quantized weights and config.
|
||||||
"""
|
"""
|
||||||
quantized_config = copy.deepcopy(config)
|
quantized_config = copy.deepcopy(config)
|
||||||
|
|
||||||
@ -94,76 +85,6 @@ def quantize_model(
|
|||||||
return quantized_weights, quantized_config
|
return quantized_weights, quantized_config
|
||||||
|
|
||||||
|
|
||||||
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
|
||||||
"""
|
|
||||||
Splits the weights into smaller shards.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
weights (dict): Model weights.
|
|
||||||
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of weight shards.
|
|
||||||
"""
|
|
||||||
max_file_size_bytes = max_file_size_gb << 30
|
|
||||||
shards = []
|
|
||||||
shard, shard_size = {}, 0
|
|
||||||
for k, v in weights.items():
|
|
||||||
estimated_size = v.size * v.dtype.size
|
|
||||||
if shard_size + estimated_size > max_file_size_bytes:
|
|
||||||
shards.append(shard)
|
|
||||||
shard, shard_size = {}, 0
|
|
||||||
shard[k] = v
|
|
||||||
shard_size += estimated_size
|
|
||||||
shards.append(shard)
|
|
||||||
return shards
|
|
||||||
|
|
||||||
|
|
||||||
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|
||||||
"""
|
|
||||||
Uploads the model to Hugging Face hub.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): Local path to the model.
|
|
||||||
upload_repo (str): Name of the HF repo to upload to.
|
|
||||||
hf_path (str): Path to the original Hugging Face model.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
|
|
||||||
from huggingface_hub import HfApi, ModelCard, logging
|
|
||||||
|
|
||||||
card = ModelCard.load(hf_path)
|
|
||||||
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
|
||||||
card.text = f"""
|
|
||||||
# {upload_repo}
|
|
||||||
This model was converted to MLX format from [`{hf_path}`]().
|
|
||||||
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
|
|
||||||
## Use with mlx
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install mlx-lm
|
|
||||||
```
|
|
||||||
|
|
||||||
```python
|
|
||||||
from mlx_lm import load, generate
|
|
||||||
|
|
||||||
model, tokenizer = load("{upload_repo}")
|
|
||||||
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
card.save(os.path.join(path, "README.md"))
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
|
||||||
|
|
||||||
api = HfApi()
|
|
||||||
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
|
||||||
api.upload_folder(
|
|
||||||
folder_path=path,
|
|
||||||
repo_id=upload_repo,
|
|
||||||
repo_type="model",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def convert(
|
def convert(
|
||||||
hf_path: str,
|
hf_path: str,
|
||||||
mlx_path: str = "mlx_model",
|
mlx_path: str = "mlx_model",
|
||||||
@ -174,7 +95,8 @@ def convert(
|
|||||||
upload_repo: str = None,
|
upload_repo: str = None,
|
||||||
):
|
):
|
||||||
print("[INFO] Loading")
|
print("[INFO] Loading")
|
||||||
model, config, tokenizer = fetch_from_hub(hf_path)
|
model_path = get_model_path(hf_path)
|
||||||
|
model, config, tokenizer = fetch_from_hub(model_path)
|
||||||
|
|
||||||
weights = dict(tree_flatten(model.parameters()))
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
||||||
@ -185,12 +107,17 @@ def convert(
|
|||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()))
|
||||||
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
||||||
|
|
||||||
|
if isinstance(mlx_path, str):
|
||||||
mlx_path = Path(mlx_path)
|
mlx_path = Path(mlx_path)
|
||||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
shards = make_shards(weights)
|
save_weights(mlx_path, weights)
|
||||||
for i, shard in enumerate(shards):
|
|
||||||
mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard)
|
py_files = glob.glob(str(model_path / "*.py"))
|
||||||
|
for file in py_files:
|
||||||
|
shutil.copy(file, mlx_path)
|
||||||
|
|
||||||
tokenizer.save_pretrained(mlx_path)
|
tokenizer.save_pretrained(mlx_path)
|
||||||
|
|
||||||
with open(mlx_path / "config.json", "w") as fid:
|
with open(mlx_path / "config.json", "w") as fid:
|
||||||
json.dump(config, fid, indent=4)
|
json.dump(config, fid, indent=4)
|
||||||
|
|
||||||
|
91
llms/mlx_lm/fuse.py
Normal file
91
llms/mlx_lm/fuse.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
|
||||||
|
from .tuner.lora import LoRALinear
|
||||||
|
from .tuner.utils import apply_lora_layers
|
||||||
|
from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
default="mlx_model",
|
||||||
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path",
|
||||||
|
default="lora_fused_model",
|
||||||
|
help="The path to save the fused model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adapter-file",
|
||||||
|
type=str,
|
||||||
|
default="adapters.npz",
|
||||||
|
help="Path to the trained adapter weights (npz or safetensors).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the original Hugging Face model. Required for upload if --model is a local directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upload-repo",
|
||||||
|
help="The Hugging Face repo to upload the model to.",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
print("Loading pretrained model")
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
model_path = get_model_path(args.model)
|
||||||
|
model, config, tokenizer = fetch_from_hub(model_path)
|
||||||
|
|
||||||
|
model.freeze()
|
||||||
|
model = apply_lora_layers(model, args.adapter_file)
|
||||||
|
fused_linears = [
|
||||||
|
(n, m.to_linear())
|
||||||
|
for n, m in model.named_modules()
|
||||||
|
if isinstance(m, LoRALinear)
|
||||||
|
]
|
||||||
|
|
||||||
|
model.update_modules(tree_unflatten(fused_linears))
|
||||||
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
|
save_path = Path(args.save_path)
|
||||||
|
|
||||||
|
save_weights(save_path, weights)
|
||||||
|
|
||||||
|
py_files = glob.glob(str(model_path / "*.py"))
|
||||||
|
for file in py_files:
|
||||||
|
shutil.copy(file, save_path)
|
||||||
|
|
||||||
|
tokenizer.save_pretrained(save_path)
|
||||||
|
|
||||||
|
with open(save_path / "config.json", "w") as fid:
|
||||||
|
json.dump(config, fid, indent=4)
|
||||||
|
|
||||||
|
if args.upload_repo is not None:
|
||||||
|
hf_path = args.hf_path or (
|
||||||
|
args.model if not Path(args.model).exists() else None
|
||||||
|
)
|
||||||
|
if hf_path is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Must provide original Hugging Face repo to upload local model."
|
||||||
|
)
|
||||||
|
upload_to_hub(args.save_path, args.upload_repo, hf_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
252
llms/mlx_lm/lora.py
Normal file
252
llms/mlx_lm/lora.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
import numpy as np
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
from .models import llama, mixtral, phi2
|
||||||
|
from .tuner.lora import LoRALinear
|
||||||
|
from .tuner.trainer import TrainingArgs, evaluate, train
|
||||||
|
from .utils import generate, load
|
||||||
|
|
||||||
|
SUPPORTED_MODELS = [llama.Model, mixtral.Model, phi2.Model]
|
||||||
|
|
||||||
|
|
||||||
|
def build_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
default="mlx_model",
|
||||||
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
|
)
|
||||||
|
# Generation args
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens",
|
||||||
|
"-m",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="The maximum number of tokens to generate",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temp", type=float, default=0.8, help="The sampling temperature"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
"-p",
|
||||||
|
type=str,
|
||||||
|
help="The prompt for generation",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training args
|
||||||
|
parser.add_argument(
|
||||||
|
"--train",
|
||||||
|
action="store_true",
|
||||||
|
help="Do training",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data",
|
||||||
|
type=str,
|
||||||
|
default="data/",
|
||||||
|
help="Directory with {train, valid, test}.jsonl files",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-layers",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="Number of layers to fine-tune",
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--iters", type=int, default=1000, help="Iterations to train for."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--val-batches",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="Number of validation batches, -1 uses the entire validation set.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning-rate", type=float, default=1e-5, help="Adam learning rate."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--steps-per-report",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Number of training steps between loss reporting.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--steps-per-eval",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="Number of training steps between validations.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume-adapter-file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Load path to resume training with the given adapter weights.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adapter-file",
|
||||||
|
type=str,
|
||||||
|
default="adapters.npz",
|
||||||
|
help="Save/load path for the trained adapter weights.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-every",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Save the model every N iterations.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test",
|
||||||
|
action="store_true",
|
||||||
|
help="Evaluate on the test set after training",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-batches",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="Number of test set batches, -1 uses the entire test set.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset:
|
||||||
|
"""
|
||||||
|
Light-weight wrapper to hold lines from a jsonl file
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: Path, key: str = "text"):
|
||||||
|
if not path.exists():
|
||||||
|
self._data = None
|
||||||
|
else:
|
||||||
|
with open(path, "r") as fid:
|
||||||
|
self._data = [json.loads(l) for l in fid]
|
||||||
|
self._key = key
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
return self._data[idx][self._key]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self._data is None:
|
||||||
|
return 0
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(args):
|
||||||
|
names = ("train", "valid", "test")
|
||||||
|
train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names)
|
||||||
|
if args.train and len(train) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Training set not found or empty. Must provide training set for fine-tuning."
|
||||||
|
)
|
||||||
|
if args.train and len(valid) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Validation set not found or empty. Must provide validation set for fine-tuning."
|
||||||
|
)
|
||||||
|
if args.test and len(test) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Test set not found or empty. Must provide test set for evaluation."
|
||||||
|
)
|
||||||
|
return train, valid, test
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = build_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
|
print("Loading pretrained model")
|
||||||
|
model, tokenizer = load(args.model)
|
||||||
|
|
||||||
|
if model.__class__ not in SUPPORTED_MODELS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {model.__class__} not supported. "
|
||||||
|
f"Supported models: { SUPPORTED_MODELS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Freeze all layers other than LORA linears
|
||||||
|
model.freeze()
|
||||||
|
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
|
||||||
|
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
|
||||||
|
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
|
||||||
|
if hasattr(l, "block_sparse_moe"):
|
||||||
|
l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
|
||||||
|
|
||||||
|
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||||
|
print(f"Total parameters {p:.3f}M")
|
||||||
|
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
||||||
|
print(f"Trainable parameters {p:.3f}M")
|
||||||
|
|
||||||
|
print("Loading datasets")
|
||||||
|
train_set, valid_set, test_set = load_dataset(args)
|
||||||
|
|
||||||
|
# Resume training the given adapters.
|
||||||
|
if args.resume_adapter_file is not None:
|
||||||
|
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||||
|
model.load_weights(args.resume_adapter_file, strict=False)
|
||||||
|
# init training args
|
||||||
|
trainingArgs = TrainingArgs(
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
iters=args.iters,
|
||||||
|
val_batches=args.val_batches,
|
||||||
|
steps_per_report=args.steps_per_report,
|
||||||
|
steps_per_eval=args.steps_per_eval,
|
||||||
|
steps_per_save=args.save_every,
|
||||||
|
adapter_file=args.adapter_file,
|
||||||
|
)
|
||||||
|
if args.train:
|
||||||
|
print("Training")
|
||||||
|
model.train()
|
||||||
|
opt = optim.Adam(learning_rate=args.learning_rate)
|
||||||
|
# Train model
|
||||||
|
train(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
args=trainingArgs,
|
||||||
|
optimizer=opt,
|
||||||
|
train_dataset=train_set,
|
||||||
|
val_dataset=valid_set,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the LoRA adapter weights which we assume should exist by this point
|
||||||
|
if not Path(args.adapter_file).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
f"Adapter file {args.adapter_file} missing. "
|
||||||
|
"Use --train to learn and save the adapters.npz."
|
||||||
|
)
|
||||||
|
model.load_weights(args.adapter_file, strict=False)
|
||||||
|
|
||||||
|
if args.test:
|
||||||
|
print("Testing")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
test_loss = evaluate(
|
||||||
|
model=model,
|
||||||
|
dataset=test_set,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_batches=args.test_batches,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_ppl = math.exp(test_loss)
|
||||||
|
|
||||||
|
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||||
|
|
||||||
|
if args.prompt is not None:
|
||||||
|
print("Generating")
|
||||||
|
generate(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
temp=args.temp,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
prompt=args.prompt,
|
||||||
|
verbose=True,
|
||||||
|
)
|
@ -3,6 +3,7 @@ from typing import Dict, Optional, Tuple, Union
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
@ -158,12 +159,28 @@ class MixtralSparseMoeBlock(nn.Module):
|
|||||||
x = x.reshape(-1, x.shape[-1])
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
|
||||||
gates = self.gate(x)
|
gates = self.gate(x)
|
||||||
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
|
||||||
|
inds = mx.stop_gradient(
|
||||||
|
mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
||||||
|
) # TODO remove it once we figure out how to fine tune TopK in MOE
|
||||||
|
|
||||||
scores = mx.softmax(
|
scores = mx.softmax(
|
||||||
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
||||||
axis=-1,
|
axis=-1,
|
||||||
).astype(gates.dtype)
|
).astype(gates.dtype)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
mx.eval(inds)
|
||||||
|
inds = np.array(inds)
|
||||||
|
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
|
||||||
|
for e, expert in enumerate(self.experts):
|
||||||
|
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||||
|
if idx1.size == 0:
|
||||||
|
continue
|
||||||
|
y[idx1, idx2] = expert(x[idx1])
|
||||||
|
|
||||||
|
y = (y * scores[:, :, None]).sum(axis=1)
|
||||||
|
else:
|
||||||
y = []
|
y = []
|
||||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||||
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
|
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
|
||||||
@ -229,7 +246,7 @@ class MixtralModel(nn.Module):
|
|||||||
for e, layer in enumerate(self.layers):
|
for e, layer in enumerate(self.layers):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h, cache[e] = layer(h, mask, cache[e])
|
||||||
|
|
||||||
return self.norm(h[:, T - 1 : T, :]), cache
|
return self.norm(h), cache
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
|
0
llms/mlx_lm/tuner/__init__.py
Normal file
0
llms/mlx_lm/tuner/__init__.py
Normal file
88
llms/mlx_lm/tuner/lora.py
Normal file
88
llms/mlx_lm/tuner/lora.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALinear(nn.Module):
|
||||||
|
@staticmethod
|
||||||
|
def from_linear(linear: nn.Linear, rank: int = 8, scale: float = 20.0):
|
||||||
|
# TODO remove when input_dims and output_dims are attributes
|
||||||
|
# on linear and quantized linear
|
||||||
|
output_dims, input_dims = linear.weight.shape
|
||||||
|
if isinstance(linear, nn.QuantizedLinear):
|
||||||
|
input_dims *= 32 // linear.bits
|
||||||
|
lora_lin = LoRALinear(
|
||||||
|
input_dims=input_dims, output_dims=output_dims, rank=rank, scale=scale
|
||||||
|
)
|
||||||
|
lora_lin.linear = linear
|
||||||
|
return lora_lin
|
||||||
|
|
||||||
|
def to_linear(self):
|
||||||
|
linear = self.linear
|
||||||
|
bias = "bias" in linear
|
||||||
|
weight = linear.weight
|
||||||
|
is_quantized = isinstance(linear, nn.QuantizedLinear)
|
||||||
|
|
||||||
|
# Use the same type as the linear weight if not quantized
|
||||||
|
dtype = weight.dtype
|
||||||
|
|
||||||
|
if is_quantized:
|
||||||
|
dtype = mx.float16
|
||||||
|
weight = mx.dequantize(
|
||||||
|
weight,
|
||||||
|
linear.scales,
|
||||||
|
linear.biases,
|
||||||
|
linear.group_size,
|
||||||
|
linear.bits,
|
||||||
|
)
|
||||||
|
output_dims, input_dims = weight.shape
|
||||||
|
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||||
|
|
||||||
|
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||||
|
lora_a = self.lora_a.T.astype(dtype)
|
||||||
|
fused_linear.weight = weight + lora_b @ lora_a
|
||||||
|
if bias:
|
||||||
|
fused_linear.bias = linear.bias
|
||||||
|
|
||||||
|
if is_quantized:
|
||||||
|
fused_linear = nn.QuantizedLinear.from_linear(
|
||||||
|
fused_linear,
|
||||||
|
linear.group_size,
|
||||||
|
linear.bits,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fused_linear
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: int,
|
||||||
|
output_dims: int,
|
||||||
|
rank: int = 8,
|
||||||
|
bias: bool = False,
|
||||||
|
scale: float = 20.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Regular linear layer weights
|
||||||
|
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||||
|
|
||||||
|
# Scale for low-rank update
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
# Low rank lora weights
|
||||||
|
scale = 1 / math.sqrt(input_dims)
|
||||||
|
self.lora_a = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(input_dims, rank),
|
||||||
|
)
|
||||||
|
self.lora_b = mx.zeros(shape=(rank, output_dims))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
dtype = self.linear.weight.dtype
|
||||||
|
if isinstance(self.linear, nn.QuantizedLinear):
|
||||||
|
dtype = self.linear.scales.dtype
|
||||||
|
y = self.linear(x.astype(dtype))
|
||||||
|
z = (x @ self.lora_a) @ self.lora_b
|
||||||
|
return y + self.scale * z
|
204
llms/mlx_lm/tuner/trainer.py
Normal file
204
llms/mlx_lm/tuner/trainer.py
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArgs:
|
||||||
|
lora_layers: int = field(
|
||||||
|
default=16, metadata={"help": "Number of layers to fine-tune"}
|
||||||
|
)
|
||||||
|
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
||||||
|
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
|
||||||
|
val_batches: int = field(
|
||||||
|
default=25,
|
||||||
|
metadata={
|
||||||
|
"help": "Number of validation batches, -1 uses the entire validation set."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
steps_per_report: int = field(
|
||||||
|
default=10,
|
||||||
|
metadata={"help": "Number of training steps between loss reporting."},
|
||||||
|
)
|
||||||
|
steps_per_eval: int = field(
|
||||||
|
default=200, metadata={"help": "Number of training steps between validations."}
|
||||||
|
)
|
||||||
|
steps_per_save: int = field(
|
||||||
|
default=100, metadata={"help": "Save the model every number steps"}
|
||||||
|
)
|
||||||
|
max_seq_length: int = field(
|
||||||
|
default=2048, metadata={"help": "Maximum sequence length."}
|
||||||
|
)
|
||||||
|
adapter_file: str = field(
|
||||||
|
default="adapter.npz",
|
||||||
|
metadata={"help": "Save/load path for the trained adapter weights."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def default_loss(model, inputs, targets, lengths):
|
||||||
|
logits, _ = model(inputs)
|
||||||
|
logits = logits.astype(mx.float32)
|
||||||
|
|
||||||
|
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||||
|
|
||||||
|
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
||||||
|
ntoks = length_mask.sum()
|
||||||
|
ce = ce.sum() / ntoks
|
||||||
|
|
||||||
|
return ce, ntoks
|
||||||
|
|
||||||
|
|
||||||
|
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||||
|
while True:
|
||||||
|
# Shuffle indices
|
||||||
|
indices = np.arange(len(dataset))
|
||||||
|
indices = np.random.permutation(indices)
|
||||||
|
# Collect batches from dataset
|
||||||
|
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
||||||
|
# Encode batch
|
||||||
|
batch = [
|
||||||
|
tokenizer.encode(dataset[indices[i + j]]) for j in range(batch_size)
|
||||||
|
]
|
||||||
|
lengths = [len(x) for x in batch]
|
||||||
|
|
||||||
|
# Check if any sequence is longer than max_seq_length
|
||||||
|
if max(lengths) > max_seq_length:
|
||||||
|
print(
|
||||||
|
"[WARNING] Some sequences are longer than 2048 tokens. "
|
||||||
|
"Consider pre-splitting your data to save memory."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pad to the max length
|
||||||
|
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
||||||
|
|
||||||
|
for j in range(batch_size):
|
||||||
|
batch_arr[j, : lengths[j]] = batch[j]
|
||||||
|
batch = mx.array(batch_arr)
|
||||||
|
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
||||||
|
|
||||||
|
if not train:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
model,
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
batch_size,
|
||||||
|
num_batches,
|
||||||
|
max_seq_length=2048,
|
||||||
|
loss: callable = default_loss,
|
||||||
|
):
|
||||||
|
all_losses = []
|
||||||
|
ntokens = 0
|
||||||
|
for it, batch in zip(
|
||||||
|
range(num_batches),
|
||||||
|
iterate_batches(
|
||||||
|
dataset=dataset,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_seq_length=max_seq_length,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
losses, toks = loss(model, *batch)
|
||||||
|
all_losses.append((losses * toks).item())
|
||||||
|
ntokens += toks.item()
|
||||||
|
|
||||||
|
return np.sum(all_losses) / ntokens
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
optimizer,
|
||||||
|
train_dataset,
|
||||||
|
val_dataset,
|
||||||
|
args: TrainingArgs = TrainingArgs(),
|
||||||
|
loss: callable = default_loss,
|
||||||
|
):
|
||||||
|
# Create value and grad function for loss
|
||||||
|
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
n_tokens = 0
|
||||||
|
print("Starting training..., iters:", args.iters)
|
||||||
|
# Main training loop
|
||||||
|
start = time.perf_counter()
|
||||||
|
for it, batch in zip(
|
||||||
|
range(args.iters),
|
||||||
|
iterate_batches(
|
||||||
|
dataset=train_dataset,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
train=True,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# Forward and backward pass
|
||||||
|
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
||||||
|
|
||||||
|
# Model update
|
||||||
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
|
mx.eval(model.parameters(), optimizer.state, lvalue)
|
||||||
|
|
||||||
|
# Record loss
|
||||||
|
losses.append(lvalue.item())
|
||||||
|
n_tokens += toks.item()
|
||||||
|
|
||||||
|
# Report training loss if needed
|
||||||
|
if (it + 1) % args.steps_per_report == 0:
|
||||||
|
train_loss = np.mean(losses)
|
||||||
|
|
||||||
|
stop = time.perf_counter()
|
||||||
|
print(
|
||||||
|
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||||
|
f"It/sec {args.steps_per_report / (stop - start):.3f}, "
|
||||||
|
f"Tokens/sec {float(n_tokens) / (stop - start):.3f}"
|
||||||
|
)
|
||||||
|
losses = []
|
||||||
|
n_tokens = 0
|
||||||
|
start = time.perf_counter()
|
||||||
|
|
||||||
|
# Report validation loss if needed
|
||||||
|
if it == 0 or (it + 1) % args.steps_per_eval == 0:
|
||||||
|
stop = time.perf_counter()
|
||||||
|
val_loss = evaluate(
|
||||||
|
model=model,
|
||||||
|
dataset=val_dataset,
|
||||||
|
loss=loss,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_batches=args.val_batches,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Iter {it + 1}: "
|
||||||
|
f"Val loss {val_loss:.3f}, "
|
||||||
|
f"Val took {(time.perf_counter() - stop):.3f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
|
||||||
|
# Save adapter weights if needed
|
||||||
|
if (it + 1) % args.steps_per_save == 0:
|
||||||
|
save_adapter(model=model, adapter_file=args.adapter_file)
|
||||||
|
print(
|
||||||
|
f"Iter {it + 1}: Saved adapter weights to {os.path.join(args.adapter_file)}."
|
||||||
|
)
|
||||||
|
# save final adapter weights
|
||||||
|
save_adapter(model=model, adapter_file=args.adapter_file)
|
||||||
|
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")
|
||||||
|
|
||||||
|
|
||||||
|
def save_adapter(
|
||||||
|
model: nn.Module,
|
||||||
|
adapter_file: str,
|
||||||
|
):
|
||||||
|
flattened_tree = tree_flatten(model.trainable_parameters())
|
||||||
|
|
||||||
|
mx.savez(adapter_file, **dict(flattened_tree))
|
22
llms/mlx_lm/tuner/utils.py
Normal file
22
llms/mlx_lm/tuner/utils.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
from mlx.utils import tree_unflatten
|
||||||
|
|
||||||
|
from .lora import LoRALinear
|
||||||
|
|
||||||
|
|
||||||
|
def apply_lora_layers(model, adapter_file: str):
|
||||||
|
adapters = list(mx.load(adapter_file).items())
|
||||||
|
linear_replacements = {}
|
||||||
|
lora_layers = set(
|
||||||
|
[name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters]
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if name in lora_layers:
|
||||||
|
replacement_module = LoRALinear.from_linear(module)
|
||||||
|
linear_replacements[name] = replacement_module
|
||||||
|
|
||||||
|
model.update_modules(tree_unflatten(list(linear_replacements.items())))
|
||||||
|
|
||||||
|
model.update(tree_unflatten(adapters))
|
||||||
|
return model
|
@ -1,13 +1,14 @@
|
|||||||
|
import copy
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, Tuple
|
from typing import Any, Dict, Generator, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .models import llama, mixtral, phi2, plamo, qwen
|
from .models import llama, mixtral, phi2, plamo, qwen
|
||||||
@ -21,6 +22,7 @@ MODEL_MAPPING = {
|
|||||||
"qwen": qwen,
|
"qwen": qwen,
|
||||||
"plamo": plamo,
|
"plamo": plamo,
|
||||||
}
|
}
|
||||||
|
MAX_FILE_SIZE_GB = 15
|
||||||
|
|
||||||
linear_class_predicate = (
|
linear_class_predicate = (
|
||||||
lambda m: isinstance(m, nn.Linear)
|
lambda m: isinstance(m, nn.Linear)
|
||||||
@ -204,6 +206,7 @@ def load_model(model_path: Path) -> nn.Module:
|
|||||||
|
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -211,7 +214,7 @@ def load(
|
|||||||
path_or_hf_repo: str, tokenizer_config={}
|
path_or_hf_repo: str, tokenizer_config={}
|
||||||
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||||
"""
|
"""
|
||||||
Load the model from a given path or a huggingface repository.
|
Load the model and tokenizer from a given path or a huggingface repository.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path (Path): The path or the huggingface repository to load the model from.
|
model_path (Path): The path or the huggingface repository to load the model from.
|
||||||
@ -229,3 +232,103 @@ def load(
|
|||||||
model = load_model(model_path)
|
model = load_model(model_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_from_hub(
|
||||||
|
model_path: Path,
|
||||||
|
) -> Tuple[Dict, dict, PreTrainedTokenizer]:
|
||||||
|
model = load_model(model_path)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model_path)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
|
||||||
|
return model, config.to_dict(), tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
||||||
|
"""
|
||||||
|
Splits the weights into smaller shards.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights (dict): Model weights.
|
||||||
|
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of weight shards.
|
||||||
|
"""
|
||||||
|
max_file_size_bytes = max_file_size_gb << 30
|
||||||
|
shards = []
|
||||||
|
shard, shard_size = {}, 0
|
||||||
|
for k, v in weights.items():
|
||||||
|
estimated_size = v.size * v.dtype.size
|
||||||
|
if shard_size + estimated_size > max_file_size_bytes:
|
||||||
|
shards.append(shard)
|
||||||
|
shard, shard_size = {}, 0
|
||||||
|
shard[k] = v
|
||||||
|
shard_size += estimated_size
|
||||||
|
shards.append(shard)
|
||||||
|
return shards
|
||||||
|
|
||||||
|
|
||||||
|
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||||
|
"""
|
||||||
|
Uploads the model to Hugging Face hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Local path to the model.
|
||||||
|
upload_repo (str): Name of the HF repo to upload to.
|
||||||
|
hf_path (str): Path to the original Hugging Face model.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi, ModelCard, logging
|
||||||
|
|
||||||
|
card = ModelCard.load(hf_path)
|
||||||
|
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
||||||
|
card.text = f"""
|
||||||
|
# {upload_repo}
|
||||||
|
This model was converted to MLX format from [`{hf_path}`]().
|
||||||
|
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
|
||||||
|
## Use with mlx
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx-lm
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mlx_lm import load, generate
|
||||||
|
|
||||||
|
model, tokenizer = load("{upload_repo}")
|
||||||
|
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
card.save(os.path.join(path, "README.md"))
|
||||||
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
||||||
|
api.upload_folder(
|
||||||
|
folder_path=path,
|
||||||
|
repo_id=upload_repo,
|
||||||
|
repo_type="model",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
||||||
|
"""Save model weights into specified directory."""
|
||||||
|
if isinstance(save_path, str):
|
||||||
|
save_path = Path(save_path)
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
shards = make_shards(weights)
|
||||||
|
shards_count = len(shards)
|
||||||
|
shard_file_format = (
|
||||||
|
"model-{:05d}-of-{:05d}.safetensors"
|
||||||
|
if shards_count > 1
|
||||||
|
else "model.safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, shard in enumerate(shards):
|
||||||
|
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||||
|
mx.save_safetensors(str(save_path / shard_name), shard)
|
||||||
|
@ -4,7 +4,7 @@ import json
|
|||||||
import math
|
import math
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -198,7 +198,7 @@ class Model(nn.Module):
|
|||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: mx.array = None,
|
mask: mx.array = None,
|
||||||
cache: mx.array = None,
|
cache: mx.array = None,
|
||||||
) -> tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
mask = None
|
mask = None
|
||||||
if x.shape[1] > 1:
|
if x.shape[1] > 1:
|
||||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
|
Loading…
Reference in New Issue
Block a user