mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
parent
d661440dbb
commit
2bd64b78cf
@ -65,11 +65,11 @@ mistralai/Mistral-7B-v0.1`.
|
||||
If `--model` points to a quantized model, then the training will use QLoRA,
|
||||
otherwise it will use regular LoRA.
|
||||
|
||||
By default, the adapter weights are saved in `adapters.npz`. You can specify
|
||||
the output location with `--adapter-file`.
|
||||
By default, the adapter config and weights are saved in `adapters/`. You can
|
||||
specify the output location with `--adapter-path`.
|
||||
|
||||
You can resume fine-tuning with an existing adapter with
|
||||
`--resume-adapter-file <path_to_adapters.npz>`.
|
||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||
|
||||
### Evaluate
|
||||
|
||||
@ -78,7 +78,7 @@ To compute test set perplexity use:
|
||||
```shell
|
||||
python -m mlx_lm.lora \
|
||||
--model <path_to_model> \
|
||||
--adapter-file <path_to_adapters.npz> \
|
||||
--adapter-path <path_to_adapters> \
|
||||
--data <path_to_data> \
|
||||
--test
|
||||
```
|
||||
@ -90,7 +90,7 @@ For generation use `mlx_lm.generate`:
|
||||
```shell
|
||||
python -m mlx_lm.generate \
|
||||
--model <path_to_model> \
|
||||
--adapter-file <path_to_adapters.npz> \
|
||||
--adapter-path <path_to_adapters> \
|
||||
--prompt "<your_model_prompt>"
|
||||
```
|
||||
|
||||
@ -115,7 +115,7 @@ To generate the fused model run:
|
||||
python -m mlx_lm.fuse --model <path_to_model>
|
||||
```
|
||||
|
||||
This will by default load the adapters from `adapters.npz`, and save the fused
|
||||
This will by default load the adapters from `adapters/`, and save the fused
|
||||
model in the path `lora_fused_model/`. All of these are configurable.
|
||||
|
||||
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
@ -31,10 +30,10 @@ def parse_arguments() -> argparse.Namespace:
|
||||
help="The path to save the fused model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-file",
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
default="adapters.npz",
|
||||
help="Path to the trained adapter weights (npz or safetensors).",
|
||||
default="adapters",
|
||||
help="Path to the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
@ -75,7 +74,7 @@ def main() -> None:
|
||||
model, config, tokenizer = fetch_from_hub(model_path)
|
||||
|
||||
model.freeze()
|
||||
model = apply_lora_layers(model, args.adapter_file)
|
||||
model = apply_lora_layers(model, args.adapter_path)
|
||||
|
||||
fused_linears = [
|
||||
(n, m.to_linear())
|
||||
|
@ -24,9 +24,9 @@ def setup_arg_parser():
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-file",
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Optional path for the trained adapter weights.",
|
||||
help="Optional path for the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
@ -110,7 +110,7 @@ def main(args):
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
model, tokenizer = load(
|
||||
args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config
|
||||
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
||||
)
|
||||
|
||||
if args.use_default_chat_template:
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import types
|
||||
@ -16,7 +15,7 @@ from mlx.utils import tree_flatten
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import build_schedule, linear_to_lora_layers
|
||||
from .utils import load
|
||||
from .utils import load, save_config
|
||||
|
||||
yaml_loader = yaml.SafeLoader
|
||||
yaml_loader.add_implicit_resolver(
|
||||
@ -48,7 +47,7 @@ CONFIG_DEFAULTS = {
|
||||
"steps_per_report": 10,
|
||||
"steps_per_eval": 200,
|
||||
"resume_adapter_file": None,
|
||||
"adapter_file": "adapters.npz",
|
||||
"adapter_path": "adapters",
|
||||
"save_every": 100,
|
||||
"test": False,
|
||||
"test_batches": 500,
|
||||
@ -102,12 +101,12 @@ def build_parser():
|
||||
parser.add_argument(
|
||||
"--resume-adapter-file",
|
||||
type=str,
|
||||
help="Load path to resume training with the given adapter weights.",
|
||||
help="Load path to resume training with the given adapters.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-file",
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Save/load path for the trained adapter weights.",
|
||||
help="Save/load path for the adapters.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-every",
|
||||
@ -184,6 +183,11 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file, strict=False)
|
||||
|
||||
adapter_path = Path(args.adapter_path)
|
||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
|
||||
if args.train:
|
||||
print("Training")
|
||||
# init training args
|
||||
@ -194,7 +198,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=args.adapter_file,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
)
|
||||
@ -219,12 +223,12 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
)
|
||||
|
||||
# Load the LoRA adapter weights which we assume should exist by this point
|
||||
if not Path(args.adapter_file).is_file():
|
||||
if not adapter_file.is_file():
|
||||
raise ValueError(
|
||||
f"Adapter file {args.adapter_file} missing. "
|
||||
"Use --train to learn and save the adapters.npz."
|
||||
f"Adapter file {adapter_file} missing. "
|
||||
"Use --train to learn and save the adapters"
|
||||
)
|
||||
model.load_weights(args.adapter_file, strict=False)
|
||||
model.load_weights(str(adapter_file), strict=False)
|
||||
|
||||
if args.test:
|
||||
print("Testing")
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
@ -418,9 +418,9 @@ if __name__ == "__main__":
|
||||
help="The path to the MLX model weights, tokenizer, and config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-file",
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Optional path for the trained adapter weights.",
|
||||
help="Optional path for the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
@ -445,7 +445,7 @@ if __name__ == "__main__":
|
||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||
|
||||
MODEL, TOKENIZER = load(
|
||||
args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config
|
||||
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
||||
)
|
||||
|
||||
run(args.host, args.port)
|
||||
|
@ -4,6 +4,7 @@ import time
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -54,7 +55,7 @@ class TrainingArgs:
|
||||
default=2048, metadata={"help": "Maximum sequence length."}
|
||||
)
|
||||
adapter_file: str = field(
|
||||
default="adapter.npz",
|
||||
default="adapters.safetensors",
|
||||
metadata={"help": "Save/load path for the trained adapter weights."},
|
||||
)
|
||||
grad_checkpoint: bool = field(
|
||||
@ -172,18 +173,6 @@ def train(
|
||||
):
|
||||
print(f"Starting training..., iters: {args.iters}")
|
||||
|
||||
def checkpoints_path(adapter_file) -> str:
|
||||
checkpoints_path = Path("checkpoints")
|
||||
if Path(adapter_file).parent:
|
||||
checkpoints_path = Path(adapter_file).parent / "checkpoints"
|
||||
|
||||
checkpoints_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return str(checkpoints_path)
|
||||
|
||||
# Create checkpoints directory if it does not exist
|
||||
adapter_path = checkpoints_path(args.adapter_file)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
grad_checkpoint(model.layers[0])
|
||||
|
||||
@ -206,7 +195,7 @@ def train(
|
||||
# Main training loop
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
range(args.iters),
|
||||
range(1, args.iters + 1),
|
||||
iterate_batches(
|
||||
dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
@ -223,7 +212,7 @@ def train(
|
||||
n_tokens += toks.item()
|
||||
|
||||
# Report training loss if needed
|
||||
if ((it + 1) % args.steps_per_report == 0) or (it + 1 == args.iters):
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
train_loss = np.mean(losses)
|
||||
|
||||
stop = time.perf_counter()
|
||||
@ -233,7 +222,7 @@ def train(
|
||||
trained_tokens += n_tokens
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
print(
|
||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||
f"Iter {it}: Train loss {train_loss:.3f}, "
|
||||
f"Learning Rate {learning_rate:.3e}, "
|
||||
f"It/sec {it_sec:.3f}, "
|
||||
f"Tokens/sec {tokens_sec:.3f}, "
|
||||
@ -243,7 +232,7 @@ def train(
|
||||
|
||||
if training_callback is not None:
|
||||
train_info = {
|
||||
"iteration": it + 1,
|
||||
"iteration": it,
|
||||
"train_loss": train_loss,
|
||||
"learning_rate": learning_rate,
|
||||
"iterations_per_second": it_sec,
|
||||
@ -258,7 +247,7 @@ def train(
|
||||
start = time.perf_counter()
|
||||
|
||||
# Report validation loss if needed
|
||||
if it == 0 or ((it + 1) % args.steps_per_eval == 0) or (it + 1 == args.iters):
|
||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
val_loss = evaluate(
|
||||
model=model,
|
||||
@ -272,14 +261,12 @@ def train(
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
print(
|
||||
f"Iter {it + 1}: "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val took {val_time:.3f}s"
|
||||
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
val_info = {
|
||||
"iteration": it + 1,
|
||||
"iteration": it,
|
||||
"val_loss": val_loss,
|
||||
"val_time": val_time,
|
||||
}
|
||||
@ -287,23 +274,26 @@ def train(
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save adapter weights if needed
|
||||
if (it + 1) % args.steps_per_save == 0:
|
||||
checkpoint_adapter_file = (
|
||||
f"{adapter_path}/{it + 1}_{Path(args.adapter_file).name}"
|
||||
# Save adapter weights
|
||||
if it % args.steps_per_save == 0:
|
||||
save_adapter(model, args.adapter_file)
|
||||
checkpoint = (
|
||||
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
||||
)
|
||||
save_adapter(model, checkpoint)
|
||||
print(
|
||||
f"Iter {it}: Saved adapter weights to "
|
||||
f"{args.adapter_file} and {checkpoint}."
|
||||
)
|
||||
save_adapter(model=model, adapter_file=checkpoint_adapter_file)
|
||||
print(f"Iter {it + 1}: Saved adapter weights to {checkpoint_adapter_file}.")
|
||||
|
||||
# save final adapter weights
|
||||
save_adapter(model=model, adapter_file=args.adapter_file)
|
||||
save_adapter(model, args.adapter_file)
|
||||
print(f"Saved final adapter weights to {args.adapter_file}.")
|
||||
|
||||
|
||||
def save_adapter(
|
||||
model: nn.Module,
|
||||
adapter_file: str,
|
||||
adapter_file: Union[str, Path],
|
||||
):
|
||||
flattened_tree = tree_flatten(model.trainable_parameters())
|
||||
|
||||
mx.savez(adapter_file, **dict(flattened_tree))
|
||||
mx.save_safetensors(str(adapter_file), dict(flattened_tree))
|
||||
|
@ -1,4 +1,7 @@
|
||||
import os
|
||||
# Copyright © 2024 Apple Inc.
|
||||
import json
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
@ -91,40 +94,28 @@ def linear_to_lora_layers(
|
||||
raise ValueError(f"Lora does not support {model.model_type}")
|
||||
|
||||
for l in model.layers[num_layers - num_lora_layers :]:
|
||||
modules = l.named_modules()
|
||||
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
||||
l.update_modules(tree_unflatten(lora_layers))
|
||||
|
||||
|
||||
def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module:
|
||||
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
|
||||
"""
|
||||
Apply LoRA layers to the model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The neural network model.
|
||||
adapter_file (str): Path to the adapter configuration file.
|
||||
adapter_path (str): Path to the adapter configuration file.
|
||||
|
||||
Returns:
|
||||
nn.Module: The updated model with LoRA layers applied.
|
||||
"""
|
||||
if not os.path.exists(adapter_file):
|
||||
raise FileNotFoundError(f"The adapter file does not exist: {adapter_file}")
|
||||
|
||||
adapters = list(mx.load(adapter_file).items())
|
||||
|
||||
linear_replacements = []
|
||||
lora_layers = set(
|
||||
[name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters]
|
||||
)
|
||||
for name, module in model.named_modules():
|
||||
if name in lora_layers:
|
||||
replacement_module = LoRALinear.from_linear(module)
|
||||
linear_replacements.append((name, replacement_module))
|
||||
|
||||
model.update_modules(tree_unflatten(linear_replacements))
|
||||
|
||||
model.update(tree_unflatten(adapters))
|
||||
|
||||
adapter_path = Path(adapter_path)
|
||||
if not adapter_path.exists():
|
||||
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
|
||||
with open(adapter_path / "adapter_config.json", "r") as fid:
|
||||
config = types.SimpleNamespace(**json.load(fid))
|
||||
linear_to_lora_layers(model, config.lora_layers, config.lora_parameters)
|
||||
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -9,7 +9,7 @@ import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -354,7 +354,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
||||
def load(
|
||||
path_or_hf_repo: str,
|
||||
tokenizer_config={},
|
||||
adapter_file: Optional[str] = None,
|
||||
adapter_path: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
"""
|
||||
@ -364,8 +364,8 @@ def load(
|
||||
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
||||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
||||
Defaults to an empty dictionary.
|
||||
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.
|
||||
Defaults to None.
|
||||
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
||||
to the model. Default: ``None``.
|
||||
lazy (bool): If False eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
@ -379,8 +379,8 @@ def load(
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
|
||||
model = load_model(model_path, lazy)
|
||||
if adapter_file is not None:
|
||||
model = apply_lora_layers(model, adapter_file)
|
||||
if adapter_path is not None:
|
||||
model = apply_lora_layers(model, adapter_path)
|
||||
model.eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
||||
|
@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.4.0"
|
||||
__version__ = "0.6.0"
|
||||
|
Loading…
Reference in New Issue
Block a user