mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
LoRA on all linear transformer block layers (#546)
* Add --lora-all-linear option to apply LoRa to all linear transfer block layers * Moved to YAML config and added specification of rank & alpha * nits in conifg, more tests * nit * run tests for prs --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
@@ -49,6 +51,7 @@ CONFIG_DEFAULTS = {
|
||||
"test": False,
|
||||
"test_batches": 500,
|
||||
"max_seq_length": 2048,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
}
|
||||
|
||||
|
||||
@@ -58,7 +61,6 @@ def build_parser():
|
||||
"--model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
# Generation args
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
@@ -196,7 +198,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
# Freeze all layers
|
||||
model.freeze()
|
||||
# Convert linear layers to lora layers and unfreeze in the process
|
||||
linear_to_lora_layers(model, args.lora_layers)
|
||||
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
|
||||
|
||||
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
print(f"Total parameters {p:.3f}M")
|
||||
|
Reference in New Issue
Block a user