Adding full finetuning (#903)

* Adding full model weights finetuning

* Updating the LORA.md and ACKNOWLEDGMENTS.md files.

* removing --use-dora and --fulll-training and adding --fine-tune-type

* some clean up

* reformating and fixing dora training

* updated CONFIG_DEFAULTS

* update config example

* update in the config example fie

* Update LORA.md

* merge and commit

* adding argument for dora linear layer

* clean up

* clean up in the example yaml file

* fix

* final fix before sending

* small addition to re md file

* fix for loading the fully trained model by saving all the files and configs correctly

* clean up

* removing the unnesesairy files

* changing lora layers back to 16

* removed max file size

* nits

* resolve merge

* some consistency changes

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Gökdeniz Gülmez 2024-09-30 02:12:47 +02:00 committed by GitHub
parent 7ec2021bb9
commit 50e5ca81a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 79 additions and 70 deletions

View File

@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM` and `Mamba`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.

View File

@ -16,7 +16,7 @@ conda install -c conda-forge mlx-lm
The `mlx-lm` package also has:
- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)

View File

@ -57,6 +57,9 @@ mlx_lm.lora \
--iters 600
```
To fine-tune the full model weights, add the `--fine-tune-type full` flag.
Currently supported fine-tuning types are `lora` (default), `dora`, and `full`.
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
details on the data format see the section on [Data](#Data).
@ -67,8 +70,8 @@ mistralai/Mistral-7B-v0.1`.
If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA.
By default, the adapter config and weights are saved in `adapters/`. You can
specify the output location with `--adapter-path`.
By default, the adapter config and learned weights are saved in `adapters/`.
You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`.
@ -118,7 +121,7 @@ mlx_lm.fuse --model <path_to_model>
```
This will by default load the adapters from `adapters/`, and save the fused
model in the path `lora_fused_model/`. All of these are configurable.
model in the path `fused_model/`. All of these are configurable.
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
@ -141,7 +144,7 @@ mlx_lm.fuse \
--export-gguf
```
This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You
This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You
can specify the file name with `--gguf-path`.
## Data
@ -301,7 +304,7 @@ of memory. Here are some tips to reduce memory use should you need to do so:
setting this to `2` or `1` will reduce memory consumption. This may slow
things down a little, but will also reduce the memory use.
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
3. Reduce the number of layers to fine-tune with `--num-layers`. The default
is `16`, so you can try `8` or `4`. This reduces the amount of memory
needed for back propagation. It may also reduce the quality of the
fine-tuned model if you are fine-tuning with a lot of data.
@ -323,7 +326,7 @@ mlx_lm.lora \
--model mistralai/Mistral-7B-v0.1 \
--train \
--batch-size 1 \
--lora-layers 4 \
--num-layers 4 \
--data wikisql
```
@ -333,4 +336,5 @@ tokens-per-second, using the MLX Example
data set.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)

View File

@ -1,8 +1,12 @@
# The path to the local model directory or Hugging Face repo.
model: "mlx_model"
# Whether or not to train (boolean)
train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora
# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"
@ -51,9 +55,6 @@ max_seq_length: 2048
# Use gradient checkpointing to reduce memory use.
grad_checkpoint: false
# Use DoRA instead of LoRA.
use_dora: false
# LoRA parameters can only be specified in a config file
lora_parameters:
# The layer keys to apply LoRA to.

View File

@ -8,7 +8,7 @@ from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf
from .tuner.dora import DoRAEmbedding, DoRALinear
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
from .tuner.utils import apply_lora_layers, dequantize
from .tuner.utils import dequantize, load_adapters
from .utils import (
fetch_from_hub,
get_model_path,
@ -29,7 +29,7 @@ def parse_arguments() -> argparse.Namespace:
)
parser.add_argument(
"--save-path",
default="lora_fused_model",
default="fused_model",
help="The path to save the fused model.",
)
parser.add_argument(
@ -77,17 +77,14 @@ def main() -> None:
model, config, tokenizer = fetch_from_hub(model_path)
model.freeze()
model = apply_lora_layers(model, args.adapter_path)
model = load_adapters(model, args.adapter_path)
fused_linears = [
(n, m.fuse())
for n, m in model.named_modules()
if isinstance(
m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding)
)
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
]
model.update_modules(tree_unflatten(fused_linears))
if fused_linears:
model.update_modules(tree_unflatten(fused_linears))
if args.de_quantize:
print("De-quantizing model")

View File

@ -15,9 +15,9 @@ from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
apply_lora_layers,
build_schedule,
linear_to_lora_layers,
load_adapters,
print_trainable_parameters,
)
from .utils import load, save_config
@ -41,9 +41,10 @@ yaml_loader.add_implicit_resolver(
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"fine_tune_type": "lora",
"data": "data/",
"seed": 0,
"lora_layers": 16,
"num_layers": 16,
"batch_size": 4,
"iters": 1000,
"val_batches": 25,
@ -58,7 +59,6 @@ CONFIG_DEFAULTS = {
"max_seq_length": 2048,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"use_dora": False,
}
@ -82,7 +82,14 @@ def build_parser():
help="Directory with {train, valid, test}.jsonl files",
)
parser.add_argument(
"--lora-layers",
"--fine-tune-type",
type=str,
choices=["lora", "dora", "full"],
default="lora",
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--num-layers",
type=int,
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
)
@ -107,12 +114,12 @@ def build_parser():
parser.add_argument(
"--resume-adapter-file",
type=str,
help="Load path to resume training with the given adapters.",
help="Load path to resume training from the given fine-tuned weights.",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Save/load path for the adapters.",
help="Save/load path for the fine-tuned weights.",
)
parser.add_argument(
"--save-every",
@ -148,9 +155,6 @@ def build_parser():
default=None,
)
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
parser.add_argument(
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
)
return parser
@ -162,21 +166,31 @@ def train_model(
valid_set,
training_callback: TrainingCallback = None,
):
# Freeze all layers
model.freeze()
if args.fine_tune_type == "full":
for l in model.layers[-min(args.num_layers, 0) :]:
l.unfreeze()
elif args.fine_tune_type in ["lora", "dora"]:
# Convert linear layers to lora/dora layers and unfreeze in the process
linear_to_lora_layers(
model,
args.num_layers,
args.lora_parameters,
use_dora=(args.fine_tune_type == "dora"),
)
else:
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
# Convert linear layers to lora layers and unfreeze in the process
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters, args.use_dora)
# Resume training the given adapters.
# Resume from weights if provided
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
print_trainable_parameters(model)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
@ -240,7 +254,7 @@ def run(args, training_callback: TrainingCallback = None):
if args.test and not args.train:
# Allow testing without LoRA layers by providing empty path
if args.adapter_path != "":
apply_lora_layers(model, args.adapter_path)
load_adapters(model, args.adapter_path)
elif args.train:
print("Training")

View File

@ -1,5 +1,7 @@
# Copyright © 2024 Apple Inc.
import glob
import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
@ -285,24 +287,18 @@ def train(
# Save adapter weights
if it % args.steps_per_save == 0:
save_adapter(model, args.adapter_file)
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
save_adapter(model, checkpoint)
mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
# save final adapter weights
save_adapter(model, args.adapter_file)
print(f"Saved final adapter weights to {args.adapter_file}.")
def save_adapter(
model: nn.Module,
adapter_file: Union[str, Path],
):
flattened_tree = tree_flatten(model.trainable_parameters())
mx.save_safetensors(str(adapter_file), dict(flattened_tree))
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")

View File

@ -36,7 +36,7 @@ def build_schedule(schedule_config: Dict):
def linear_to_lora_layers(
model: nn.Module,
num_lora_layers: int,
num_layers: int,
config: Dict,
use_dora: bool = False,
):
@ -45,22 +45,17 @@ def linear_to_lora_layers(
Args:
model (nn.Module): The neural network model.
num_lora_layers (int): The number of blocks to convert to lora layers
num_layers (int): The number of blocks to convert to lora layers
starting from the last layer.
config (dict): More configuration parameters for LoRA, including the
rank, scale, and optional layer keys.
use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False``
"""
num_layers = len(model.layers)
if num_lora_layers < 0:
num_lora_layers = num_layers
if num_lora_layers > num_layers:
if num_layers > len(model.layers):
raise ValueError(
f"Requested {num_lora_layers} LoRA layers "
f"but the model only has {num_layers} layers."
f"Requested {num_layers} LoRA layers "
f"but the model only has {len(model.layers)} layers."
)
def to_lora(layer):
@ -151,7 +146,7 @@ def linear_to_lora_layers(
else:
raise ValueError(f"Lora does not support {model.model_type}")
for l in model.layers[num_layers - num_lora_layers :]:
for l in model.layers[-min(num_layers, 0) :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
if lora_layers:
l.update_modules(tree_unflatten(lora_layers))
@ -161,9 +156,9 @@ def linear_to_lora_layers(
model.update_modules(tree_unflatten(lora_modules))
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
"""
Apply LoRA layers to the model.
Load any fine-tuned adapters / layers.
Args:
model (nn.Module): The neural network model.
@ -177,12 +172,14 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
with open(adapter_path / "adapter_config.json", "r") as fid:
config = types.SimpleNamespace(**json.load(fid))
linear_to_lora_layers(
model,
config.lora_layers,
config.lora_parameters,
getattr(config, "use_dora", False),
)
fine_tune_type = getattr(config, "fine_tune_type", "lora")
if fine_tune_type != "full":
linear_to_lora_layers(
model,
config.num_layers,
config.lora_parameters,
use_dora=(fine_tune_type == "dora"),
)
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
return model

View File

@ -21,8 +21,8 @@ from transformers import PreTrainedTokenizer
from .models.base import KVCache, RotatingKVCache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers
from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters
# Constants
MODEL_REMAPPING = {
@ -515,7 +515,7 @@ def load(
model = load_model(model_path, lazy, model_config)
if adapter_path is not None:
model = apply_lora_layers(model, adapter_path)
model = load_adapters(model, adapter_path)
model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config)