diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 2037a076..41557c29 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals: - Markus Enzweiler: Added the `cvae` examples. - Prince Canuma: Helped add support for `Starcoder2` models. - Shiyu Li: Added the `Segment Anything Model`. -- Gökdeniz Gülmez: Added support for `MiniCPM` and `Mamba`. +- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`. \ No newline at end of file diff --git a/llms/README.md b/llms/README.md index b8e1914d..75677865 100644 --- a/llms/README.md +++ b/llms/README.md @@ -16,7 +16,7 @@ conda install -c conda-forge mlx-lm The `mlx-lm` package also has: -- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md) +- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md) - [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md) - [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 8aec89ec..80c25b4b 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -57,6 +57,9 @@ mlx_lm.lora \ --iters 600 ``` +To fine-tune the full model weights, add the `--fine-tune-type full` flag. +Currently supported fine-tuning types are `lora` (default), `dora`, and `full`. + The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl` when using `--train` and a path to a `test.jsonl` when using `--test`. For more details on the data format see the section on [Data](#Data). @@ -67,8 +70,8 @@ mistralai/Mistral-7B-v0.1`. If `--model` points to a quantized model, then the training will use QLoRA, otherwise it will use regular LoRA. -By default, the adapter config and weights are saved in `adapters/`. You can -specify the output location with `--adapter-path`. +By default, the adapter config and learned weights are saved in `adapters/`. +You can specify the output location with `--adapter-path`. You can resume fine-tuning with an existing adapter with `--resume-adapter-file `. @@ -118,7 +121,7 @@ mlx_lm.fuse --model ``` This will by default load the adapters from `adapters/`, and save the fused -model in the path `lora_fused_model/`. All of these are configurable. +model in the path `fused_model/`. All of these are configurable. To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments to `mlx_lm.fuse`. The latter is the repo name of the original model, which is @@ -141,7 +144,7 @@ mlx_lm.fuse \ --export-gguf ``` -This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You +This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You can specify the file name with `--gguf-path`. ## Data @@ -301,7 +304,7 @@ of memory. Here are some tips to reduce memory use should you need to do so: setting this to `2` or `1` will reduce memory consumption. This may slow things down a little, but will also reduce the memory use. -3. Reduce the number of layers to fine-tune with `--lora-layers`. The default +3. Reduce the number of layers to fine-tune with `--num-layers`. The default is `16`, so you can try `8` or `4`. This reduces the amount of memory needed for back propagation. It may also reduce the quality of the fine-tuned model if you are fine-tuning with a lot of data. @@ -323,7 +326,7 @@ mlx_lm.lora \ --model mistralai/Mistral-7B-v0.1 \ --train \ --batch-size 1 \ - --lora-layers 4 \ + --num-layers 4 \ --data wikisql ``` @@ -333,4 +336,5 @@ tokens-per-second, using the MLX Example data set. [^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. + [^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index 073a5b6f..4ec9a23c 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -1,8 +1,12 @@ # The path to the local model directory or Hugging Face repo. model: "mlx_model" + # Whether or not to train (boolean) train: true +# The fine-tuning method: "lora", "dora", or "full". +fine_tune_type: lora + # Directory with {train, valid, test}.jsonl files data: "/path/to/training/data" @@ -51,9 +55,6 @@ max_seq_length: 2048 # Use gradient checkpointing to reduce memory use. grad_checkpoint: false -# Use DoRA instead of LoRA. -use_dora: false - # LoRA parameters can only be specified in a config file lora_parameters: # The layer keys to apply LoRA to. diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py index 16457036..b0c46a74 100644 --- a/llms/mlx_lm/fuse.py +++ b/llms/mlx_lm/fuse.py @@ -8,7 +8,7 @@ from mlx.utils import tree_flatten, tree_unflatten from .gguf import convert_to_gguf from .tuner.dora import DoRAEmbedding, DoRALinear from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear -from .tuner.utils import apply_lora_layers, dequantize +from .tuner.utils import dequantize, load_adapters from .utils import ( fetch_from_hub, get_model_path, @@ -29,7 +29,7 @@ def parse_arguments() -> argparse.Namespace: ) parser.add_argument( "--save-path", - default="lora_fused_model", + default="fused_model", help="The path to save the fused model.", ) parser.add_argument( @@ -77,17 +77,14 @@ def main() -> None: model, config, tokenizer = fetch_from_hub(model_path) model.freeze() - model = apply_lora_layers(model, args.adapter_path) + model = load_adapters(model, args.adapter_path) fused_linears = [ - (n, m.fuse()) - for n, m in model.named_modules() - if isinstance( - m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding) - ) + (n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse") ] - model.update_modules(tree_unflatten(fused_linears)) + if fused_linears: + model.update_modules(tree_unflatten(fused_linears)) if args.de_quantize: print("De-quantizing model") diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 580e3d3c..69232774 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -15,9 +15,9 @@ from .tokenizer_utils import TokenizerWrapper from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.utils import ( - apply_lora_layers, build_schedule, linear_to_lora_layers, + load_adapters, print_trainable_parameters, ) from .utils import load, save_config @@ -41,9 +41,10 @@ yaml_loader.add_implicit_resolver( CONFIG_DEFAULTS = { "model": "mlx_model", "train": False, + "fine_tune_type": "lora", "data": "data/", "seed": 0, - "lora_layers": 16, + "num_layers": 16, "batch_size": 4, "iters": 1000, "val_batches": 25, @@ -58,7 +59,6 @@ CONFIG_DEFAULTS = { "max_seq_length": 2048, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, - "use_dora": False, } @@ -82,7 +82,14 @@ def build_parser(): help="Directory with {train, valid, test}.jsonl files", ) parser.add_argument( - "--lora-layers", + "--fine-tune-type", + type=str, + choices=["lora", "dora", "full"], + default="lora", + help="Type of fine-tuning to perform: lora, dora, or full.", + ) + parser.add_argument( + "--num-layers", type=int, help="Number of layers to fine-tune. Default is 16, use -1 for all.", ) @@ -107,12 +114,12 @@ def build_parser(): parser.add_argument( "--resume-adapter-file", type=str, - help="Load path to resume training with the given adapters.", + help="Load path to resume training from the given fine-tuned weights.", ) parser.add_argument( "--adapter-path", type=str, - help="Save/load path for the adapters.", + help="Save/load path for the fine-tuned weights.", ) parser.add_argument( "--save-every", @@ -148,9 +155,6 @@ def build_parser(): default=None, ) parser.add_argument("--seed", type=int, default=None, help="The PRNG seed") - parser.add_argument( - "--use-dora", action="store_true", default=None, help="Use DoRA to finetune." - ) return parser @@ -162,21 +166,31 @@ def train_model( valid_set, training_callback: TrainingCallback = None, ): - # Freeze all layers model.freeze() + if args.fine_tune_type == "full": + for l in model.layers[-min(args.num_layers, 0) :]: + l.unfreeze() + elif args.fine_tune_type in ["lora", "dora"]: + # Convert linear layers to lora/dora layers and unfreeze in the process + linear_to_lora_layers( + model, + args.num_layers, + args.lora_parameters, + use_dora=(args.fine_tune_type == "dora"), + ) + else: + raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}") - # Convert linear layers to lora layers and unfreeze in the process - linear_to_lora_layers(model, args.lora_layers, args.lora_parameters, args.use_dora) - - # Resume training the given adapters. + # Resume from weights if provided if args.resume_adapter_file is not None: - print(f"Loading pretrained adapters from {args.resume_adapter_file}") + print(f"Loading fine-tuned weights from {args.resume_adapter_file}") model.load_weights(args.resume_adapter_file, strict=False) print_trainable_parameters(model) adapter_path = Path(args.adapter_path) adapter_path.mkdir(parents=True, exist_ok=True) + adapter_file = adapter_path / "adapters.safetensors" save_config(vars(args), adapter_path / "adapter_config.json") @@ -240,7 +254,7 @@ def run(args, training_callback: TrainingCallback = None): if args.test and not args.train: # Allow testing without LoRA layers by providing empty path if args.adapter_path != "": - apply_lora_layers(model, args.adapter_path) + load_adapters(model, args.adapter_path) elif args.train: print("Training") diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index b15801a5..1d934a72 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -1,5 +1,7 @@ # Copyright © 2024 Apple Inc. +import glob +import shutil import time from dataclasses import dataclass, field from pathlib import Path @@ -285,24 +287,18 @@ def train( # Save adapter weights if it % args.steps_per_save == 0: - save_adapter(model, args.adapter_file) + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.save_safetensors(str(args.adapter_file), adapter_weights) checkpoint = ( Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" ) - save_adapter(model, checkpoint) + mx.save_safetensors(str(checkpoint), adapter_weights) print( f"Iter {it}: Saved adapter weights to " f"{args.adapter_file} and {checkpoint}." ) - # save final adapter weights - save_adapter(model, args.adapter_file) - print(f"Saved final adapter weights to {args.adapter_file}.") - - -def save_adapter( - model: nn.Module, - adapter_file: Union[str, Path], -): - flattened_tree = tree_flatten(model.trainable_parameters()) - mx.save_safetensors(str(adapter_file), dict(flattened_tree)) + # Save final weights + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.save_safetensors(str(args.adapter_file), adapter_weights) + print(f"Saved final weights to {args.adapter_file}.") diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index ab9d37aa..7c78ee91 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -36,7 +36,7 @@ def build_schedule(schedule_config: Dict): def linear_to_lora_layers( model: nn.Module, - num_lora_layers: int, + num_layers: int, config: Dict, use_dora: bool = False, ): @@ -45,22 +45,17 @@ def linear_to_lora_layers( Args: model (nn.Module): The neural network model. - num_lora_layers (int): The number of blocks to convert to lora layers + num_layers (int): The number of blocks to convert to lora layers starting from the last layer. config (dict): More configuration parameters for LoRA, including the rank, scale, and optional layer keys. use_dora (bool): If True, uses DoRA instead of LoRA. Default: ``False`` """ - num_layers = len(model.layers) - - if num_lora_layers < 0: - num_lora_layers = num_layers - - if num_lora_layers > num_layers: + if num_layers > len(model.layers): raise ValueError( - f"Requested {num_lora_layers} LoRA layers " - f"but the model only has {num_layers} layers." + f"Requested {num_layers} LoRA layers " + f"but the model only has {len(model.layers)} layers." ) def to_lora(layer): @@ -151,7 +146,7 @@ def linear_to_lora_layers( else: raise ValueError(f"Lora does not support {model.model_type}") - for l in model.layers[num_layers - num_lora_layers :]: + for l in model.layers[-min(num_layers, 0) :]: lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] if lora_layers: l.update_modules(tree_unflatten(lora_layers)) @@ -161,9 +156,9 @@ def linear_to_lora_layers( model.update_modules(tree_unflatten(lora_modules)) -def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: +def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module: """ - Apply LoRA layers to the model. + Load any fine-tuned adapters / layers. Args: model (nn.Module): The neural network model. @@ -177,12 +172,14 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}") with open(adapter_path / "adapter_config.json", "r") as fid: config = types.SimpleNamespace(**json.load(fid)) - linear_to_lora_layers( - model, - config.lora_layers, - config.lora_parameters, - getattr(config, "use_dora", False), - ) + fine_tune_type = getattr(config, "fine_tune_type", "lora") + if fine_tune_type != "full": + linear_to_lora_layers( + model, + config.num_layers, + config.lora_parameters, + use_dora=(fine_tune_type == "dora"), + ) model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) return model diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 16271c3e..9411138d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -21,8 +21,8 @@ from transformers import PreTrainedTokenizer from .models.base import KVCache, RotatingKVCache from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer -from .tuner.utils import apply_lora_layers from .tuner.utils import dequantize as dequantize_model +from .tuner.utils import load_adapters # Constants MODEL_REMAPPING = { @@ -515,7 +515,7 @@ def load( model = load_model(model_path, lazy, model_config) if adapter_path is not None: - model = apply_lora_layers(model, adapter_path) + model = load_adapters(model, adapter_path) model.eval() tokenizer = load_tokenizer(model_path, tokenizer_config)