mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
LoRA: Extract small function (#614)
* LoRA: Extract pre_processing_model function * LoRA: Extract small functions(train_model,evaluate_model) * move test case to test_tuner_utils.py * nits * nits * remove extra param, validate at it 0 * version * fix test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
from .trainer import TrainingArgs, evaluate, train
|
||||
from .utils import linear_to_lora_layers
|
||||
|
@@ -11,9 +11,8 @@ class DoRALinear(nn.Module):
|
||||
def from_linear(
|
||||
linear: nn.Linear,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
# TODO support quantized weights in DoRALinear
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
@@ -23,7 +22,6 @@ class DoRALinear(nn.Module):
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
r=r,
|
||||
alpha=alpha,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
@@ -56,9 +54,8 @@ class DoRALinear(nn.Module):
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
scale: float = 20.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -68,7 +65,7 @@ class DoRALinear(nn.Module):
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale * (alpha / r)
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
|
@@ -13,9 +13,8 @@ class LoRALinear(nn.Module):
|
||||
def from_linear(
|
||||
linear: nn.Linear,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
# TODO remove when input_dims and output_dims are attributes
|
||||
# on linear and quantized linear
|
||||
@@ -26,7 +25,6 @@ class LoRALinear(nn.Module):
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
r=r,
|
||||
alpha=alpha,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
@@ -74,9 +72,8 @@ class LoRALinear(nn.Module):
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
scale: float = 20.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -87,7 +84,7 @@ class LoRALinear(nn.Module):
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale * (alpha / r)
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
@@ -109,16 +106,14 @@ class LoRASwitchLinear(nn.Module):
|
||||
def from_linear(
|
||||
linear: nn.Module,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
lora_lin = LoRASwitchLinear(
|
||||
input_dims=linear.input_dims,
|
||||
output_dims=linear.output_dims,
|
||||
num_experts=linear.num_experts,
|
||||
r=r,
|
||||
alpha=alpha,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
@@ -163,9 +158,8 @@ class LoRASwitchLinear(nn.Module):
|
||||
output_dims: int,
|
||||
num_experts: int,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
scale: float = 20.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -176,7 +170,7 @@ class LoRASwitchLinear(nn.Module):
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale * (alpha / r)
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
|
@@ -29,9 +29,6 @@ def grad_checkpoint(layer):
|
||||
|
||||
@dataclass
|
||||
class TrainingArgs:
|
||||
lora_layers: int = field(
|
||||
default=16, metadata={"help": "Number of layers to fine-tune"}
|
||||
)
|
||||
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
||||
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
|
||||
val_batches: int = field(
|
||||
@@ -211,6 +208,35 @@ def train(
|
||||
train=True,
|
||||
),
|
||||
):
|
||||
# Report validation loss if needed, the first validation loss
|
||||
# is always measured before any training.
|
||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
val_loss = evaluate(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
loss=loss,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
iterate_batches=iterate_batches,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
print(
|
||||
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
val_info = {
|
||||
"iteration": it,
|
||||
"val_loss": val_loss,
|
||||
"val_time": val_time,
|
||||
}
|
||||
training_callback.on_val_loss_report(val_info)
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
lvalue, toks = step(batch)
|
||||
mx.eval(state, lvalue, toks)
|
||||
|
||||
@@ -220,9 +246,9 @@ def train(
|
||||
|
||||
# Report training loss if needed
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
train_loss = np.mean(losses)
|
||||
|
||||
stop = time.perf_counter()
|
||||
|
||||
train_loss = np.mean(losses)
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
@@ -253,34 +279,6 @@ def train(
|
||||
n_tokens = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
# Report validation loss if needed
|
||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
val_loss = evaluate(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
loss=loss,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
iterate_batches=iterate_batches,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
print(
|
||||
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
val_info = {
|
||||
"iteration": it,
|
||||
"val_loss": val_loss,
|
||||
"val_time": val_time,
|
||||
}
|
||||
training_callback.on_val_loss_report(val_info)
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save adapter weights
|
||||
if it % args.steps_per_save == 0:
|
||||
save_adapter(model, args.adapter_file)
|
||||
|
@@ -7,7 +7,7 @@ from typing import Dict
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
from mlx.utils import tree_unflatten
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
|
||||
from .dora import DoRALinear
|
||||
@@ -48,7 +48,7 @@ def linear_to_lora_layers(
|
||||
num_lora_layers (int): The number of blocks to convert to lora layers
|
||||
starting from the last layer.
|
||||
config (dict): More configuration parameters for LoRA, including the
|
||||
rank, alpha, scale, and optional layer keys.
|
||||
rank, scale, and optional layer keys.
|
||||
use_dora (bool): If True, uses DoRA instead of LoRA.
|
||||
Default: ``False``
|
||||
"""
|
||||
@@ -79,7 +79,6 @@ def linear_to_lora_layers(
|
||||
return LoRALayer.from_linear(
|
||||
layer,
|
||||
r=config["rank"],
|
||||
alpha=config["alpha"],
|
||||
scale=config["scale"],
|
||||
dropout=config["dropout"],
|
||||
)
|
||||
@@ -218,3 +217,22 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
|
||||
if len(reset_layers) > 0:
|
||||
model.update_modules(tree_unflatten(reset_layers))
|
||||
return model
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
def nparams(m):
|
||||
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
|
||||
return m.weight.size * (32 // m.bits)
|
||||
return sum(v.size for _, v in tree_flatten(m.parameters()))
|
||||
|
||||
leaf_modules = tree_flatten(
|
||||
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||
)
|
||||
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
|
||||
trainable_p = (
|
||||
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
||||
)
|
||||
print(
|
||||
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
||||
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
||||
)
|
||||
|
Reference in New Issue
Block a user