mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Adding full finetuning (#903)
* Adding full model weights finetuning * Updating the LORA.md and ACKNOWLEDGMENTS.md files. * removing --use-dora and --fulll-training and adding --fine-tune-type * some clean up * reformating and fixing dora training * updated CONFIG_DEFAULTS * update config example * update in the config example fie * Update LORA.md * merge and commit * adding argument for dora linear layer * clean up * clean up in the example yaml file * fix * final fix before sending * small addition to re md file * fix for loading the fully trained model by saving all the files and configs correctly * clean up * removing the unnesesairy files * changing lora layers back to 16 * removed max file size * nits * resolve merge * some consistency changes --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
7ec2021bb9
commit
50e5ca81a8
@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
|
||||
- Markus Enzweiler: Added the `cvae` examples.
|
||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||
- Shiyu Li: Added the `Segment Anything Model`.
|
||||
- Gökdeniz Gülmez: Added support for `MiniCPM` and `Mamba`.
|
||||
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.
|
@ -16,7 +16,7 @@ conda install -c conda-forge mlx-lm
|
||||
|
||||
The `mlx-lm` package also has:
|
||||
|
||||
- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
|
||||
- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
|
||||
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
||||
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
||||
|
||||
|
@ -57,6 +57,9 @@ mlx_lm.lora \
|
||||
--iters 600
|
||||
```
|
||||
|
||||
To fine-tune the full model weights, add the `--fine-tune-type full` flag.
|
||||
Currently supported fine-tuning types are `lora` (default), `dora`, and `full`.
|
||||
|
||||
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
|
||||
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
|
||||
details on the data format see the section on [Data](#Data).
|
||||
@ -67,8 +70,8 @@ mistralai/Mistral-7B-v0.1`.
|
||||
If `--model` points to a quantized model, then the training will use QLoRA,
|
||||
otherwise it will use regular LoRA.
|
||||
|
||||
By default, the adapter config and weights are saved in `adapters/`. You can
|
||||
specify the output location with `--adapter-path`.
|
||||
By default, the adapter config and learned weights are saved in `adapters/`.
|
||||
You can specify the output location with `--adapter-path`.
|
||||
|
||||
You can resume fine-tuning with an existing adapter with
|
||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||
@ -118,7 +121,7 @@ mlx_lm.fuse --model <path_to_model>
|
||||
```
|
||||
|
||||
This will by default load the adapters from `adapters/`, and save the fused
|
||||
model in the path `lora_fused_model/`. All of these are configurable.
|
||||
model in the path `fused_model/`. All of these are configurable.
|
||||
|
||||
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
||||
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
|
||||
@ -141,7 +144,7 @@ mlx_lm.fuse \
|
||||
--export-gguf
|
||||
```
|
||||
|
||||
This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You
|
||||
This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You
|
||||
can specify the file name with `--gguf-path`.
|
||||
|
||||
## Data
|
||||
@ -301,7 +304,7 @@ of memory. Here are some tips to reduce memory use should you need to do so:
|
||||
setting this to `2` or `1` will reduce memory consumption. This may slow
|
||||
things down a little, but will also reduce the memory use.
|
||||
|
||||
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
|
||||
3. Reduce the number of layers to fine-tune with `--num-layers`. The default
|
||||
is `16`, so you can try `8` or `4`. This reduces the amount of memory
|
||||
needed for back propagation. It may also reduce the quality of the
|
||||
fine-tuned model if you are fine-tuning with a lot of data.
|
||||
@ -323,7 +326,7 @@ mlx_lm.lora \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--train \
|
||||
--batch-size 1 \
|
||||
--lora-layers 4 \
|
||||
--num-layers 4 \
|
||||
--data wikisql
|
||||
```
|
||||
|
||||
@ -333,4 +336,5 @@ tokens-per-second, using the MLX Example
|
||||
data set.
|
||||
|
||||
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||
|
||||
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
||||
|
@ -1,8 +1,12 @@
|
||||
# The path to the local model directory or Hugging Face repo.
|
||||
model: "mlx_model"
|
||||
|
||||
# Whether or not to train (boolean)
|
||||
train: true
|
||||
|
||||
# The fine-tuning method: "lora", "dora", or "full".
|
||||
fine_tune_type: lora
|
||||
|
||||
# Directory with {train, valid, test}.jsonl files
|
||||
data: "/path/to/training/data"
|
||||
|
||||
@ -51,9 +55,6 @@ max_seq_length: 2048
|
||||
# Use gradient checkpointing to reduce memory use.
|
||||
grad_checkpoint: false
|
||||
|
||||
# Use DoRA instead of LoRA.
|
||||
use_dora: false
|
||||
|
||||
# LoRA parameters can only be specified in a config file
|
||||
lora_parameters:
|
||||
# The layer keys to apply LoRA to.
|
||||
|
@ -8,7 +8,7 @@ from mlx.utils import tree_flatten, tree_unflatten
|
||||
from .gguf import convert_to_gguf
|
||||
from .tuner.dora import DoRAEmbedding, DoRALinear
|
||||
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
|
||||
from .tuner.utils import apply_lora_layers, dequantize
|
||||
from .tuner.utils import dequantize, load_adapters
|
||||
from .utils import (
|
||||
fetch_from_hub,
|
||||
get_model_path,
|
||||
@ -29,7 +29,7 @@ def parse_arguments() -> argparse.Namespace:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
default="lora_fused_model",
|
||||
default="fused_model",
|
||||
help="The path to save the fused model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -77,17 +77,14 @@ def main() -> None:
|
||||
model, config, tokenizer = fetch_from_hub(model_path)
|
||||
|
||||
model.freeze()
|
||||
model = apply_lora_layers(model, args.adapter_path)
|
||||
model = load_adapters(model, args.adapter_path)
|
||||
|
||||
fused_linears = [
|
||||
(n, m.fuse())
|
||||
for n, m in model.named_modules()
|
||||
if isinstance(
|
||||
m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding)
|
||||
)
|
||||
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
|
||||
]
|
||||
|
||||
model.update_modules(tree_unflatten(fused_linears))
|
||||
if fused_linears:
|
||||
model.update_modules(tree_unflatten(fused_linears))
|
||||
|
||||
if args.de_quantize:
|
||||
print("De-quantizing model")
|
||||
|
@ -15,9 +15,9 @@ from .tokenizer_utils import TokenizerWrapper
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import (
|
||||
apply_lora_layers,
|
||||
build_schedule,
|
||||
linear_to_lora_layers,
|
||||
load_adapters,
|
||||
print_trainable_parameters,
|
||||
)
|
||||
from .utils import load, save_config
|
||||
@ -41,9 +41,10 @@ yaml_loader.add_implicit_resolver(
|
||||
CONFIG_DEFAULTS = {
|
||||
"model": "mlx_model",
|
||||
"train": False,
|
||||
"fine_tune_type": "lora",
|
||||
"data": "data/",
|
||||
"seed": 0,
|
||||
"lora_layers": 16,
|
||||
"num_layers": 16,
|
||||
"batch_size": 4,
|
||||
"iters": 1000,
|
||||
"val_batches": 25,
|
||||
@ -58,7 +59,6 @@ CONFIG_DEFAULTS = {
|
||||
"max_seq_length": 2048,
|
||||
"lr_schedule": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
"use_dora": False,
|
||||
}
|
||||
|
||||
|
||||
@ -82,7 +82,14 @@ def build_parser():
|
||||
help="Directory with {train, valid, test}.jsonl files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-layers",
|
||||
"--fine-tune-type",
|
||||
type=str,
|
||||
choices=["lora", "dora", "full"],
|
||||
default="lora",
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
|
||||
)
|
||||
@ -107,12 +114,12 @@ def build_parser():
|
||||
parser.add_argument(
|
||||
"--resume-adapter-file",
|
||||
type=str,
|
||||
help="Load path to resume training with the given adapters.",
|
||||
help="Load path to resume training from the given fine-tuned weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Save/load path for the adapters.",
|
||||
help="Save/load path for the fine-tuned weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-every",
|
||||
@ -148,9 +155,6 @@ def build_parser():
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
|
||||
parser.add_argument(
|
||||
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -162,21 +166,31 @@ def train_model(
|
||||
valid_set,
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
# Freeze all layers
|
||||
model.freeze()
|
||||
if args.fine_tune_type == "full":
|
||||
for l in model.layers[-min(args.num_layers, 0) :]:
|
||||
l.unfreeze()
|
||||
elif args.fine_tune_type in ["lora", "dora"]:
|
||||
# Convert linear layers to lora/dora layers and unfreeze in the process
|
||||
linear_to_lora_layers(
|
||||
model,
|
||||
args.num_layers,
|
||||
args.lora_parameters,
|
||||
use_dora=(args.fine_tune_type == "dora"),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
|
||||
|
||||
# Convert linear layers to lora layers and unfreeze in the process
|
||||
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters, args.use_dora)
|
||||
|
||||
# Resume training the given adapters.
|
||||
# Resume from weights if provided
|
||||
if args.resume_adapter_file is not None:
|
||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file, strict=False)
|
||||
|
||||
print_trainable_parameters(model)
|
||||
|
||||
adapter_path = Path(args.adapter_path)
|
||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
|
||||
@ -240,7 +254,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
if args.test and not args.train:
|
||||
# Allow testing without LoRA layers by providing empty path
|
||||
if args.adapter_path != "":
|
||||
apply_lora_layers(model, args.adapter_path)
|
||||
load_adapters(model, args.adapter_path)
|
||||
|
||||
elif args.train:
|
||||
print("Training")
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import glob
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@ -285,24 +287,18 @@ def train(
|
||||
|
||||
# Save adapter weights
|
||||
if it % args.steps_per_save == 0:
|
||||
save_adapter(model, args.adapter_file)
|
||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||
checkpoint = (
|
||||
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
||||
)
|
||||
save_adapter(model, checkpoint)
|
||||
mx.save_safetensors(str(checkpoint), adapter_weights)
|
||||
print(
|
||||
f"Iter {it}: Saved adapter weights to "
|
||||
f"{args.adapter_file} and {checkpoint}."
|
||||
)
|
||||
|
||||
# save final adapter weights
|
||||
save_adapter(model, args.adapter_file)
|
||||
print(f"Saved final adapter weights to {args.adapter_file}.")
|
||||
|
||||
|
||||
def save_adapter(
|
||||
model: nn.Module,
|
||||
adapter_file: Union[str, Path],
|
||||
):
|
||||
flattened_tree = tree_flatten(model.trainable_parameters())
|
||||
mx.save_safetensors(str(adapter_file), dict(flattened_tree))
|
||||
# Save final weights
|
||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||
print(f"Saved final weights to {args.adapter_file}.")
|
||||
|
@ -36,7 +36,7 @@ def build_schedule(schedule_config: Dict):
|
||||
|
||||
def linear_to_lora_layers(
|
||||
model: nn.Module,
|
||||
num_lora_layers: int,
|
||||
num_layers: int,
|
||||
config: Dict,
|
||||
use_dora: bool = False,
|
||||
):
|
||||
@ -45,22 +45,17 @@ def linear_to_lora_layers(
|
||||
|
||||
Args:
|
||||
model (nn.Module): The neural network model.
|
||||
num_lora_layers (int): The number of blocks to convert to lora layers
|
||||
num_layers (int): The number of blocks to convert to lora layers
|
||||
starting from the last layer.
|
||||
config (dict): More configuration parameters for LoRA, including the
|
||||
rank, scale, and optional layer keys.
|
||||
use_dora (bool): If True, uses DoRA instead of LoRA.
|
||||
Default: ``False``
|
||||
"""
|
||||
num_layers = len(model.layers)
|
||||
|
||||
if num_lora_layers < 0:
|
||||
num_lora_layers = num_layers
|
||||
|
||||
if num_lora_layers > num_layers:
|
||||
if num_layers > len(model.layers):
|
||||
raise ValueError(
|
||||
f"Requested {num_lora_layers} LoRA layers "
|
||||
f"but the model only has {num_layers} layers."
|
||||
f"Requested {num_layers} LoRA layers "
|
||||
f"but the model only has {len(model.layers)} layers."
|
||||
)
|
||||
|
||||
def to_lora(layer):
|
||||
@ -151,7 +146,7 @@ def linear_to_lora_layers(
|
||||
else:
|
||||
raise ValueError(f"Lora does not support {model.model_type}")
|
||||
|
||||
for l in model.layers[num_layers - num_lora_layers :]:
|
||||
for l in model.layers[-min(num_layers, 0) :]:
|
||||
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
||||
if lora_layers:
|
||||
l.update_modules(tree_unflatten(lora_layers))
|
||||
@ -161,9 +156,9 @@ def linear_to_lora_layers(
|
||||
model.update_modules(tree_unflatten(lora_modules))
|
||||
|
||||
|
||||
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
|
||||
def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
|
||||
"""
|
||||
Apply LoRA layers to the model.
|
||||
Load any fine-tuned adapters / layers.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The neural network model.
|
||||
@ -177,12 +172,14 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
|
||||
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
|
||||
with open(adapter_path / "adapter_config.json", "r") as fid:
|
||||
config = types.SimpleNamespace(**json.load(fid))
|
||||
linear_to_lora_layers(
|
||||
model,
|
||||
config.lora_layers,
|
||||
config.lora_parameters,
|
||||
getattr(config, "use_dora", False),
|
||||
)
|
||||
fine_tune_type = getattr(config, "fine_tune_type", "lora")
|
||||
if fine_tune_type != "full":
|
||||
linear_to_lora_layers(
|
||||
model,
|
||||
config.num_layers,
|
||||
config.lora_parameters,
|
||||
use_dora=(fine_tune_type == "dora"),
|
||||
)
|
||||
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
||||
return model
|
||||
|
||||
|
@ -21,8 +21,8 @@ from transformers import PreTrainedTokenizer
|
||||
from .models.base import KVCache, RotatingKVCache
|
||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import apply_lora_layers
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
from .tuner.utils import load_adapters
|
||||
|
||||
# Constants
|
||||
MODEL_REMAPPING = {
|
||||
@ -515,7 +515,7 @@ def load(
|
||||
|
||||
model = load_model(model_path, lazy, model_config)
|
||||
if adapter_path is not None:
|
||||
model = apply_lora_layers(model, adapter_path)
|
||||
model = load_adapters(model, adapter_path)
|
||||
model.eval()
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user