Save lora config (#636)

* lora config

* comments

* version bump
This commit is contained in:
Awni Hannun 2024-04-02 13:52:53 -07:00 committed by GitHub
parent d661440dbb
commit 2bd64b78cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 73 additions and 90 deletions

View File

@ -65,11 +65,11 @@ mistralai/Mistral-7B-v0.1`.
If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA.
By default, the adapter weights are saved in `adapters.npz`. You can specify
the output location with `--adapter-file`.
By default, the adapter config and weights are saved in `adapters/`. You can
specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.npz>`.
`--resume-adapter-file <path_to_adapters.safetensors>`.
### Evaluate
@ -78,7 +78,7 @@ To compute test set perplexity use:
```shell
python -m mlx_lm.lora \
--model <path_to_model> \
--adapter-file <path_to_adapters.npz> \
--adapter-path <path_to_adapters> \
--data <path_to_data> \
--test
```
@ -90,7 +90,7 @@ For generation use `mlx_lm.generate`:
```shell
python -m mlx_lm.generate \
--model <path_to_model> \
--adapter-file <path_to_adapters.npz> \
--adapter-path <path_to_adapters> \
--prompt "<your_model_prompt>"
```
@ -115,7 +115,7 @@ To generate the fused model run:
python -m mlx_lm.fuse --model <path_to_model>
```
This will by default load the adapters from `adapters.npz`, and save the fused
This will by default load the adapters from `adapters/`, and save the fused
model in the path `lora_fused_model/`. All of these are configurable.
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments

View File

@ -1,6 +1,5 @@
import argparse
import glob
import json
import shutil
from pathlib import Path
@ -31,10 +30,10 @@ def parse_arguments() -> argparse.Namespace:
help="The path to save the fused model.",
)
parser.add_argument(
"--adapter-file",
"--adapter-path",
type=str,
default="adapters.npz",
help="Path to the trained adapter weights (npz or safetensors).",
default="adapters",
help="Path to the trained adapter weights and config.",
)
parser.add_argument(
"--hf-path",
@ -75,7 +74,7 @@ def main() -> None:
model, config, tokenizer = fetch_from_hub(model_path)
model.freeze()
model = apply_lora_layers(model, args.adapter_file)
model = apply_lora_layers(model, args.adapter_path)
fused_linears = [
(n, m.to_linear())

View File

@ -24,9 +24,9 @@ def setup_arg_parser():
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--adapter-file",
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights.",
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--trust-remote-code",
@ -110,7 +110,7 @@ def main(args):
tokenizer_config["eos_token"] = args.eos_token
model, tokenizer = load(
args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
)
if args.use_default_chat_template:

View File

@ -1,7 +1,6 @@
# Copyright © 2024 Apple Inc.
import argparse
import json
import math
import re
import types
@ -16,7 +15,7 @@ from mlx.utils import tree_flatten
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import build_schedule, linear_to_lora_layers
from .utils import load
from .utils import load, save_config
yaml_loader = yaml.SafeLoader
yaml_loader.add_implicit_resolver(
@ -48,7 +47,7 @@ CONFIG_DEFAULTS = {
"steps_per_report": 10,
"steps_per_eval": 200,
"resume_adapter_file": None,
"adapter_file": "adapters.npz",
"adapter_path": "adapters",
"save_every": 100,
"test": False,
"test_batches": 500,
@ -102,12 +101,12 @@ def build_parser():
parser.add_argument(
"--resume-adapter-file",
type=str,
help="Load path to resume training with the given adapter weights.",
help="Load path to resume training with the given adapters.",
)
parser.add_argument(
"--adapter-file",
"--adapter-path",
type=str,
help="Save/load path for the trained adapter weights.",
help="Save/load path for the adapters.",
)
parser.add_argument(
"--save-every",
@ -184,6 +183,11 @@ def run(args, training_callback: TrainingCallback = None):
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), adapter_path / "adapter_config.json")
adapter_file = adapter_path / "adapters.safetensors"
if args.train:
print("Training")
# init training args
@ -194,7 +198,7 @@ def run(args, training_callback: TrainingCallback = None):
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=args.adapter_file,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)
@ -219,12 +223,12 @@ def run(args, training_callback: TrainingCallback = None):
)
# Load the LoRA adapter weights which we assume should exist by this point
if not Path(args.adapter_file).is_file():
if not adapter_file.is_file():
raise ValueError(
f"Adapter file {args.adapter_file} missing. "
"Use --train to learn and save the adapters.npz."
f"Adapter file {adapter_file} missing. "
"Use --train to learn and save the adapters"
)
model.load_weights(args.adapter_file, strict=False)
model.load_weights(str(adapter_file), strict=False)
if args.test:
print("Testing")

View File

@ -2,7 +2,6 @@
import argparse
import glob
import json
import shutil
from pathlib import Path
from typing import Optional

View File

@ -418,9 +418,9 @@ if __name__ == "__main__":
help="The path to the MLX model weights, tokenizer, and config",
)
parser.add_argument(
"--adapter-file",
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights.",
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--host",
@ -445,7 +445,7 @@ if __name__ == "__main__":
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
MODEL, TOKENIZER = load(
args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
)
run(args.host, args.port)

View File

@ -4,6 +4,7 @@ import time
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Union
import mlx.core as mx
import mlx.nn as nn
@ -54,7 +55,7 @@ class TrainingArgs:
default=2048, metadata={"help": "Maximum sequence length."}
)
adapter_file: str = field(
default="adapter.npz",
default="adapters.safetensors",
metadata={"help": "Save/load path for the trained adapter weights."},
)
grad_checkpoint: bool = field(
@ -172,18 +173,6 @@ def train(
):
print(f"Starting training..., iters: {args.iters}")
def checkpoints_path(adapter_file) -> str:
checkpoints_path = Path("checkpoints")
if Path(adapter_file).parent:
checkpoints_path = Path(adapter_file).parent / "checkpoints"
checkpoints_path.mkdir(parents=True, exist_ok=True)
return str(checkpoints_path)
# Create checkpoints directory if it does not exist
adapter_path = checkpoints_path(args.adapter_file)
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
@ -206,7 +195,7 @@ def train(
# Main training loop
start = time.perf_counter()
for it, batch in zip(
range(args.iters),
range(1, args.iters + 1),
iterate_batches(
dataset=train_dataset,
tokenizer=tokenizer,
@ -223,7 +212,7 @@ def train(
n_tokens += toks.item()
# Report training loss if needed
if ((it + 1) % args.steps_per_report == 0) or (it + 1 == args.iters):
if it % args.steps_per_report == 0 or it == args.iters:
train_loss = np.mean(losses)
stop = time.perf_counter()
@ -233,7 +222,7 @@ def train(
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 2**30
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"Iter {it}: Train loss {train_loss:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
@ -243,7 +232,7 @@ def train(
if training_callback is not None:
train_info = {
"iteration": it + 1,
"iteration": it,
"train_loss": train_loss,
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
@ -258,7 +247,7 @@ def train(
start = time.perf_counter()
# Report validation loss if needed
if it == 0 or ((it + 1) % args.steps_per_eval == 0) or (it + 1 == args.iters):
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss = evaluate(
model=model,
@ -272,14 +261,12 @@ def train(
)
val_time = time.perf_counter() - stop
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val took {val_time:.3f}s"
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
)
if training_callback is not None:
val_info = {
"iteration": it + 1,
"iteration": it,
"val_loss": val_loss,
"val_time": val_time,
}
@ -287,23 +274,26 @@ def train(
start = time.perf_counter()
# Save adapter weights if needed
if (it + 1) % args.steps_per_save == 0:
checkpoint_adapter_file = (
f"{adapter_path}/{it + 1}_{Path(args.adapter_file).name}"
# Save adapter weights
if it % args.steps_per_save == 0:
save_adapter(model, args.adapter_file)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
save_adapter(model, checkpoint)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
save_adapter(model=model, adapter_file=checkpoint_adapter_file)
print(f"Iter {it + 1}: Saved adapter weights to {checkpoint_adapter_file}.")
# save final adapter weights
save_adapter(model=model, adapter_file=args.adapter_file)
save_adapter(model, args.adapter_file)
print(f"Saved final adapter weights to {args.adapter_file}.")
def save_adapter(
model: nn.Module,
adapter_file: str,
adapter_file: Union[str, Path],
):
flattened_tree = tree_flatten(model.trainable_parameters())
mx.savez(adapter_file, **dict(flattened_tree))
mx.save_safetensors(str(adapter_file), dict(flattened_tree))

View File

@ -1,4 +1,7 @@
import os
# Copyright © 2024 Apple Inc.
import json
import types
from pathlib import Path
from typing import Dict
import mlx.core as mx
@ -91,40 +94,28 @@ def linear_to_lora_layers(
raise ValueError(f"Lora does not support {model.model_type}")
for l in model.layers[num_layers - num_lora_layers :]:
modules = l.named_modules()
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
l.update_modules(tree_unflatten(lora_layers))
def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module:
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
"""
Apply LoRA layers to the model.
Args:
model (nn.Module): The neural network model.
adapter_file (str): Path to the adapter configuration file.
adapter_path (str): Path to the adapter configuration file.
Returns:
nn.Module: The updated model with LoRA layers applied.
"""
if not os.path.exists(adapter_file):
raise FileNotFoundError(f"The adapter file does not exist: {adapter_file}")
adapters = list(mx.load(adapter_file).items())
linear_replacements = []
lora_layers = set(
[name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters]
)
for name, module in model.named_modules():
if name in lora_layers:
replacement_module = LoRALinear.from_linear(module)
linear_replacements.append((name, replacement_module))
model.update_modules(tree_unflatten(linear_replacements))
model.update(tree_unflatten(adapters))
adapter_path = Path(adapter_path)
if not adapter_path.exists():
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
with open(adapter_path / "adapter_config.json", "r") as fid:
config = types.SimpleNamespace(**json.load(fid))
linear_to_lora_layers(model, config.lora_layers, config.lora_parameters)
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
return model

View File

@ -9,7 +9,7 @@ import shutil
import time
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@ -354,7 +354,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
def load(
path_or_hf_repo: str,
tokenizer_config={},
adapter_file: Optional[str] = None,
adapter_path: Optional[str] = None,
lazy: bool = False,
) -> Tuple[nn.Module, PreTrainedTokenizer]:
"""
@ -364,8 +364,8 @@ def load(
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.
Defaults to None.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
@ -379,8 +379,8 @@ def load(
model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path, lazy)
if adapter_file is not None:
model = apply_lora_layers(model, adapter_file)
if adapter_path is not None:
model = apply_lora_layers(model, adapter_path)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.4.0"
__version__ = "0.6.0"