mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Adding full finetuning (#903)
* Adding full model weights finetuning * Updating the LORA.md and ACKNOWLEDGMENTS.md files. * removing --use-dora and --fulll-training and adding --fine-tune-type * some clean up * reformating and fixing dora training * updated CONFIG_DEFAULTS * update config example * update in the config example fie * Update LORA.md * merge and commit * adding argument for dora linear layer * clean up * clean up in the example yaml file * fix * final fix before sending * small addition to re md file * fix for loading the fully trained model by saving all the files and configs correctly * clean up * removing the unnesesairy files * changing lora layers back to 16 * removed max file size * nits * resolve merge * some consistency changes --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
7ec2021bb9
commit
50e5ca81a8
@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
|
|||||||
- Markus Enzweiler: Added the `cvae` examples.
|
- Markus Enzweiler: Added the `cvae` examples.
|
||||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||||
- Shiyu Li: Added the `Segment Anything Model`.
|
- Shiyu Li: Added the `Segment Anything Model`.
|
||||||
- Gökdeniz Gülmez: Added support for `MiniCPM` and `Mamba`.
|
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.
|
@ -16,7 +16,7 @@ conda install -c conda-forge mlx-lm
|
|||||||
|
|
||||||
The `mlx-lm` package also has:
|
The `mlx-lm` package also has:
|
||||||
|
|
||||||
- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
|
- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
|
||||||
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
||||||
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
||||||
|
|
||||||
|
@ -57,6 +57,9 @@ mlx_lm.lora \
|
|||||||
--iters 600
|
--iters 600
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To fine-tune the full model weights, add the `--fine-tune-type full` flag.
|
||||||
|
Currently supported fine-tuning types are `lora` (default), `dora`, and `full`.
|
||||||
|
|
||||||
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
|
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
|
||||||
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
|
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
|
||||||
details on the data format see the section on [Data](#Data).
|
details on the data format see the section on [Data](#Data).
|
||||||
@ -67,8 +70,8 @@ mistralai/Mistral-7B-v0.1`.
|
|||||||
If `--model` points to a quantized model, then the training will use QLoRA,
|
If `--model` points to a quantized model, then the training will use QLoRA,
|
||||||
otherwise it will use regular LoRA.
|
otherwise it will use regular LoRA.
|
||||||
|
|
||||||
By default, the adapter config and weights are saved in `adapters/`. You can
|
By default, the adapter config and learned weights are saved in `adapters/`.
|
||||||
specify the output location with `--adapter-path`.
|
You can specify the output location with `--adapter-path`.
|
||||||
|
|
||||||
You can resume fine-tuning with an existing adapter with
|
You can resume fine-tuning with an existing adapter with
|
||||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||||
@ -118,7 +121,7 @@ mlx_lm.fuse --model <path_to_model>
|
|||||||
```
|
```
|
||||||
|
|
||||||
This will by default load the adapters from `adapters/`, and save the fused
|
This will by default load the adapters from `adapters/`, and save the fused
|
||||||
model in the path `lora_fused_model/`. All of these are configurable.
|
model in the path `fused_model/`. All of these are configurable.
|
||||||
|
|
||||||
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
||||||
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
|
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
|
||||||
@ -141,7 +144,7 @@ mlx_lm.fuse \
|
|||||||
--export-gguf
|
--export-gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You
|
This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You
|
||||||
can specify the file name with `--gguf-path`.
|
can specify the file name with `--gguf-path`.
|
||||||
|
|
||||||
## Data
|
## Data
|
||||||
@ -301,7 +304,7 @@ of memory. Here are some tips to reduce memory use should you need to do so:
|
|||||||
setting this to `2` or `1` will reduce memory consumption. This may slow
|
setting this to `2` or `1` will reduce memory consumption. This may slow
|
||||||
things down a little, but will also reduce the memory use.
|
things down a little, but will also reduce the memory use.
|
||||||
|
|
||||||
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
|
3. Reduce the number of layers to fine-tune with `--num-layers`. The default
|
||||||
is `16`, so you can try `8` or `4`. This reduces the amount of memory
|
is `16`, so you can try `8` or `4`. This reduces the amount of memory
|
||||||
needed for back propagation. It may also reduce the quality of the
|
needed for back propagation. It may also reduce the quality of the
|
||||||
fine-tuned model if you are fine-tuning with a lot of data.
|
fine-tuned model if you are fine-tuning with a lot of data.
|
||||||
@ -323,7 +326,7 @@ mlx_lm.lora \
|
|||||||
--model mistralai/Mistral-7B-v0.1 \
|
--model mistralai/Mistral-7B-v0.1 \
|
||||||
--train \
|
--train \
|
||||||
--batch-size 1 \
|
--batch-size 1 \
|
||||||
--lora-layers 4 \
|
--num-layers 4 \
|
||||||
--data wikisql
|
--data wikisql
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -333,4 +336,5 @@ tokens-per-second, using the MLX Example
|
|||||||
data set.
|
data set.
|
||||||
|
|
||||||
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||||
|
|
||||||
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
# The path to the local model directory or Hugging Face repo.
|
# The path to the local model directory or Hugging Face repo.
|
||||||
model: "mlx_model"
|
model: "mlx_model"
|
||||||
|
|
||||||
# Whether or not to train (boolean)
|
# Whether or not to train (boolean)
|
||||||
train: true
|
train: true
|
||||||
|
|
||||||
|
# The fine-tuning method: "lora", "dora", or "full".
|
||||||
|
fine_tune_type: lora
|
||||||
|
|
||||||
# Directory with {train, valid, test}.jsonl files
|
# Directory with {train, valid, test}.jsonl files
|
||||||
data: "/path/to/training/data"
|
data: "/path/to/training/data"
|
||||||
|
|
||||||
@ -51,9 +55,6 @@ max_seq_length: 2048
|
|||||||
# Use gradient checkpointing to reduce memory use.
|
# Use gradient checkpointing to reduce memory use.
|
||||||
grad_checkpoint: false
|
grad_checkpoint: false
|
||||||
|
|
||||||
# Use DoRA instead of LoRA.
|
|
||||||
use_dora: false
|
|
||||||
|
|
||||||
# LoRA parameters can only be specified in a config file
|
# LoRA parameters can only be specified in a config file
|
||||||
lora_parameters:
|
lora_parameters:
|
||||||
# The layer keys to apply LoRA to.
|
# The layer keys to apply LoRA to.
|
||||||
|
@ -8,7 +8,7 @@ from mlx.utils import tree_flatten, tree_unflatten
|
|||||||
from .gguf import convert_to_gguf
|
from .gguf import convert_to_gguf
|
||||||
from .tuner.dora import DoRAEmbedding, DoRALinear
|
from .tuner.dora import DoRAEmbedding, DoRALinear
|
||||||
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
|
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
|
||||||
from .tuner.utils import apply_lora_layers, dequantize
|
from .tuner.utils import dequantize, load_adapters
|
||||||
from .utils import (
|
from .utils import (
|
||||||
fetch_from_hub,
|
fetch_from_hub,
|
||||||
get_model_path,
|
get_model_path,
|
||||||
@ -29,7 +29,7 @@ def parse_arguments() -> argparse.Namespace:
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-path",
|
"--save-path",
|
||||||
default="lora_fused_model",
|
default="fused_model",
|
||||||
help="The path to save the fused model.",
|
help="The path to save the fused model.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -77,16 +77,13 @@ def main() -> None:
|
|||||||
model, config, tokenizer = fetch_from_hub(model_path)
|
model, config, tokenizer = fetch_from_hub(model_path)
|
||||||
|
|
||||||
model.freeze()
|
model.freeze()
|
||||||
model = apply_lora_layers(model, args.adapter_path)
|
model = load_adapters(model, args.adapter_path)
|
||||||
|
|
||||||
fused_linears = [
|
fused_linears = [
|
||||||
(n, m.fuse())
|
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
|
||||||
for n, m in model.named_modules()
|
|
||||||
if isinstance(
|
|
||||||
m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding)
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if fused_linears:
|
||||||
model.update_modules(tree_unflatten(fused_linears))
|
model.update_modules(tree_unflatten(fused_linears))
|
||||||
|
|
||||||
if args.de_quantize:
|
if args.de_quantize:
|
||||||
|
@ -15,9 +15,9 @@ from .tokenizer_utils import TokenizerWrapper
|
|||||||
from .tuner.datasets import load_dataset
|
from .tuner.datasets import load_dataset
|
||||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
from .tuner.utils import (
|
from .tuner.utils import (
|
||||||
apply_lora_layers,
|
|
||||||
build_schedule,
|
build_schedule,
|
||||||
linear_to_lora_layers,
|
linear_to_lora_layers,
|
||||||
|
load_adapters,
|
||||||
print_trainable_parameters,
|
print_trainable_parameters,
|
||||||
)
|
)
|
||||||
from .utils import load, save_config
|
from .utils import load, save_config
|
||||||
@ -41,9 +41,10 @@ yaml_loader.add_implicit_resolver(
|
|||||||
CONFIG_DEFAULTS = {
|
CONFIG_DEFAULTS = {
|
||||||
"model": "mlx_model",
|
"model": "mlx_model",
|
||||||
"train": False,
|
"train": False,
|
||||||
|
"fine_tune_type": "lora",
|
||||||
"data": "data/",
|
"data": "data/",
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
"lora_layers": 16,
|
"num_layers": 16,
|
||||||
"batch_size": 4,
|
"batch_size": 4,
|
||||||
"iters": 1000,
|
"iters": 1000,
|
||||||
"val_batches": 25,
|
"val_batches": 25,
|
||||||
@ -58,7 +59,6 @@ CONFIG_DEFAULTS = {
|
|||||||
"max_seq_length": 2048,
|
"max_seq_length": 2048,
|
||||||
"lr_schedule": None,
|
"lr_schedule": None,
|
||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
"use_dora": False,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -82,7 +82,14 @@ def build_parser():
|
|||||||
help="Directory with {train, valid, test}.jsonl files",
|
help="Directory with {train, valid, test}.jsonl files",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-layers",
|
"--fine-tune-type",
|
||||||
|
type=str,
|
||||||
|
choices=["lora", "dora", "full"],
|
||||||
|
default="lora",
|
||||||
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-layers",
|
||||||
type=int,
|
type=int,
|
||||||
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
|
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
|
||||||
)
|
)
|
||||||
@ -107,12 +114,12 @@ def build_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--resume-adapter-file",
|
"--resume-adapter-file",
|
||||||
type=str,
|
type=str,
|
||||||
help="Load path to resume training with the given adapters.",
|
help="Load path to resume training from the given fine-tuned weights.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adapter-path",
|
"--adapter-path",
|
||||||
type=str,
|
type=str,
|
||||||
help="Save/load path for the adapters.",
|
help="Save/load path for the fine-tuned weights.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every",
|
"--save-every",
|
||||||
@ -148,9 +155,6 @@ def build_parser():
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
|
||||||
parser.add_argument(
|
|
||||||
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
|
|
||||||
)
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -162,21 +166,31 @@ def train_model(
|
|||||||
valid_set,
|
valid_set,
|
||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
# Freeze all layers
|
|
||||||
model.freeze()
|
model.freeze()
|
||||||
|
if args.fine_tune_type == "full":
|
||||||
|
for l in model.layers[-min(args.num_layers, 0) :]:
|
||||||
|
l.unfreeze()
|
||||||
|
elif args.fine_tune_type in ["lora", "dora"]:
|
||||||
|
# Convert linear layers to lora/dora layers and unfreeze in the process
|
||||||
|
linear_to_lora_layers(
|
||||||
|
model,
|
||||||
|
args.num_layers,
|
||||||
|
args.lora_parameters,
|
||||||
|
use_dora=(args.fine_tune_type == "dora"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
|
||||||
|
|
||||||
# Convert linear layers to lora layers and unfreeze in the process
|
# Resume from weights if provided
|
||||||
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters, args.use_dora)
|
|
||||||
|
|
||||||
# Resume training the given adapters.
|
|
||||||
if args.resume_adapter_file is not None:
|
if args.resume_adapter_file is not None:
|
||||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
|
||||||
model.load_weights(args.resume_adapter_file, strict=False)
|
model.load_weights(args.resume_adapter_file, strict=False)
|
||||||
|
|
||||||
print_trainable_parameters(model)
|
print_trainable_parameters(model)
|
||||||
|
|
||||||
adapter_path = Path(args.adapter_path)
|
adapter_path = Path(args.adapter_path)
|
||||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
adapter_file = adapter_path / "adapters.safetensors"
|
adapter_file = adapter_path / "adapters.safetensors"
|
||||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||||
|
|
||||||
@ -240,7 +254,7 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
if args.test and not args.train:
|
if args.test and not args.train:
|
||||||
# Allow testing without LoRA layers by providing empty path
|
# Allow testing without LoRA layers by providing empty path
|
||||||
if args.adapter_path != "":
|
if args.adapter_path != "":
|
||||||
apply_lora_layers(model, args.adapter_path)
|
load_adapters(model, args.adapter_path)
|
||||||
|
|
||||||
elif args.train:
|
elif args.train:
|
||||||
print("Training")
|
print("Training")
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import shutil
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -285,24 +287,18 @@ def train(
|
|||||||
|
|
||||||
# Save adapter weights
|
# Save adapter weights
|
||||||
if it % args.steps_per_save == 0:
|
if it % args.steps_per_save == 0:
|
||||||
save_adapter(model, args.adapter_file)
|
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||||
|
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||||
checkpoint = (
|
checkpoint = (
|
||||||
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
||||||
)
|
)
|
||||||
save_adapter(model, checkpoint)
|
mx.save_safetensors(str(checkpoint), adapter_weights)
|
||||||
print(
|
print(
|
||||||
f"Iter {it}: Saved adapter weights to "
|
f"Iter {it}: Saved adapter weights to "
|
||||||
f"{args.adapter_file} and {checkpoint}."
|
f"{args.adapter_file} and {checkpoint}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# save final adapter weights
|
# Save final weights
|
||||||
save_adapter(model, args.adapter_file)
|
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||||
print(f"Saved final adapter weights to {args.adapter_file}.")
|
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||||
|
print(f"Saved final weights to {args.adapter_file}.")
|
||||||
|
|
||||||
def save_adapter(
|
|
||||||
model: nn.Module,
|
|
||||||
adapter_file: Union[str, Path],
|
|
||||||
):
|
|
||||||
flattened_tree = tree_flatten(model.trainable_parameters())
|
|
||||||
mx.save_safetensors(str(adapter_file), dict(flattened_tree))
|
|
||||||
|
@ -36,7 +36,7 @@ def build_schedule(schedule_config: Dict):
|
|||||||
|
|
||||||
def linear_to_lora_layers(
|
def linear_to_lora_layers(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
num_lora_layers: int,
|
num_layers: int,
|
||||||
config: Dict,
|
config: Dict,
|
||||||
use_dora: bool = False,
|
use_dora: bool = False,
|
||||||
):
|
):
|
||||||
@ -45,22 +45,17 @@ def linear_to_lora_layers(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The neural network model.
|
model (nn.Module): The neural network model.
|
||||||
num_lora_layers (int): The number of blocks to convert to lora layers
|
num_layers (int): The number of blocks to convert to lora layers
|
||||||
starting from the last layer.
|
starting from the last layer.
|
||||||
config (dict): More configuration parameters for LoRA, including the
|
config (dict): More configuration parameters for LoRA, including the
|
||||||
rank, scale, and optional layer keys.
|
rank, scale, and optional layer keys.
|
||||||
use_dora (bool): If True, uses DoRA instead of LoRA.
|
use_dora (bool): If True, uses DoRA instead of LoRA.
|
||||||
Default: ``False``
|
Default: ``False``
|
||||||
"""
|
"""
|
||||||
num_layers = len(model.layers)
|
if num_layers > len(model.layers):
|
||||||
|
|
||||||
if num_lora_layers < 0:
|
|
||||||
num_lora_layers = num_layers
|
|
||||||
|
|
||||||
if num_lora_layers > num_layers:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Requested {num_lora_layers} LoRA layers "
|
f"Requested {num_layers} LoRA layers "
|
||||||
f"but the model only has {num_layers} layers."
|
f"but the model only has {len(model.layers)} layers."
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_lora(layer):
|
def to_lora(layer):
|
||||||
@ -151,7 +146,7 @@ def linear_to_lora_layers(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Lora does not support {model.model_type}")
|
raise ValueError(f"Lora does not support {model.model_type}")
|
||||||
|
|
||||||
for l in model.layers[num_layers - num_lora_layers :]:
|
for l in model.layers[-min(num_layers, 0) :]:
|
||||||
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
||||||
if lora_layers:
|
if lora_layers:
|
||||||
l.update_modules(tree_unflatten(lora_layers))
|
l.update_modules(tree_unflatten(lora_layers))
|
||||||
@ -161,9 +156,9 @@ def linear_to_lora_layers(
|
|||||||
model.update_modules(tree_unflatten(lora_modules))
|
model.update_modules(tree_unflatten(lora_modules))
|
||||||
|
|
||||||
|
|
||||||
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
|
def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Apply LoRA layers to the model.
|
Load any fine-tuned adapters / layers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The neural network model.
|
model (nn.Module): The neural network model.
|
||||||
@ -177,11 +172,13 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
|
|||||||
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
|
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
|
||||||
with open(adapter_path / "adapter_config.json", "r") as fid:
|
with open(adapter_path / "adapter_config.json", "r") as fid:
|
||||||
config = types.SimpleNamespace(**json.load(fid))
|
config = types.SimpleNamespace(**json.load(fid))
|
||||||
|
fine_tune_type = getattr(config, "fine_tune_type", "lora")
|
||||||
|
if fine_tune_type != "full":
|
||||||
linear_to_lora_layers(
|
linear_to_lora_layers(
|
||||||
model,
|
model,
|
||||||
config.lora_layers,
|
config.num_layers,
|
||||||
config.lora_parameters,
|
config.lora_parameters,
|
||||||
getattr(config, "use_dora", False),
|
use_dora=(fine_tune_type == "dora"),
|
||||||
)
|
)
|
||||||
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
||||||
return model
|
return model
|
||||||
|
@ -21,8 +21,8 @@ from transformers import PreTrainedTokenizer
|
|||||||
from .models.base import KVCache, RotatingKVCache
|
from .models.base import KVCache, RotatingKVCache
|
||||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import apply_lora_layers
|
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
|
from .tuner.utils import load_adapters
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
MODEL_REMAPPING = {
|
MODEL_REMAPPING = {
|
||||||
@ -515,7 +515,7 @@ def load(
|
|||||||
|
|
||||||
model = load_model(model_path, lazy, model_config)
|
model = load_model(model_path, lazy, model_config)
|
||||||
if adapter_path is not None:
|
if adapter_path is not None:
|
||||||
model = apply_lora_layers(model, adapter_path)
|
model = load_adapters(model, adapter_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user