LoRA: Extract small function (#614)

* LoRA: Extract pre_processing_model  function

* LoRA: Extract small functions(train_model,evaluate_model)

* move test case to test_tuner_utils.py

* nits

* nits

* remove extra param, validate at it 0

* version

* fix test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
madroid
2024-06-02 21:38:42 +08:00
committed by GitHub
parent 81318ad4a8
commit c457a3f88b
10 changed files with 232 additions and 206 deletions

View File

@@ -0,0 +1,2 @@
from .trainer import TrainingArgs, evaluate, train
from .utils import linear_to_lora_layers

View File

@@ -11,9 +11,8 @@ class DoRALinear(nn.Module):
def from_linear(
linear: nn.Linear,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
):
# TODO support quantized weights in DoRALinear
output_dims, input_dims = linear.weight.shape
@@ -23,7 +22,6 @@ class DoRALinear(nn.Module):
input_dims=input_dims,
output_dims=output_dims,
r=r,
alpha=alpha,
dropout=dropout,
scale=scale,
)
@@ -56,9 +54,8 @@ class DoRALinear(nn.Module):
input_dims: int,
output_dims: int,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
bias: bool = False,
):
super().__init__()
@@ -68,7 +65,7 @@ class DoRALinear(nn.Module):
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale * (alpha / r)
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)

View File

@@ -13,9 +13,8 @@ class LoRALinear(nn.Module):
def from_linear(
linear: nn.Linear,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
):
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
@@ -26,7 +25,6 @@ class LoRALinear(nn.Module):
input_dims=input_dims,
output_dims=output_dims,
r=r,
alpha=alpha,
dropout=dropout,
scale=scale,
)
@@ -74,9 +72,8 @@ class LoRALinear(nn.Module):
input_dims: int,
output_dims: int,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
bias: bool = False,
):
super().__init__()
@@ -87,7 +84,7 @@ class LoRALinear(nn.Module):
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale * (alpha / r)
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
@@ -109,16 +106,14 @@ class LoRASwitchLinear(nn.Module):
def from_linear(
linear: nn.Module,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
):
lora_lin = LoRASwitchLinear(
input_dims=linear.input_dims,
output_dims=linear.output_dims,
num_experts=linear.num_experts,
r=r,
alpha=alpha,
dropout=dropout,
scale=scale,
)
@@ -163,9 +158,8 @@ class LoRASwitchLinear(nn.Module):
output_dims: int,
num_experts: int,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
bias: bool = False,
):
super().__init__()
@@ -176,7 +170,7 @@ class LoRASwitchLinear(nn.Module):
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale * (alpha / r)
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)

View File

@@ -29,9 +29,6 @@ def grad_checkpoint(layer):
@dataclass
class TrainingArgs:
lora_layers: int = field(
default=16, metadata={"help": "Number of layers to fine-tune"}
)
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
val_batches: int = field(
@@ -211,6 +208,35 @@ def train(
train=True,
),
):
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
loss=loss,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
print(
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
)
if training_callback is not None:
val_info = {
"iteration": it,
"val_loss": val_loss,
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
lvalue, toks = step(batch)
mx.eval(state, lvalue, toks)
@@ -220,9 +246,9 @@ def train(
# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters:
train_loss = np.mean(losses)
stop = time.perf_counter()
train_loss = np.mean(losses)
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
@@ -253,34 +279,6 @@ def train(
n_tokens = 0
start = time.perf_counter()
# Report validation loss if needed
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
loss=loss,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
print(
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
)
if training_callback is not None:
val_info = {
"iteration": it,
"val_loss": val_loss,
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
# Save adapter weights
if it % args.steps_per_save == 0:
save_adapter(model, args.adapter_file)

View File

@@ -7,7 +7,7 @@ from typing import Dict
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_unflatten
from mlx.utils import tree_flatten, tree_unflatten
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
from .dora import DoRALinear
@@ -48,7 +48,7 @@ def linear_to_lora_layers(
num_lora_layers (int): The number of blocks to convert to lora layers
starting from the last layer.
config (dict): More configuration parameters for LoRA, including the
rank, alpha, scale, and optional layer keys.
rank, scale, and optional layer keys.
use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False``
"""
@@ -79,7 +79,6 @@ def linear_to_lora_layers(
return LoRALayer.from_linear(
layer,
r=config["rank"],
alpha=config["alpha"],
scale=config["scale"],
dropout=config["dropout"],
)
@@ -218,3 +217,22 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
if len(reset_layers) > 0:
model.update_modules(tree_unflatten(reset_layers))
return model
def print_trainable_parameters(model):
def nparams(m):
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
return m.weight.size * (32 // m.bits)
return sum(v.size for _, v in tree_flatten(m.parameters()))
leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
)
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
trainable_p = (
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
)
print(
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)"
)