mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
parent
d661440dbb
commit
2bd64b78cf
@ -65,11 +65,11 @@ mistralai/Mistral-7B-v0.1`.
|
|||||||
If `--model` points to a quantized model, then the training will use QLoRA,
|
If `--model` points to a quantized model, then the training will use QLoRA,
|
||||||
otherwise it will use regular LoRA.
|
otherwise it will use regular LoRA.
|
||||||
|
|
||||||
By default, the adapter weights are saved in `adapters.npz`. You can specify
|
By default, the adapter config and weights are saved in `adapters/`. You can
|
||||||
the output location with `--adapter-file`.
|
specify the output location with `--adapter-path`.
|
||||||
|
|
||||||
You can resume fine-tuning with an existing adapter with
|
You can resume fine-tuning with an existing adapter with
|
||||||
`--resume-adapter-file <path_to_adapters.npz>`.
|
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||||
|
|
||||||
### Evaluate
|
### Evaluate
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ To compute test set perplexity use:
|
|||||||
```shell
|
```shell
|
||||||
python -m mlx_lm.lora \
|
python -m mlx_lm.lora \
|
||||||
--model <path_to_model> \
|
--model <path_to_model> \
|
||||||
--adapter-file <path_to_adapters.npz> \
|
--adapter-path <path_to_adapters> \
|
||||||
--data <path_to_data> \
|
--data <path_to_data> \
|
||||||
--test
|
--test
|
||||||
```
|
```
|
||||||
@ -90,7 +90,7 @@ For generation use `mlx_lm.generate`:
|
|||||||
```shell
|
```shell
|
||||||
python -m mlx_lm.generate \
|
python -m mlx_lm.generate \
|
||||||
--model <path_to_model> \
|
--model <path_to_model> \
|
||||||
--adapter-file <path_to_adapters.npz> \
|
--adapter-path <path_to_adapters> \
|
||||||
--prompt "<your_model_prompt>"
|
--prompt "<your_model_prompt>"
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -115,7 +115,7 @@ To generate the fused model run:
|
|||||||
python -m mlx_lm.fuse --model <path_to_model>
|
python -m mlx_lm.fuse --model <path_to_model>
|
||||||
```
|
```
|
||||||
|
|
||||||
This will by default load the adapters from `adapters.npz`, and save the fused
|
This will by default load the adapters from `adapters/`, and save the fused
|
||||||
model in the path `lora_fused_model/`. All of these are configurable.
|
model in the path `lora_fused_model/`. All of these are configurable.
|
||||||
|
|
||||||
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import json
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -31,10 +30,10 @@ def parse_arguments() -> argparse.Namespace:
|
|||||||
help="The path to save the fused model.",
|
help="The path to save the fused model.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adapter-file",
|
"--adapter-path",
|
||||||
type=str,
|
type=str,
|
||||||
default="adapters.npz",
|
default="adapters",
|
||||||
help="Path to the trained adapter weights (npz or safetensors).",
|
help="Path to the trained adapter weights and config.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hf-path",
|
"--hf-path",
|
||||||
@ -75,7 +74,7 @@ def main() -> None:
|
|||||||
model, config, tokenizer = fetch_from_hub(model_path)
|
model, config, tokenizer = fetch_from_hub(model_path)
|
||||||
|
|
||||||
model.freeze()
|
model.freeze()
|
||||||
model = apply_lora_layers(model, args.adapter_file)
|
model = apply_lora_layers(model, args.adapter_path)
|
||||||
|
|
||||||
fused_linears = [
|
fused_linears = [
|
||||||
(n, m.to_linear())
|
(n, m.to_linear())
|
||||||
|
@ -24,9 +24,9 @@ def setup_arg_parser():
|
|||||||
help="The path to the local model directory or Hugging Face repo.",
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adapter-file",
|
"--adapter-path",
|
||||||
type=str,
|
type=str,
|
||||||
help="Optional path for the trained adapter weights.",
|
help="Optional path for the trained adapter weights and config.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
@ -110,7 +110,7 @@ def main(args):
|
|||||||
tokenizer_config["eos_token"] = args.eos_token
|
tokenizer_config["eos_token"] = args.eos_token
|
||||||
|
|
||||||
model, tokenizer = load(
|
model, tokenizer = load(
|
||||||
args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config
|
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.use_default_chat_template:
|
if args.use_default_chat_template:
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
@ -16,7 +15,7 @@ from mlx.utils import tree_flatten
|
|||||||
from .tuner.datasets import load_dataset
|
from .tuner.datasets import load_dataset
|
||||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
from .tuner.utils import build_schedule, linear_to_lora_layers
|
from .tuner.utils import build_schedule, linear_to_lora_layers
|
||||||
from .utils import load
|
from .utils import load, save_config
|
||||||
|
|
||||||
yaml_loader = yaml.SafeLoader
|
yaml_loader = yaml.SafeLoader
|
||||||
yaml_loader.add_implicit_resolver(
|
yaml_loader.add_implicit_resolver(
|
||||||
@ -48,7 +47,7 @@ CONFIG_DEFAULTS = {
|
|||||||
"steps_per_report": 10,
|
"steps_per_report": 10,
|
||||||
"steps_per_eval": 200,
|
"steps_per_eval": 200,
|
||||||
"resume_adapter_file": None,
|
"resume_adapter_file": None,
|
||||||
"adapter_file": "adapters.npz",
|
"adapter_path": "adapters",
|
||||||
"save_every": 100,
|
"save_every": 100,
|
||||||
"test": False,
|
"test": False,
|
||||||
"test_batches": 500,
|
"test_batches": 500,
|
||||||
@ -102,12 +101,12 @@ def build_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--resume-adapter-file",
|
"--resume-adapter-file",
|
||||||
type=str,
|
type=str,
|
||||||
help="Load path to resume training with the given adapter weights.",
|
help="Load path to resume training with the given adapters.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adapter-file",
|
"--adapter-path",
|
||||||
type=str,
|
type=str,
|
||||||
help="Save/load path for the trained adapter weights.",
|
help="Save/load path for the adapters.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every",
|
"--save-every",
|
||||||
@ -184,6 +183,11 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||||
model.load_weights(args.resume_adapter_file, strict=False)
|
model.load_weights(args.resume_adapter_file, strict=False)
|
||||||
|
|
||||||
|
adapter_path = Path(args.adapter_path)
|
||||||
|
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||||
|
adapter_file = adapter_path / "adapters.safetensors"
|
||||||
|
|
||||||
if args.train:
|
if args.train:
|
||||||
print("Training")
|
print("Training")
|
||||||
# init training args
|
# init training args
|
||||||
@ -194,7 +198,7 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
steps_per_report=args.steps_per_report,
|
steps_per_report=args.steps_per_report,
|
||||||
steps_per_eval=args.steps_per_eval,
|
steps_per_eval=args.steps_per_eval,
|
||||||
steps_per_save=args.save_every,
|
steps_per_save=args.save_every,
|
||||||
adapter_file=args.adapter_file,
|
adapter_file=adapter_file,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
grad_checkpoint=args.grad_checkpoint,
|
grad_checkpoint=args.grad_checkpoint,
|
||||||
)
|
)
|
||||||
@ -219,12 +223,12 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Load the LoRA adapter weights which we assume should exist by this point
|
# Load the LoRA adapter weights which we assume should exist by this point
|
||||||
if not Path(args.adapter_file).is_file():
|
if not adapter_file.is_file():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Adapter file {args.adapter_file} missing. "
|
f"Adapter file {adapter_file} missing. "
|
||||||
"Use --train to learn and save the adapters.npz."
|
"Use --train to learn and save the adapters"
|
||||||
)
|
)
|
||||||
model.load_weights(args.adapter_file, strict=False)
|
model.load_weights(str(adapter_file), strict=False)
|
||||||
|
|
||||||
if args.test:
|
if args.test:
|
||||||
print("Testing")
|
print("Testing")
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import json
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
@ -418,9 +418,9 @@ if __name__ == "__main__":
|
|||||||
help="The path to the MLX model weights, tokenizer, and config",
|
help="The path to the MLX model weights, tokenizer, and config",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adapter-file",
|
"--adapter-path",
|
||||||
type=str,
|
type=str,
|
||||||
help="Optional path for the trained adapter weights.",
|
help="Optional path for the trained adapter weights and config.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--host",
|
"--host",
|
||||||
@ -445,7 +445,7 @@ if __name__ == "__main__":
|
|||||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||||
|
|
||||||
MODEL, TOKENIZER = load(
|
MODEL, TOKENIZER = load(
|
||||||
args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config
|
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
||||||
)
|
)
|
||||||
|
|
||||||
run(args.host, args.port)
|
run(args.host, args.port)
|
||||||
|
@ -4,6 +4,7 @@ import time
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -54,7 +55,7 @@ class TrainingArgs:
|
|||||||
default=2048, metadata={"help": "Maximum sequence length."}
|
default=2048, metadata={"help": "Maximum sequence length."}
|
||||||
)
|
)
|
||||||
adapter_file: str = field(
|
adapter_file: str = field(
|
||||||
default="adapter.npz",
|
default="adapters.safetensors",
|
||||||
metadata={"help": "Save/load path for the trained adapter weights."},
|
metadata={"help": "Save/load path for the trained adapter weights."},
|
||||||
)
|
)
|
||||||
grad_checkpoint: bool = field(
|
grad_checkpoint: bool = field(
|
||||||
@ -172,18 +173,6 @@ def train(
|
|||||||
):
|
):
|
||||||
print(f"Starting training..., iters: {args.iters}")
|
print(f"Starting training..., iters: {args.iters}")
|
||||||
|
|
||||||
def checkpoints_path(adapter_file) -> str:
|
|
||||||
checkpoints_path = Path("checkpoints")
|
|
||||||
if Path(adapter_file).parent:
|
|
||||||
checkpoints_path = Path(adapter_file).parent / "checkpoints"
|
|
||||||
|
|
||||||
checkpoints_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
return str(checkpoints_path)
|
|
||||||
|
|
||||||
# Create checkpoints directory if it does not exist
|
|
||||||
adapter_path = checkpoints_path(args.adapter_file)
|
|
||||||
|
|
||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
grad_checkpoint(model.layers[0])
|
grad_checkpoint(model.layers[0])
|
||||||
|
|
||||||
@ -206,7 +195,7 @@ def train(
|
|||||||
# Main training loop
|
# Main training loop
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
for it, batch in zip(
|
for it, batch in zip(
|
||||||
range(args.iters),
|
range(1, args.iters + 1),
|
||||||
iterate_batches(
|
iterate_batches(
|
||||||
dataset=train_dataset,
|
dataset=train_dataset,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -223,7 +212,7 @@ def train(
|
|||||||
n_tokens += toks.item()
|
n_tokens += toks.item()
|
||||||
|
|
||||||
# Report training loss if needed
|
# Report training loss if needed
|
||||||
if ((it + 1) % args.steps_per_report == 0) or (it + 1 == args.iters):
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
train_loss = np.mean(losses)
|
train_loss = np.mean(losses)
|
||||||
|
|
||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
@ -233,7 +222,7 @@ def train(
|
|||||||
trained_tokens += n_tokens
|
trained_tokens += n_tokens
|
||||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||||
print(
|
print(
|
||||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
f"Iter {it}: Train loss {train_loss:.3f}, "
|
||||||
f"Learning Rate {learning_rate:.3e}, "
|
f"Learning Rate {learning_rate:.3e}, "
|
||||||
f"It/sec {it_sec:.3f}, "
|
f"It/sec {it_sec:.3f}, "
|
||||||
f"Tokens/sec {tokens_sec:.3f}, "
|
f"Tokens/sec {tokens_sec:.3f}, "
|
||||||
@ -243,7 +232,7 @@ def train(
|
|||||||
|
|
||||||
if training_callback is not None:
|
if training_callback is not None:
|
||||||
train_info = {
|
train_info = {
|
||||||
"iteration": it + 1,
|
"iteration": it,
|
||||||
"train_loss": train_loss,
|
"train_loss": train_loss,
|
||||||
"learning_rate": learning_rate,
|
"learning_rate": learning_rate,
|
||||||
"iterations_per_second": it_sec,
|
"iterations_per_second": it_sec,
|
||||||
@ -258,7 +247,7 @@ def train(
|
|||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
# Report validation loss if needed
|
# Report validation loss if needed
|
||||||
if it == 0 or ((it + 1) % args.steps_per_eval == 0) or (it + 1 == args.iters):
|
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
val_loss = evaluate(
|
val_loss = evaluate(
|
||||||
model=model,
|
model=model,
|
||||||
@ -272,14 +261,12 @@ def train(
|
|||||||
)
|
)
|
||||||
val_time = time.perf_counter() - stop
|
val_time = time.perf_counter() - stop
|
||||||
print(
|
print(
|
||||||
f"Iter {it + 1}: "
|
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
|
||||||
f"Val loss {val_loss:.3f}, "
|
|
||||||
f"Val took {val_time:.3f}s"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if training_callback is not None:
|
if training_callback is not None:
|
||||||
val_info = {
|
val_info = {
|
||||||
"iteration": it + 1,
|
"iteration": it,
|
||||||
"val_loss": val_loss,
|
"val_loss": val_loss,
|
||||||
"val_time": val_time,
|
"val_time": val_time,
|
||||||
}
|
}
|
||||||
@ -287,23 +274,26 @@ def train(
|
|||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
# Save adapter weights if needed
|
# Save adapter weights
|
||||||
if (it + 1) % args.steps_per_save == 0:
|
if it % args.steps_per_save == 0:
|
||||||
checkpoint_adapter_file = (
|
save_adapter(model, args.adapter_file)
|
||||||
f"{adapter_path}/{it + 1}_{Path(args.adapter_file).name}"
|
checkpoint = (
|
||||||
|
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
||||||
|
)
|
||||||
|
save_adapter(model, checkpoint)
|
||||||
|
print(
|
||||||
|
f"Iter {it}: Saved adapter weights to "
|
||||||
|
f"{args.adapter_file} and {checkpoint}."
|
||||||
)
|
)
|
||||||
save_adapter(model=model, adapter_file=checkpoint_adapter_file)
|
|
||||||
print(f"Iter {it + 1}: Saved adapter weights to {checkpoint_adapter_file}.")
|
|
||||||
|
|
||||||
# save final adapter weights
|
# save final adapter weights
|
||||||
save_adapter(model=model, adapter_file=args.adapter_file)
|
save_adapter(model, args.adapter_file)
|
||||||
print(f"Saved final adapter weights to {args.adapter_file}.")
|
print(f"Saved final adapter weights to {args.adapter_file}.")
|
||||||
|
|
||||||
|
|
||||||
def save_adapter(
|
def save_adapter(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
adapter_file: str,
|
adapter_file: Union[str, Path],
|
||||||
):
|
):
|
||||||
flattened_tree = tree_flatten(model.trainable_parameters())
|
flattened_tree = tree_flatten(model.trainable_parameters())
|
||||||
|
mx.save_safetensors(str(adapter_file), dict(flattened_tree))
|
||||||
mx.savez(adapter_file, **dict(flattened_tree))
|
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
import os
|
# Copyright © 2024 Apple Inc.
|
||||||
|
import json
|
||||||
|
import types
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -91,40 +94,28 @@ def linear_to_lora_layers(
|
|||||||
raise ValueError(f"Lora does not support {model.model_type}")
|
raise ValueError(f"Lora does not support {model.model_type}")
|
||||||
|
|
||||||
for l in model.layers[num_layers - num_lora_layers :]:
|
for l in model.layers[num_layers - num_lora_layers :]:
|
||||||
modules = l.named_modules()
|
|
||||||
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
||||||
l.update_modules(tree_unflatten(lora_layers))
|
l.update_modules(tree_unflatten(lora_layers))
|
||||||
|
|
||||||
|
|
||||||
def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module:
|
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Apply LoRA layers to the model.
|
Apply LoRA layers to the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The neural network model.
|
model (nn.Module): The neural network model.
|
||||||
adapter_file (str): Path to the adapter configuration file.
|
adapter_path (str): Path to the adapter configuration file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: The updated model with LoRA layers applied.
|
nn.Module: The updated model with LoRA layers applied.
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(adapter_file):
|
adapter_path = Path(adapter_path)
|
||||||
raise FileNotFoundError(f"The adapter file does not exist: {adapter_file}")
|
if not adapter_path.exists():
|
||||||
|
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
|
||||||
adapters = list(mx.load(adapter_file).items())
|
with open(adapter_path / "adapter_config.json", "r") as fid:
|
||||||
|
config = types.SimpleNamespace(**json.load(fid))
|
||||||
linear_replacements = []
|
linear_to_lora_layers(model, config.lora_layers, config.lora_parameters)
|
||||||
lora_layers = set(
|
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
||||||
[name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters]
|
|
||||||
)
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if name in lora_layers:
|
|
||||||
replacement_module = LoRALinear.from_linear(module)
|
|
||||||
linear_replacements.append((name, replacement_module))
|
|
||||||
|
|
||||||
model.update_modules(tree_unflatten(linear_replacements))
|
|
||||||
|
|
||||||
model.update(tree_unflatten(adapters))
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -354,7 +354,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
|||||||
def load(
|
def load(
|
||||||
path_or_hf_repo: str,
|
path_or_hf_repo: str,
|
||||||
tokenizer_config={},
|
tokenizer_config={},
|
||||||
adapter_file: Optional[str] = None,
|
adapter_path: Optional[str] = None,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||||
"""
|
"""
|
||||||
@ -364,8 +364,8 @@ def load(
|
|||||||
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
||||||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
||||||
Defaults to an empty dictionary.
|
Defaults to an empty dictionary.
|
||||||
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.
|
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
||||||
Defaults to None.
|
to the model. Default: ``None``.
|
||||||
lazy (bool): If False eval the model parameters to make sure they are
|
lazy (bool): If False eval the model parameters to make sure they are
|
||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
@ -379,8 +379,8 @@ def load(
|
|||||||
model_path = get_model_path(path_or_hf_repo)
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
model = load_model(model_path, lazy)
|
model = load_model(model_path, lazy)
|
||||||
if adapter_file is not None:
|
if adapter_path is not None:
|
||||||
model = apply_lora_layers(model, adapter_file)
|
model = apply_lora_layers(model, adapter_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.4.0"
|
__version__ = "0.6.0"
|
||||||
|
Loading…
Reference in New Issue
Block a user