mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 09:48:54 +08:00
LoRA: Extract small function (#614)
* LoRA: Extract pre_processing_model function * LoRA: Extract small functions(train_model,evaluate_model) * move test case to test_tuner_utils.py * nits * nits * remove extra param, validate at it 0 * version * fix test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import Dict
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
from mlx.utils import tree_unflatten
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
|
||||
from .dora import DoRALinear
|
||||
@@ -48,7 +48,7 @@ def linear_to_lora_layers(
|
||||
num_lora_layers (int): The number of blocks to convert to lora layers
|
||||
starting from the last layer.
|
||||
config (dict): More configuration parameters for LoRA, including the
|
||||
rank, alpha, scale, and optional layer keys.
|
||||
rank, scale, and optional layer keys.
|
||||
use_dora (bool): If True, uses DoRA instead of LoRA.
|
||||
Default: ``False``
|
||||
"""
|
||||
@@ -79,7 +79,6 @@ def linear_to_lora_layers(
|
||||
return LoRALayer.from_linear(
|
||||
layer,
|
||||
r=config["rank"],
|
||||
alpha=config["alpha"],
|
||||
scale=config["scale"],
|
||||
dropout=config["dropout"],
|
||||
)
|
||||
@@ -218,3 +217,22 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
|
||||
if len(reset_layers) > 0:
|
||||
model.update_modules(tree_unflatten(reset_layers))
|
||||
return model
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
def nparams(m):
|
||||
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
|
||||
return m.weight.size * (32 // m.bits)
|
||||
return sum(v.size for _, v in tree_flatten(m.parameters()))
|
||||
|
||||
leaf_modules = tree_flatten(
|
||||
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||
)
|
||||
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
|
||||
trainable_p = (
|
||||
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
||||
)
|
||||
print(
|
||||
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
||||
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user