Save lora config (#636)

* lora config

* comments

* version bump
This commit is contained in:
Awni Hannun 2024-04-02 13:52:53 -07:00 committed by GitHub
parent d661440dbb
commit 2bd64b78cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 73 additions and 90 deletions

View File

@ -65,11 +65,11 @@ mistralai/Mistral-7B-v0.1`.
If `--model` points to a quantized model, then the training will use QLoRA, If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA. otherwise it will use regular LoRA.
By default, the adapter weights are saved in `adapters.npz`. You can specify By default, the adapter config and weights are saved in `adapters/`. You can
the output location with `--adapter-file`. specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.npz>`. `--resume-adapter-file <path_to_adapters.safetensors>`.
### Evaluate ### Evaluate
@ -78,7 +78,7 @@ To compute test set perplexity use:
```shell ```shell
python -m mlx_lm.lora \ python -m mlx_lm.lora \
--model <path_to_model> \ --model <path_to_model> \
--adapter-file <path_to_adapters.npz> \ --adapter-path <path_to_adapters> \
--data <path_to_data> \ --data <path_to_data> \
--test --test
``` ```
@ -90,7 +90,7 @@ For generation use `mlx_lm.generate`:
```shell ```shell
python -m mlx_lm.generate \ python -m mlx_lm.generate \
--model <path_to_model> \ --model <path_to_model> \
--adapter-file <path_to_adapters.npz> \ --adapter-path <path_to_adapters> \
--prompt "<your_model_prompt>" --prompt "<your_model_prompt>"
``` ```
@ -115,7 +115,7 @@ To generate the fused model run:
python -m mlx_lm.fuse --model <path_to_model> python -m mlx_lm.fuse --model <path_to_model>
``` ```
This will by default load the adapters from `adapters.npz`, and save the fused This will by default load the adapters from `adapters/`, and save the fused
model in the path `lora_fused_model/`. All of these are configurable. model in the path `lora_fused_model/`. All of these are configurable.
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments

View File

@ -1,6 +1,5 @@
import argparse import argparse
import glob import glob
import json
import shutil import shutil
from pathlib import Path from pathlib import Path
@ -31,10 +30,10 @@ def parse_arguments() -> argparse.Namespace:
help="The path to save the fused model.", help="The path to save the fused model.",
) )
parser.add_argument( parser.add_argument(
"--adapter-file", "--adapter-path",
type=str, type=str,
default="adapters.npz", default="adapters",
help="Path to the trained adapter weights (npz or safetensors).", help="Path to the trained adapter weights and config.",
) )
parser.add_argument( parser.add_argument(
"--hf-path", "--hf-path",
@ -75,7 +74,7 @@ def main() -> None:
model, config, tokenizer = fetch_from_hub(model_path) model, config, tokenizer = fetch_from_hub(model_path)
model.freeze() model.freeze()
model = apply_lora_layers(model, args.adapter_file) model = apply_lora_layers(model, args.adapter_path)
fused_linears = [ fused_linears = [
(n, m.to_linear()) (n, m.to_linear())

View File

@ -24,9 +24,9 @@ def setup_arg_parser():
help="The path to the local model directory or Hugging Face repo.", help="The path to the local model directory or Hugging Face repo.",
) )
parser.add_argument( parser.add_argument(
"--adapter-file", "--adapter-path",
type=str, type=str,
help="Optional path for the trained adapter weights.", help="Optional path for the trained adapter weights and config.",
) )
parser.add_argument( parser.add_argument(
"--trust-remote-code", "--trust-remote-code",
@ -110,7 +110,7 @@ def main(args):
tokenizer_config["eos_token"] = args.eos_token tokenizer_config["eos_token"] = args.eos_token
model, tokenizer = load( model, tokenizer = load(
args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
) )
if args.use_default_chat_template: if args.use_default_chat_template:

View File

@ -1,7 +1,6 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import argparse import argparse
import json
import math import math
import re import re
import types import types
@ -16,7 +15,7 @@ from mlx.utils import tree_flatten
from .tuner.datasets import load_dataset from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import build_schedule, linear_to_lora_layers from .tuner.utils import build_schedule, linear_to_lora_layers
from .utils import load from .utils import load, save_config
yaml_loader = yaml.SafeLoader yaml_loader = yaml.SafeLoader
yaml_loader.add_implicit_resolver( yaml_loader.add_implicit_resolver(
@ -48,7 +47,7 @@ CONFIG_DEFAULTS = {
"steps_per_report": 10, "steps_per_report": 10,
"steps_per_eval": 200, "steps_per_eval": 200,
"resume_adapter_file": None, "resume_adapter_file": None,
"adapter_file": "adapters.npz", "adapter_path": "adapters",
"save_every": 100, "save_every": 100,
"test": False, "test": False,
"test_batches": 500, "test_batches": 500,
@ -102,12 +101,12 @@ def build_parser():
parser.add_argument( parser.add_argument(
"--resume-adapter-file", "--resume-adapter-file",
type=str, type=str,
help="Load path to resume training with the given adapter weights.", help="Load path to resume training with the given adapters.",
) )
parser.add_argument( parser.add_argument(
"--adapter-file", "--adapter-path",
type=str, type=str,
help="Save/load path for the trained adapter weights.", help="Save/load path for the adapters.",
) )
parser.add_argument( parser.add_argument(
"--save-every", "--save-every",
@ -184,6 +183,11 @@ def run(args, training_callback: TrainingCallback = None):
print(f"Loading pretrained adapters from {args.resume_adapter_file}") print(f"Loading pretrained adapters from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False) model.load_weights(args.resume_adapter_file, strict=False)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), adapter_path / "adapter_config.json")
adapter_file = adapter_path / "adapters.safetensors"
if args.train: if args.train:
print("Training") print("Training")
# init training args # init training args
@ -194,7 +198,7 @@ def run(args, training_callback: TrainingCallback = None):
steps_per_report=args.steps_per_report, steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval, steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every, steps_per_save=args.save_every,
adapter_file=args.adapter_file, adapter_file=adapter_file,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint, grad_checkpoint=args.grad_checkpoint,
) )
@ -219,12 +223,12 @@ def run(args, training_callback: TrainingCallback = None):
) )
# Load the LoRA adapter weights which we assume should exist by this point # Load the LoRA adapter weights which we assume should exist by this point
if not Path(args.adapter_file).is_file(): if not adapter_file.is_file():
raise ValueError( raise ValueError(
f"Adapter file {args.adapter_file} missing. " f"Adapter file {adapter_file} missing. "
"Use --train to learn and save the adapters.npz." "Use --train to learn and save the adapters"
) )
model.load_weights(args.adapter_file, strict=False) model.load_weights(str(adapter_file), strict=False)
if args.test: if args.test:
print("Testing") print("Testing")

View File

@ -2,7 +2,6 @@
import argparse import argparse
import glob import glob
import json
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional

View File

@ -418,9 +418,9 @@ if __name__ == "__main__":
help="The path to the MLX model weights, tokenizer, and config", help="The path to the MLX model weights, tokenizer, and config",
) )
parser.add_argument( parser.add_argument(
"--adapter-file", "--adapter-path",
type=str, type=str,
help="Optional path for the trained adapter weights.", help="Optional path for the trained adapter weights and config.",
) )
parser.add_argument( parser.add_argument(
"--host", "--host",
@ -445,7 +445,7 @@ if __name__ == "__main__":
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
MODEL, TOKENIZER = load( MODEL, TOKENIZER = load(
args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
) )
run(args.host, args.port) run(args.host, args.port)

View File

@ -4,6 +4,7 @@ import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -54,7 +55,7 @@ class TrainingArgs:
default=2048, metadata={"help": "Maximum sequence length."} default=2048, metadata={"help": "Maximum sequence length."}
) )
adapter_file: str = field( adapter_file: str = field(
default="adapter.npz", default="adapters.safetensors",
metadata={"help": "Save/load path for the trained adapter weights."}, metadata={"help": "Save/load path for the trained adapter weights."},
) )
grad_checkpoint: bool = field( grad_checkpoint: bool = field(
@ -172,18 +173,6 @@ def train(
): ):
print(f"Starting training..., iters: {args.iters}") print(f"Starting training..., iters: {args.iters}")
def checkpoints_path(adapter_file) -> str:
checkpoints_path = Path("checkpoints")
if Path(adapter_file).parent:
checkpoints_path = Path(adapter_file).parent / "checkpoints"
checkpoints_path.mkdir(parents=True, exist_ok=True)
return str(checkpoints_path)
# Create checkpoints directory if it does not exist
adapter_path = checkpoints_path(args.adapter_file)
if args.grad_checkpoint: if args.grad_checkpoint:
grad_checkpoint(model.layers[0]) grad_checkpoint(model.layers[0])
@ -206,7 +195,7 @@ def train(
# Main training loop # Main training loop
start = time.perf_counter() start = time.perf_counter()
for it, batch in zip( for it, batch in zip(
range(args.iters), range(1, args.iters + 1),
iterate_batches( iterate_batches(
dataset=train_dataset, dataset=train_dataset,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -223,7 +212,7 @@ def train(
n_tokens += toks.item() n_tokens += toks.item()
# Report training loss if needed # Report training loss if needed
if ((it + 1) % args.steps_per_report == 0) or (it + 1 == args.iters): if it % args.steps_per_report == 0 or it == args.iters:
train_loss = np.mean(losses) train_loss = np.mean(losses)
stop = time.perf_counter() stop = time.perf_counter()
@ -233,7 +222,7 @@ def train(
trained_tokens += n_tokens trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 2**30 peak_mem = mx.metal.get_peak_memory() / 2**30
print( print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, " f"Iter {it}: Train loss {train_loss:.3f}, "
f"Learning Rate {learning_rate:.3e}, " f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, " f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, "
@ -243,7 +232,7 @@ def train(
if training_callback is not None: if training_callback is not None:
train_info = { train_info = {
"iteration": it + 1, "iteration": it,
"train_loss": train_loss, "train_loss": train_loss,
"learning_rate": learning_rate, "learning_rate": learning_rate,
"iterations_per_second": it_sec, "iterations_per_second": it_sec,
@ -258,7 +247,7 @@ def train(
start = time.perf_counter() start = time.perf_counter()
# Report validation loss if needed # Report validation loss if needed
if it == 0 or ((it + 1) % args.steps_per_eval == 0) or (it + 1 == args.iters): if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter() stop = time.perf_counter()
val_loss = evaluate( val_loss = evaluate(
model=model, model=model,
@ -272,14 +261,12 @@ def train(
) )
val_time = time.perf_counter() - stop val_time = time.perf_counter() - stop
print( print(
f"Iter {it + 1}: " f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
f"Val loss {val_loss:.3f}, "
f"Val took {val_time:.3f}s"
) )
if training_callback is not None: if training_callback is not None:
val_info = { val_info = {
"iteration": it + 1, "iteration": it,
"val_loss": val_loss, "val_loss": val_loss,
"val_time": val_time, "val_time": val_time,
} }
@ -287,23 +274,26 @@ def train(
start = time.perf_counter() start = time.perf_counter()
# Save adapter weights if needed # Save adapter weights
if (it + 1) % args.steps_per_save == 0: if it % args.steps_per_save == 0:
checkpoint_adapter_file = ( save_adapter(model, args.adapter_file)
f"{adapter_path}/{it + 1}_{Path(args.adapter_file).name}" checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
save_adapter(model, checkpoint)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
) )
save_adapter(model=model, adapter_file=checkpoint_adapter_file)
print(f"Iter {it + 1}: Saved adapter weights to {checkpoint_adapter_file}.")
# save final adapter weights # save final adapter weights
save_adapter(model=model, adapter_file=args.adapter_file) save_adapter(model, args.adapter_file)
print(f"Saved final adapter weights to {args.adapter_file}.") print(f"Saved final adapter weights to {args.adapter_file}.")
def save_adapter( def save_adapter(
model: nn.Module, model: nn.Module,
adapter_file: str, adapter_file: Union[str, Path],
): ):
flattened_tree = tree_flatten(model.trainable_parameters()) flattened_tree = tree_flatten(model.trainable_parameters())
mx.save_safetensors(str(adapter_file), dict(flattened_tree))
mx.savez(adapter_file, **dict(flattened_tree))

View File

@ -1,4 +1,7 @@
import os # Copyright © 2024 Apple Inc.
import json
import types
from pathlib import Path
from typing import Dict from typing import Dict
import mlx.core as mx import mlx.core as mx
@ -91,40 +94,28 @@ def linear_to_lora_layers(
raise ValueError(f"Lora does not support {model.model_type}") raise ValueError(f"Lora does not support {model.model_type}")
for l in model.layers[num_layers - num_lora_layers :]: for l in model.layers[num_layers - num_lora_layers :]:
modules = l.named_modules()
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
l.update_modules(tree_unflatten(lora_layers)) l.update_modules(tree_unflatten(lora_layers))
def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module: def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
""" """
Apply LoRA layers to the model. Apply LoRA layers to the model.
Args: Args:
model (nn.Module): The neural network model. model (nn.Module): The neural network model.
adapter_file (str): Path to the adapter configuration file. adapter_path (str): Path to the adapter configuration file.
Returns: Returns:
nn.Module: The updated model with LoRA layers applied. nn.Module: The updated model with LoRA layers applied.
""" """
if not os.path.exists(adapter_file): adapter_path = Path(adapter_path)
raise FileNotFoundError(f"The adapter file does not exist: {adapter_file}") if not adapter_path.exists():
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
adapters = list(mx.load(adapter_file).items()) with open(adapter_path / "adapter_config.json", "r") as fid:
config = types.SimpleNamespace(**json.load(fid))
linear_replacements = [] linear_to_lora_layers(model, config.lora_layers, config.lora_parameters)
lora_layers = set( model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
[name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters]
)
for name, module in model.named_modules():
if name in lora_layers:
replacement_module = LoRALinear.from_linear(module)
linear_replacements.append((name, replacement_module))
model.update_modules(tree_unflatten(linear_replacements))
model.update(tree_unflatten(adapters))
return model return model

View File

@ -9,7 +9,7 @@ import shutil
import time import time
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -354,7 +354,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
def load( def load(
path_or_hf_repo: str, path_or_hf_repo: str,
tokenizer_config={}, tokenizer_config={},
adapter_file: Optional[str] = None, adapter_path: Optional[str] = None,
lazy: bool = False, lazy: bool = False,
) -> Tuple[nn.Module, PreTrainedTokenizer]: ) -> Tuple[nn.Module, PreTrainedTokenizer]:
""" """
@ -364,8 +364,8 @@ def load(
path_or_hf_repo (Path): The path or the huggingface repository to load the model from. path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary. Defaults to an empty dictionary.
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model. adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
Defaults to None. to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
@ -379,8 +379,8 @@ def load(
model_path = get_model_path(path_or_hf_repo) model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path, lazy) model = load_model(model_path, lazy)
if adapter_file is not None: if adapter_path is not None:
model = apply_lora_layers(model, adapter_file) model = apply_lora_layers(model, adapter_path)
model.eval() model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config) tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.4.0" __version__ = "0.6.0"