Adding full finetuning (#903)

* Adding full model weights finetuning

* Updating the LORA.md and ACKNOWLEDGMENTS.md files.

* removing --use-dora and --fulll-training and adding --fine-tune-type

* some clean up

* reformating and fixing dora training

* updated CONFIG_DEFAULTS

* update config example

* update in the config example fie

* Update LORA.md

* merge and commit

* adding argument for dora linear layer

* clean up

* clean up in the example yaml file

* fix

* final fix before sending

* small addition to re md file

* fix for loading the fully trained model by saving all the files and configs correctly

* clean up

* removing the unnesesairy files

* changing lora layers back to 16

* removed max file size

* nits

* resolve merge

* some consistency changes

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Gökdeniz Gülmez 2024-09-30 02:12:47 +02:00 committed by GitHub
parent 7ec2021bb9
commit 50e5ca81a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 79 additions and 70 deletions

View File

@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples. - Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models. - Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`. - Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM` and `Mamba`. - Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.

View File

@ -16,7 +16,7 @@ conda install -c conda-forge mlx-lm
The `mlx-lm` package also has: The `mlx-lm` package also has:
- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md) - [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md) - [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md) - [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)

View File

@ -57,6 +57,9 @@ mlx_lm.lora \
--iters 600 --iters 600
``` ```
To fine-tune the full model weights, add the `--fine-tune-type full` flag.
Currently supported fine-tuning types are `lora` (default), `dora`, and `full`.
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl` The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
when using `--train` and a path to a `test.jsonl` when using `--test`. For more when using `--train` and a path to a `test.jsonl` when using `--test`. For more
details on the data format see the section on [Data](#Data). details on the data format see the section on [Data](#Data).
@ -67,8 +70,8 @@ mistralai/Mistral-7B-v0.1`.
If `--model` points to a quantized model, then the training will use QLoRA, If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA. otherwise it will use regular LoRA.
By default, the adapter config and weights are saved in `adapters/`. You can By default, the adapter config and learned weights are saved in `adapters/`.
specify the output location with `--adapter-path`. You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`. `--resume-adapter-file <path_to_adapters.safetensors>`.
@ -118,7 +121,7 @@ mlx_lm.fuse --model <path_to_model>
``` ```
This will by default load the adapters from `adapters/`, and save the fused This will by default load the adapters from `adapters/`, and save the fused
model in the path `lora_fused_model/`. All of these are configurable. model in the path `fused_model/`. All of these are configurable.
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
@ -141,7 +144,7 @@ mlx_lm.fuse \
--export-gguf --export-gguf
``` ```
This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You
can specify the file name with `--gguf-path`. can specify the file name with `--gguf-path`.
## Data ## Data
@ -301,7 +304,7 @@ of memory. Here are some tips to reduce memory use should you need to do so:
setting this to `2` or `1` will reduce memory consumption. This may slow setting this to `2` or `1` will reduce memory consumption. This may slow
things down a little, but will also reduce the memory use. things down a little, but will also reduce the memory use.
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default 3. Reduce the number of layers to fine-tune with `--num-layers`. The default
is `16`, so you can try `8` or `4`. This reduces the amount of memory is `16`, so you can try `8` or `4`. This reduces the amount of memory
needed for back propagation. It may also reduce the quality of the needed for back propagation. It may also reduce the quality of the
fine-tuned model if you are fine-tuning with a lot of data. fine-tuned model if you are fine-tuning with a lot of data.
@ -323,7 +326,7 @@ mlx_lm.lora \
--model mistralai/Mistral-7B-v0.1 \ --model mistralai/Mistral-7B-v0.1 \
--train \ --train \
--batch-size 1 \ --batch-size 1 \
--lora-layers 4 \ --num-layers 4 \
--data wikisql --data wikisql
``` ```
@ -333,4 +336,5 @@ tokens-per-second, using the MLX Example
data set. data set.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. [^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) [^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)

View File

@ -1,8 +1,12 @@
# The path to the local model directory or Hugging Face repo. # The path to the local model directory or Hugging Face repo.
model: "mlx_model" model: "mlx_model"
# Whether or not to train (boolean) # Whether or not to train (boolean)
train: true train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora
# Directory with {train, valid, test}.jsonl files # Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data" data: "/path/to/training/data"
@ -51,9 +55,6 @@ max_seq_length: 2048
# Use gradient checkpointing to reduce memory use. # Use gradient checkpointing to reduce memory use.
grad_checkpoint: false grad_checkpoint: false
# Use DoRA instead of LoRA.
use_dora: false
# LoRA parameters can only be specified in a config file # LoRA parameters can only be specified in a config file
lora_parameters: lora_parameters:
# The layer keys to apply LoRA to. # The layer keys to apply LoRA to.

View File

@ -8,7 +8,7 @@ from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf from .gguf import convert_to_gguf
from .tuner.dora import DoRAEmbedding, DoRALinear from .tuner.dora import DoRAEmbedding, DoRALinear
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
from .tuner.utils import apply_lora_layers, dequantize from .tuner.utils import dequantize, load_adapters
from .utils import ( from .utils import (
fetch_from_hub, fetch_from_hub,
get_model_path, get_model_path,
@ -29,7 +29,7 @@ def parse_arguments() -> argparse.Namespace:
) )
parser.add_argument( parser.add_argument(
"--save-path", "--save-path",
default="lora_fused_model", default="fused_model",
help="The path to save the fused model.", help="The path to save the fused model.",
) )
parser.add_argument( parser.add_argument(
@ -77,17 +77,14 @@ def main() -> None:
model, config, tokenizer = fetch_from_hub(model_path) model, config, tokenizer = fetch_from_hub(model_path)
model.freeze() model.freeze()
model = apply_lora_layers(model, args.adapter_path) model = load_adapters(model, args.adapter_path)
fused_linears = [ fused_linears = [
(n, m.fuse()) (n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
for n, m in model.named_modules()
if isinstance(
m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding)
)
] ]
model.update_modules(tree_unflatten(fused_linears)) if fused_linears:
model.update_modules(tree_unflatten(fused_linears))
if args.de_quantize: if args.de_quantize:
print("De-quantizing model") print("De-quantizing model")

View File

@ -15,9 +15,9 @@ from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import ( from .tuner.utils import (
apply_lora_layers,
build_schedule, build_schedule,
linear_to_lora_layers, linear_to_lora_layers,
load_adapters,
print_trainable_parameters, print_trainable_parameters,
) )
from .utils import load, save_config from .utils import load, save_config
@ -41,9 +41,10 @@ yaml_loader.add_implicit_resolver(
CONFIG_DEFAULTS = { CONFIG_DEFAULTS = {
"model": "mlx_model", "model": "mlx_model",
"train": False, "train": False,
"fine_tune_type": "lora",
"data": "data/", "data": "data/",
"seed": 0, "seed": 0,
"lora_layers": 16, "num_layers": 16,
"batch_size": 4, "batch_size": 4,
"iters": 1000, "iters": 1000,
"val_batches": 25, "val_batches": 25,
@ -58,7 +59,6 @@ CONFIG_DEFAULTS = {
"max_seq_length": 2048, "max_seq_length": 2048,
"lr_schedule": None, "lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"use_dora": False,
} }
@ -82,7 +82,14 @@ def build_parser():
help="Directory with {train, valid, test}.jsonl files", help="Directory with {train, valid, test}.jsonl files",
) )
parser.add_argument( parser.add_argument(
"--lora-layers", "--fine-tune-type",
type=str,
choices=["lora", "dora", "full"],
default="lora",
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--num-layers",
type=int, type=int,
help="Number of layers to fine-tune. Default is 16, use -1 for all.", help="Number of layers to fine-tune. Default is 16, use -1 for all.",
) )
@ -107,12 +114,12 @@ def build_parser():
parser.add_argument( parser.add_argument(
"--resume-adapter-file", "--resume-adapter-file",
type=str, type=str,
help="Load path to resume training with the given adapters.", help="Load path to resume training from the given fine-tuned weights.",
) )
parser.add_argument( parser.add_argument(
"--adapter-path", "--adapter-path",
type=str, type=str,
help="Save/load path for the adapters.", help="Save/load path for the fine-tuned weights.",
) )
parser.add_argument( parser.add_argument(
"--save-every", "--save-every",
@ -148,9 +155,6 @@ def build_parser():
default=None, default=None,
) )
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed") parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
parser.add_argument(
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
)
return parser return parser
@ -162,21 +166,31 @@ def train_model(
valid_set, valid_set,
training_callback: TrainingCallback = None, training_callback: TrainingCallback = None,
): ):
# Freeze all layers
model.freeze() model.freeze()
if args.fine_tune_type == "full":
for l in model.layers[-min(args.num_layers, 0) :]:
l.unfreeze()
elif args.fine_tune_type in ["lora", "dora"]:
# Convert linear layers to lora/dora layers and unfreeze in the process
linear_to_lora_layers(
model,
args.num_layers,
args.lora_parameters,
use_dora=(args.fine_tune_type == "dora"),
)
else:
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
# Convert linear layers to lora layers and unfreeze in the process # Resume from weights if provided
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters, args.use_dora)
# Resume training the given adapters.
if args.resume_adapter_file is not None: if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {args.resume_adapter_file}") print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False) model.load_weights(args.resume_adapter_file, strict=False)
print_trainable_parameters(model) print_trainable_parameters(model)
adapter_path = Path(args.adapter_path) adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True) adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors" adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json") save_config(vars(args), adapter_path / "adapter_config.json")
@ -240,7 +254,7 @@ def run(args, training_callback: TrainingCallback = None):
if args.test and not args.train: if args.test and not args.train:
# Allow testing without LoRA layers by providing empty path # Allow testing without LoRA layers by providing empty path
if args.adapter_path != "": if args.adapter_path != "":
apply_lora_layers(model, args.adapter_path) load_adapters(model, args.adapter_path)
elif args.train: elif args.train:
print("Training") print("Training")

View File

@ -1,5 +1,7 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import glob
import shutil
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@ -285,24 +287,18 @@ def train(
# Save adapter weights # Save adapter weights
if it % args.steps_per_save == 0: if it % args.steps_per_save == 0:
save_adapter(model, args.adapter_file) adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = ( checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
) )
save_adapter(model, checkpoint) mx.save_safetensors(str(checkpoint), adapter_weights)
print( print(
f"Iter {it}: Saved adapter weights to " f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}." f"{args.adapter_file} and {checkpoint}."
) )
# save final adapter weights # Save final weights
save_adapter(model, args.adapter_file) adapter_weights = dict(tree_flatten(model.trainable_parameters()))
print(f"Saved final adapter weights to {args.adapter_file}.") mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")
def save_adapter(
model: nn.Module,
adapter_file: Union[str, Path],
):
flattened_tree = tree_flatten(model.trainable_parameters())
mx.save_safetensors(str(adapter_file), dict(flattened_tree))

View File

@ -36,7 +36,7 @@ def build_schedule(schedule_config: Dict):
def linear_to_lora_layers( def linear_to_lora_layers(
model: nn.Module, model: nn.Module,
num_lora_layers: int, num_layers: int,
config: Dict, config: Dict,
use_dora: bool = False, use_dora: bool = False,
): ):
@ -45,22 +45,17 @@ def linear_to_lora_layers(
Args: Args:
model (nn.Module): The neural network model. model (nn.Module): The neural network model.
num_lora_layers (int): The number of blocks to convert to lora layers num_layers (int): The number of blocks to convert to lora layers
starting from the last layer. starting from the last layer.
config (dict): More configuration parameters for LoRA, including the config (dict): More configuration parameters for LoRA, including the
rank, scale, and optional layer keys. rank, scale, and optional layer keys.
use_dora (bool): If True, uses DoRA instead of LoRA. use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False`` Default: ``False``
""" """
num_layers = len(model.layers) if num_layers > len(model.layers):
if num_lora_layers < 0:
num_lora_layers = num_layers
if num_lora_layers > num_layers:
raise ValueError( raise ValueError(
f"Requested {num_lora_layers} LoRA layers " f"Requested {num_layers} LoRA layers "
f"but the model only has {num_layers} layers." f"but the model only has {len(model.layers)} layers."
) )
def to_lora(layer): def to_lora(layer):
@ -151,7 +146,7 @@ def linear_to_lora_layers(
else: else:
raise ValueError(f"Lora does not support {model.model_type}") raise ValueError(f"Lora does not support {model.model_type}")
for l in model.layers[num_layers - num_lora_layers :]: for l in model.layers[-min(num_layers, 0) :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
if lora_layers: if lora_layers:
l.update_modules(tree_unflatten(lora_layers)) l.update_modules(tree_unflatten(lora_layers))
@ -161,9 +156,9 @@ def linear_to_lora_layers(
model.update_modules(tree_unflatten(lora_modules)) model.update_modules(tree_unflatten(lora_modules))
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
""" """
Apply LoRA layers to the model. Load any fine-tuned adapters / layers.
Args: Args:
model (nn.Module): The neural network model. model (nn.Module): The neural network model.
@ -177,12 +172,14 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}") raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
with open(adapter_path / "adapter_config.json", "r") as fid: with open(adapter_path / "adapter_config.json", "r") as fid:
config = types.SimpleNamespace(**json.load(fid)) config = types.SimpleNamespace(**json.load(fid))
linear_to_lora_layers( fine_tune_type = getattr(config, "fine_tune_type", "lora")
model, if fine_tune_type != "full":
config.lora_layers, linear_to_lora_layers(
config.lora_parameters, model,
getattr(config, "use_dora", False), config.num_layers,
) config.lora_parameters,
use_dora=(fine_tune_type == "dora"),
)
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
return model return model

View File

@ -21,8 +21,8 @@ from transformers import PreTrainedTokenizer
from .models.base import KVCache, RotatingKVCache from .models.base import KVCache, RotatingKVCache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers
from .tuner.utils import dequantize as dequantize_model from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters
# Constants # Constants
MODEL_REMAPPING = { MODEL_REMAPPING = {
@ -515,7 +515,7 @@ def load(
model = load_model(model_path, lazy, model_config) model = load_model(model_path, lazy, model_config)
if adapter_path is not None: if adapter_path is not None:
model = apply_lora_layers(model, adapter_path) model = load_adapters(model, adapter_path)
model.eval() model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config) tokenizer = load_tokenizer(model_path, tokenizer_config)