From 2bd64b78cf304387e8fea7cc20db684b3a5f8459 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 2 Apr 2024 13:52:53 -0700 Subject: [PATCH] Save lora config (#636) * lora config * comments * version bump --- llms/mlx_lm/LORA.md | 12 ++++---- llms/mlx_lm/fuse.py | 9 +++--- llms/mlx_lm/generate.py | 6 ++-- llms/mlx_lm/lora.py | 26 +++++++++-------- llms/mlx_lm/merge.py | 1 - llms/mlx_lm/server.py | 6 ++-- llms/mlx_lm/tuner/trainer.py | 54 +++++++++++++++--------------------- llms/mlx_lm/tuner/utils.py | 35 +++++++++-------------- llms/mlx_lm/utils.py | 12 ++++---- llms/mlx_lm/version.py | 2 +- 10 files changed, 73 insertions(+), 90 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index d48e7937..04d00ead 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -65,11 +65,11 @@ mistralai/Mistral-7B-v0.1`. If `--model` points to a quantized model, then the training will use QLoRA, otherwise it will use regular LoRA. -By default, the adapter weights are saved in `adapters.npz`. You can specify -the output location with `--adapter-file`. +By default, the adapter config and weights are saved in `adapters/`. You can +specify the output location with `--adapter-path`. You can resume fine-tuning with an existing adapter with -`--resume-adapter-file `. +`--resume-adapter-file `. ### Evaluate @@ -78,7 +78,7 @@ To compute test set perplexity use: ```shell python -m mlx_lm.lora \ --model \ - --adapter-file \ + --adapter-path \ --data \ --test ``` @@ -90,7 +90,7 @@ For generation use `mlx_lm.generate`: ```shell python -m mlx_lm.generate \ --model \ - --adapter-file \ + --adapter-path \ --prompt "" ``` @@ -115,7 +115,7 @@ To generate the fused model run: python -m mlx_lm.fuse --model ``` -This will by default load the adapters from `adapters.npz`, and save the fused +This will by default load the adapters from `adapters/`, and save the fused model in the path `lora_fused_model/`. All of these are configurable. To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py index 44c7eaaa..b362e7b7 100644 --- a/llms/mlx_lm/fuse.py +++ b/llms/mlx_lm/fuse.py @@ -1,6 +1,5 @@ import argparse import glob -import json import shutil from pathlib import Path @@ -31,10 +30,10 @@ def parse_arguments() -> argparse.Namespace: help="The path to save the fused model.", ) parser.add_argument( - "--adapter-file", + "--adapter-path", type=str, - default="adapters.npz", - help="Path to the trained adapter weights (npz or safetensors).", + default="adapters", + help="Path to the trained adapter weights and config.", ) parser.add_argument( "--hf-path", @@ -75,7 +74,7 @@ def main() -> None: model, config, tokenizer = fetch_from_hub(model_path) model.freeze() - model = apply_lora_layers(model, args.adapter_file) + model = apply_lora_layers(model, args.adapter_path) fused_linears = [ (n, m.to_linear()) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index dfdffa1b..6d859c3c 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -24,9 +24,9 @@ def setup_arg_parser(): help="The path to the local model directory or Hugging Face repo.", ) parser.add_argument( - "--adapter-file", + "--adapter-path", type=str, - help="Optional path for the trained adapter weights.", + help="Optional path for the trained adapter weights and config.", ) parser.add_argument( "--trust-remote-code", @@ -110,7 +110,7 @@ def main(args): tokenizer_config["eos_token"] = args.eos_token model, tokenizer = load( - args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config + args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config ) if args.use_default_chat_template: diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 9e94868e..36343262 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -1,7 +1,6 @@ # Copyright © 2024 Apple Inc. import argparse -import json import math import re import types @@ -16,7 +15,7 @@ from mlx.utils import tree_flatten from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.utils import build_schedule, linear_to_lora_layers -from .utils import load +from .utils import load, save_config yaml_loader = yaml.SafeLoader yaml_loader.add_implicit_resolver( @@ -48,7 +47,7 @@ CONFIG_DEFAULTS = { "steps_per_report": 10, "steps_per_eval": 200, "resume_adapter_file": None, - "adapter_file": "adapters.npz", + "adapter_path": "adapters", "save_every": 100, "test": False, "test_batches": 500, @@ -102,12 +101,12 @@ def build_parser(): parser.add_argument( "--resume-adapter-file", type=str, - help="Load path to resume training with the given adapter weights.", + help="Load path to resume training with the given adapters.", ) parser.add_argument( - "--adapter-file", + "--adapter-path", type=str, - help="Save/load path for the trained adapter weights.", + help="Save/load path for the adapters.", ) parser.add_argument( "--save-every", @@ -184,6 +183,11 @@ def run(args, training_callback: TrainingCallback = None): print(f"Loading pretrained adapters from {args.resume_adapter_file}") model.load_weights(args.resume_adapter_file, strict=False) + adapter_path = Path(args.adapter_path) + adapter_path.mkdir(parents=True, exist_ok=True) + save_config(vars(args), adapter_path / "adapter_config.json") + adapter_file = adapter_path / "adapters.safetensors" + if args.train: print("Training") # init training args @@ -194,7 +198,7 @@ def run(args, training_callback: TrainingCallback = None): steps_per_report=args.steps_per_report, steps_per_eval=args.steps_per_eval, steps_per_save=args.save_every, - adapter_file=args.adapter_file, + adapter_file=adapter_file, max_seq_length=args.max_seq_length, grad_checkpoint=args.grad_checkpoint, ) @@ -219,12 +223,12 @@ def run(args, training_callback: TrainingCallback = None): ) # Load the LoRA adapter weights which we assume should exist by this point - if not Path(args.adapter_file).is_file(): + if not adapter_file.is_file(): raise ValueError( - f"Adapter file {args.adapter_file} missing. " - "Use --train to learn and save the adapters.npz." + f"Adapter file {adapter_file} missing. " + "Use --train to learn and save the adapters" ) - model.load_weights(args.adapter_file, strict=False) + model.load_weights(str(adapter_file), strict=False) if args.test: print("Testing") diff --git a/llms/mlx_lm/merge.py b/llms/mlx_lm/merge.py index c1abdb8a..9c88970e 100644 --- a/llms/mlx_lm/merge.py +++ b/llms/mlx_lm/merge.py @@ -2,7 +2,6 @@ import argparse import glob -import json import shutil from pathlib import Path from typing import Optional diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 2540abd2..e717f324 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -418,9 +418,9 @@ if __name__ == "__main__": help="The path to the MLX model weights, tokenizer, and config", ) parser.add_argument( - "--adapter-file", + "--adapter-path", type=str, - help="Optional path for the trained adapter weights.", + help="Optional path for the trained adapter weights and config.", ) parser.add_argument( "--host", @@ -445,7 +445,7 @@ if __name__ == "__main__": tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} MODEL, TOKENIZER = load( - args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config + args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config ) run(args.host, args.port) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index f0d8e0a4..b88c2a9c 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -4,6 +4,7 @@ import time from dataclasses import dataclass, field from functools import partial from pathlib import Path +from typing import Union import mlx.core as mx import mlx.nn as nn @@ -54,7 +55,7 @@ class TrainingArgs: default=2048, metadata={"help": "Maximum sequence length."} ) adapter_file: str = field( - default="adapter.npz", + default="adapters.safetensors", metadata={"help": "Save/load path for the trained adapter weights."}, ) grad_checkpoint: bool = field( @@ -172,18 +173,6 @@ def train( ): print(f"Starting training..., iters: {args.iters}") - def checkpoints_path(adapter_file) -> str: - checkpoints_path = Path("checkpoints") - if Path(adapter_file).parent: - checkpoints_path = Path(adapter_file).parent / "checkpoints" - - checkpoints_path.mkdir(parents=True, exist_ok=True) - - return str(checkpoints_path) - - # Create checkpoints directory if it does not exist - adapter_path = checkpoints_path(args.adapter_file) - if args.grad_checkpoint: grad_checkpoint(model.layers[0]) @@ -206,7 +195,7 @@ def train( # Main training loop start = time.perf_counter() for it, batch in zip( - range(args.iters), + range(1, args.iters + 1), iterate_batches( dataset=train_dataset, tokenizer=tokenizer, @@ -223,7 +212,7 @@ def train( n_tokens += toks.item() # Report training loss if needed - if ((it + 1) % args.steps_per_report == 0) or (it + 1 == args.iters): + if it % args.steps_per_report == 0 or it == args.iters: train_loss = np.mean(losses) stop = time.perf_counter() @@ -233,7 +222,7 @@ def train( trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 2**30 print( - f"Iter {it + 1}: Train loss {train_loss:.3f}, " + f"Iter {it}: Train loss {train_loss:.3f}, " f"Learning Rate {learning_rate:.3e}, " f"It/sec {it_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, " @@ -243,7 +232,7 @@ def train( if training_callback is not None: train_info = { - "iteration": it + 1, + "iteration": it, "train_loss": train_loss, "learning_rate": learning_rate, "iterations_per_second": it_sec, @@ -258,7 +247,7 @@ def train( start = time.perf_counter() # Report validation loss if needed - if it == 0 or ((it + 1) % args.steps_per_eval == 0) or (it + 1 == args.iters): + if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: stop = time.perf_counter() val_loss = evaluate( model=model, @@ -272,14 +261,12 @@ def train( ) val_time = time.perf_counter() - stop print( - f"Iter {it + 1}: " - f"Val loss {val_loss:.3f}, " - f"Val took {val_time:.3f}s" + f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s" ) if training_callback is not None: val_info = { - "iteration": it + 1, + "iteration": it, "val_loss": val_loss, "val_time": val_time, } @@ -287,23 +274,26 @@ def train( start = time.perf_counter() - # Save adapter weights if needed - if (it + 1) % args.steps_per_save == 0: - checkpoint_adapter_file = ( - f"{adapter_path}/{it + 1}_{Path(args.adapter_file).name}" + # Save adapter weights + if it % args.steps_per_save == 0: + save_adapter(model, args.adapter_file) + checkpoint = ( + Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" + ) + save_adapter(model, checkpoint) + print( + f"Iter {it}: Saved adapter weights to " + f"{args.adapter_file} and {checkpoint}." ) - save_adapter(model=model, adapter_file=checkpoint_adapter_file) - print(f"Iter {it + 1}: Saved adapter weights to {checkpoint_adapter_file}.") # save final adapter weights - save_adapter(model=model, adapter_file=args.adapter_file) + save_adapter(model, args.adapter_file) print(f"Saved final adapter weights to {args.adapter_file}.") def save_adapter( model: nn.Module, - adapter_file: str, + adapter_file: Union[str, Path], ): flattened_tree = tree_flatten(model.trainable_parameters()) - - mx.savez(adapter_file, **dict(flattened_tree)) + mx.save_safetensors(str(adapter_file), dict(flattened_tree)) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 40d42ee4..b098e1bc 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -1,4 +1,7 @@ -import os +# Copyright © 2024 Apple Inc. +import json +import types +from pathlib import Path from typing import Dict import mlx.core as mx @@ -91,40 +94,28 @@ def linear_to_lora_layers( raise ValueError(f"Lora does not support {model.model_type}") for l in model.layers[num_layers - num_lora_layers :]: - modules = l.named_modules() lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] l.update_modules(tree_unflatten(lora_layers)) -def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module: +def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: """ Apply LoRA layers to the model. Args: model (nn.Module): The neural network model. - adapter_file (str): Path to the adapter configuration file. + adapter_path (str): Path to the adapter configuration file. Returns: nn.Module: The updated model with LoRA layers applied. """ - if not os.path.exists(adapter_file): - raise FileNotFoundError(f"The adapter file does not exist: {adapter_file}") - - adapters = list(mx.load(adapter_file).items()) - - linear_replacements = [] - lora_layers = set( - [name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters] - ) - for name, module in model.named_modules(): - if name in lora_layers: - replacement_module = LoRALinear.from_linear(module) - linear_replacements.append((name, replacement_module)) - - model.update_modules(tree_unflatten(linear_replacements)) - - model.update(tree_unflatten(adapters)) - + adapter_path = Path(adapter_path) + if not adapter_path.exists(): + raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}") + with open(adapter_path / "adapter_config.json", "r") as fid: + config = types.SimpleNamespace(**json.load(fid)) + linear_to_lora_layers(model, config.lora_layers, config.lora_parameters) + model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) return model diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index bf42a5d1..02ad2294 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -9,7 +9,7 @@ import shutil import time from pathlib import Path from textwrap import dedent -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -354,7 +354,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: def load( path_or_hf_repo: str, tokenizer_config={}, - adapter_file: Optional[str] = None, + adapter_path: Optional[str] = None, lazy: bool = False, ) -> Tuple[nn.Module, PreTrainedTokenizer]: """ @@ -364,8 +364,8 @@ def load( path_or_hf_repo (Path): The path or the huggingface repository to load the model from. tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. Defaults to an empty dictionary. - adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model. - Defaults to None. + adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers + to the model. Default: ``None``. lazy (bool): If False eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` @@ -379,8 +379,8 @@ def load( model_path = get_model_path(path_or_hf_repo) model = load_model(model_path, lazy) - if adapter_file is not None: - model = apply_lora_layers(model, adapter_file) + if adapter_path is not None: + model = apply_lora_layers(model, adapter_path) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config) diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 45e522d1..8df927d4 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.4.0" +__version__ = "0.6.0"