LoRA: Extract small function (#614)

* LoRA: Extract pre_processing_model  function

* LoRA: Extract small functions(train_model,evaluate_model)

* move test case to test_tuner_utils.py

* nits

* nits

* remove extra param, validate at it 0

* version

* fix test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
madroid 2024-06-02 21:38:42 +08:00 committed by GitHub
parent 81318ad4a8
commit c457a3f88b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 232 additions and 206 deletions

View File

@ -60,8 +60,7 @@ lora_parameters:
# These will be applied for the last lora_layers # These will be applied for the last lora_layers
keys: ["self_attn.q_proj", "self_attn.v_proj"] keys: ["self_attn.q_proj", "self_attn.v_proj"]
rank: 8 rank: 8
alpha: 16.0 scale: 20.0
scale: 10.0
dropout: 0.0 dropout: 0.0
# Schedule can only be specified in a config file, uncomment to use. # Schedule can only be specified in a config file, uncomment to use.

View File

@ -10,11 +10,16 @@ import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np import numpy as np
import yaml import yaml
from mlx.utils import tree_flatten
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import apply_lora_layers, build_schedule, linear_to_lora_layers from .tuner.utils import (
apply_lora_layers,
build_schedule,
linear_to_lora_layers,
print_trainable_parameters,
)
from .utils import load, save_config from .utils import load, save_config
yaml_loader = yaml.SafeLoader yaml_loader = yaml.SafeLoader
@ -33,7 +38,6 @@ yaml_loader.add_implicit_resolver(
list("-+0123456789."), list("-+0123456789."),
) )
CONFIG_DEFAULTS = { CONFIG_DEFAULTS = {
"model": "mlx_model", "model": "mlx_model",
"train": False, "train": False,
@ -150,111 +154,103 @@ def build_parser():
return parser return parser
def print_trainable_parameters(model): def train_model(
def nparams(m): args,
if isinstance(m, nn.QuantizedLinear): model: nn.Module,
return m.weight.size * (32 // m.bits) tokenizer: TokenizerWrapper,
return sum(v.size for _, v in tree_flatten(m.parameters())) train_set,
valid_set,
training_callback: TrainingCallback = None,
):
# Freeze all layers
model.freeze()
leaf_modules = tree_flatten( # Convert linear layers to lora layers and unfreeze in the process
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
# Resume training the given adapters.
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
print_trainable_parameters(model)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
) )
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
trainable_p = ( model.train()
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
) )
print( # Train model
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " train(
f"({trainable_p:.3f}M/{total_p:.3f}M)" model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
) )
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
def run(args, training_callback: TrainingCallback = None): def run(args, training_callback: TrainingCallback = None):
np.random.seed(args.seed) np.random.seed(args.seed)
print("Loading pretrained model") print("Loading pretrained model")
model, tokenizer = load(args.model) model, tokenizer = load(args.model)
# Freeze all layers print("Loading datasets")
model.freeze() train_set, valid_set, test_set = load_dataset(args, tokenizer)
adapter_path = Path(args.adapter_path)
adapter_file = adapter_path / "adapters.safetensors"
if args.test and not args.train: if args.test and not args.train:
# Allow testing without LoRA layers by providing empty path # Allow testing without LoRA layers by providing empty path
if args.adapter_path != "": if args.adapter_path != "":
apply_lora_layers(model, adapter_path) apply_lora_layers(model, args.adapter_path)
elif args.train:
adapter_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), adapter_path / "adapter_config.json")
# Convert linear layers to lora layers and unfreeze in the process elif args.train:
linear_to_lora_layers( print("Training")
model, args.lora_layers, args.lora_parameters, args.use_dora train_model(args, model, tokenizer, train_set, valid_set, training_callback)
)
print_trainable_parameters(model)
else: else:
raise ValueError("Must provide at least one of --train or --test") raise ValueError("Must provide at least one of --train or --test")
print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer)
# Resume training the given adapters.
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
if args.train:
print("Training")
# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)
model.train()
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule)
if args.lr_schedule
else args.learning_rate
)
)
# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
if args.test: if args.test:
print("Testing") print("Testing")
model.eval() evaluate_model(args, model, tokenizer, test_set)
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
def main(): def main():

View File

@ -0,0 +1,2 @@
from .trainer import TrainingArgs, evaluate, train
from .utils import linear_to_lora_layers

View File

@ -11,9 +11,8 @@ class DoRALinear(nn.Module):
def from_linear( def from_linear(
linear: nn.Linear, linear: nn.Linear,
r: int = 8, r: int = 8,
alpha: float = 16,
dropout: float = 0.0, dropout: float = 0.0,
scale: float = 10.0, scale: float = 20.0,
): ):
# TODO support quantized weights in DoRALinear # TODO support quantized weights in DoRALinear
output_dims, input_dims = linear.weight.shape output_dims, input_dims = linear.weight.shape
@ -23,7 +22,6 @@ class DoRALinear(nn.Module):
input_dims=input_dims, input_dims=input_dims,
output_dims=output_dims, output_dims=output_dims,
r=r, r=r,
alpha=alpha,
dropout=dropout, dropout=dropout,
scale=scale, scale=scale,
) )
@ -56,9 +54,8 @@ class DoRALinear(nn.Module):
input_dims: int, input_dims: int,
output_dims: int, output_dims: int,
r: int = 8, r: int = 8,
alpha: float = 16,
dropout: float = 0.0, dropout: float = 0.0,
scale: float = 10.0, scale: float = 20.0,
bias: bool = False, bias: bool = False,
): ):
super().__init__() super().__init__()
@ -68,7 +65,7 @@ class DoRALinear(nn.Module):
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update # Scale for low-rank update
self.scale = scale * (alpha / r) self.scale = scale
# Low rank lora weights # Low rank lora weights
scale = 1 / math.sqrt(input_dims) scale = 1 / math.sqrt(input_dims)

View File

@ -13,9 +13,8 @@ class LoRALinear(nn.Module):
def from_linear( def from_linear(
linear: nn.Linear, linear: nn.Linear,
r: int = 8, r: int = 8,
alpha: float = 16,
dropout: float = 0.0, dropout: float = 0.0,
scale: float = 10.0, scale: float = 20.0,
): ):
# TODO remove when input_dims and output_dims are attributes # TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear # on linear and quantized linear
@ -26,7 +25,6 @@ class LoRALinear(nn.Module):
input_dims=input_dims, input_dims=input_dims,
output_dims=output_dims, output_dims=output_dims,
r=r, r=r,
alpha=alpha,
dropout=dropout, dropout=dropout,
scale=scale, scale=scale,
) )
@ -74,9 +72,8 @@ class LoRALinear(nn.Module):
input_dims: int, input_dims: int,
output_dims: int, output_dims: int,
r: int = 8, r: int = 8,
alpha: float = 16,
dropout: float = 0.0, dropout: float = 0.0,
scale: float = 10.0, scale: float = 20.0,
bias: bool = False, bias: bool = False,
): ):
super().__init__() super().__init__()
@ -87,7 +84,7 @@ class LoRALinear(nn.Module):
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update # Scale for low-rank update
self.scale = scale * (alpha / r) self.scale = scale
# Low rank lora weights # Low rank lora weights
scale = 1 / math.sqrt(input_dims) scale = 1 / math.sqrt(input_dims)
@ -109,16 +106,14 @@ class LoRASwitchLinear(nn.Module):
def from_linear( def from_linear(
linear: nn.Module, linear: nn.Module,
r: int = 8, r: int = 8,
alpha: float = 16,
dropout: float = 0.0, dropout: float = 0.0,
scale: float = 10.0, scale: float = 20.0,
): ):
lora_lin = LoRASwitchLinear( lora_lin = LoRASwitchLinear(
input_dims=linear.input_dims, input_dims=linear.input_dims,
output_dims=linear.output_dims, output_dims=linear.output_dims,
num_experts=linear.num_experts, num_experts=linear.num_experts,
r=r, r=r,
alpha=alpha,
dropout=dropout, dropout=dropout,
scale=scale, scale=scale,
) )
@ -163,9 +158,8 @@ class LoRASwitchLinear(nn.Module):
output_dims: int, output_dims: int,
num_experts: int, num_experts: int,
r: int = 8, r: int = 8,
alpha: float = 16,
dropout: float = 0.0, dropout: float = 0.0,
scale: float = 10.0, scale: float = 20.0,
bias: bool = False, bias: bool = False,
): ):
super().__init__() super().__init__()
@ -176,7 +170,7 @@ class LoRASwitchLinear(nn.Module):
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update # Scale for low-rank update
self.scale = scale * (alpha / r) self.scale = scale
# Low rank lora weights # Low rank lora weights
scale = 1 / math.sqrt(input_dims) scale = 1 / math.sqrt(input_dims)

View File

@ -29,9 +29,6 @@ def grad_checkpoint(layer):
@dataclass @dataclass
class TrainingArgs: class TrainingArgs:
lora_layers: int = field(
default=16, metadata={"help": "Number of layers to fine-tune"}
)
batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
iters: int = field(default=100, metadata={"help": "Iterations to train for."}) iters: int = field(default=100, metadata={"help": "Iterations to train for."})
val_batches: int = field( val_batches: int = field(
@ -211,6 +208,35 @@ def train(
train=True, train=True,
), ),
): ):
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
loss=loss,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
print(
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
)
if training_callback is not None:
val_info = {
"iteration": it,
"val_loss": val_loss,
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
lvalue, toks = step(batch) lvalue, toks = step(batch)
mx.eval(state, lvalue, toks) mx.eval(state, lvalue, toks)
@ -220,9 +246,9 @@ def train(
# Report training loss if needed # Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters: if it % args.steps_per_report == 0 or it == args.iters:
train_loss = np.mean(losses)
stop = time.perf_counter() stop = time.perf_counter()
train_loss = np.mean(losses)
learning_rate = optimizer.learning_rate.item() learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start) tokens_sec = float(n_tokens) / (stop - start)
@ -253,34 +279,6 @@ def train(
n_tokens = 0 n_tokens = 0
start = time.perf_counter() start = time.perf_counter()
# Report validation loss if needed
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
loss=loss,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
print(
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
)
if training_callback is not None:
val_info = {
"iteration": it,
"val_loss": val_loss,
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
# Save adapter weights # Save adapter weights
if it % args.steps_per_save == 0: if it % args.steps_per_save == 0:
save_adapter(model, args.adapter_file) save_adapter(model, args.adapter_file)

View File

@ -7,7 +7,7 @@ from typing import Dict
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as opt import mlx.optimizers as opt
from mlx.utils import tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
from .dora import DoRALinear from .dora import DoRALinear
@ -48,7 +48,7 @@ def linear_to_lora_layers(
num_lora_layers (int): The number of blocks to convert to lora layers num_lora_layers (int): The number of blocks to convert to lora layers
starting from the last layer. starting from the last layer.
config (dict): More configuration parameters for LoRA, including the config (dict): More configuration parameters for LoRA, including the
rank, alpha, scale, and optional layer keys. rank, scale, and optional layer keys.
use_dora (bool): If True, uses DoRA instead of LoRA. use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False`` Default: ``False``
""" """
@ -79,7 +79,6 @@ def linear_to_lora_layers(
return LoRALayer.from_linear( return LoRALayer.from_linear(
layer, layer,
r=config["rank"], r=config["rank"],
alpha=config["alpha"],
scale=config["scale"], scale=config["scale"],
dropout=config["dropout"], dropout=config["dropout"],
) )
@ -218,3 +217,22 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
if len(reset_layers) > 0: if len(reset_layers) > 0:
model.update_modules(tree_unflatten(reset_layers)) model.update_modules(tree_unflatten(reset_layers))
return model return model
def print_trainable_parameters(model):
def nparams(m):
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
return m.weight.size * (32 // m.bits)
return sum(v.size for _, v in tree_flatten(m.parameters()))
leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
)
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
trainable_p = (
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
)
print(
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)"
)

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.14.1" __version__ = "0.14.2"

View File

@ -6,7 +6,6 @@ import unittest
from io import StringIO from io import StringIO
from unittest.mock import MagicMock from unittest.mock import MagicMock
import mlx.nn as nn
import mlx.optimizers as opt import mlx.optimizers as opt
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from mlx_lm import lora, tuner from mlx_lm import lora, tuner
@ -61,68 +60,6 @@ class TestLora(unittest.TestCase):
params["keys"] = ["self_attn.k_proj"] params["keys"] = ["self_attn.k_proj"]
check_config(params) check_config(params)
def test_quantized_print_trainable_parameters(self):
model = MagicMock()
quantized_linear = MagicMock(spec=nn.QuantizedLinear)
quantized_linear.weight = MagicMock(size=1e6)
quantized_linear.bits = 8
lora_linear = MagicMock(spec=LoRALinear)
lora_linear.weight = MagicMock(size=2e6)
lora_linear.parameters.return_value = [lora_linear.weight]
linear = MagicMock(spec=nn.Linear)
linear.weight = MagicMock(size=3e6)
linear.parameters.return_value = [linear.weight]
model.leaf_modules.return_value = {
"quantized_linear": quantized_linear,
"lora_linear": lora_linear,
"linear": linear,
}
model.trainable_parameters.return_value = {
"layer1.weight": MagicMock(size=1e6),
"layer3.weight": MagicMock(size=2e6),
}
expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n"
lora.print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits)
self.capturedOutput.truncate(0)
self.capturedOutput.seek(0)
quantized_linear.weight = MagicMock(size=1e6)
quantized_linear.bits = 4
expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n"
lora.print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits)
self.capturedOutput.truncate(0)
self.capturedOutput.seek(0)
def test_print_trainable_parameters(self):
model = MagicMock()
linear1 = MagicMock(spec=nn.Linear)
linear1.weight = MagicMock(size=1e6)
linear1.parameters.return_value = [linear1.weight]
linear2 = MagicMock(spec=nn.Linear)
linear2.weight = MagicMock(size=2e6)
linear2.parameters.return_value = [linear2.weight]
lora_linear = MagicMock(spec=LoRALinear)
lora_linear.weight = MagicMock(size=3e6)
lora_linear.parameters.return_value = [lora_linear.weight]
model.leaf_modules.return_value = {
"linear1": linear1,
"linear2": linear2,
"lora_linear": lora_linear,
}
model.trainable_parameters.return_value = {
"layer1.weight": MagicMock(size=1e6),
"layer3.weight": MagicMock(size=2e6),
}
expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n"
lora.print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output)
class TestScheduleConfig(unittest.TestCase): class TestScheduleConfig(unittest.TestCase):
def test_join(self): def test_join(self):

View File

@ -0,0 +1,85 @@
# Copyright © 2024 Apple Inc.
import sys
import unittest
from io import StringIO
from unittest.mock import MagicMock
import mlx.nn as nn
from mlx_lm.tuner.lora import LoRALinear
from mlx_lm.tuner.utils import print_trainable_parameters
class TestTunerUtils(unittest.TestCase):
def setUp(self):
self.capturedOutput = StringIO()
sys.stdout = self.capturedOutput
def tearDown(self):
sys.stdout = sys.__stdout__
def test_quantized_print_trainable_parameters(self):
model = MagicMock()
quantized_linear = MagicMock(spec=nn.QuantizedLinear)
quantized_linear.weight = MagicMock(size=1e6)
quantized_linear.bits = 8
lora_linear = MagicMock(spec=LoRALinear)
lora_linear.weight = MagicMock(size=2e6)
lora_linear.parameters.return_value = [lora_linear.weight]
linear = MagicMock(spec=nn.Linear)
linear.weight = MagicMock(size=3e6)
linear.parameters.return_value = [linear.weight]
model.leaf_modules.return_value = {
"quantized_linear": quantized_linear,
"lora_linear": lora_linear,
"linear": linear,
}
model.trainable_parameters.return_value = {
"layer1.weight": MagicMock(size=1e6),
"layer3.weight": MagicMock(size=2e6),
}
expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n"
print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits)
self.capturedOutput.truncate(0)
self.capturedOutput.seek(0)
quantized_linear.weight = MagicMock(size=1e6)
quantized_linear.bits = 4
expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n"
print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits)
self.capturedOutput.truncate(0)
self.capturedOutput.seek(0)
def test_print_trainable_parameters(self):
model = MagicMock()
linear1 = MagicMock(spec=nn.Linear)
linear1.weight = MagicMock(size=1e6)
linear1.parameters.return_value = [linear1.weight]
linear2 = MagicMock(spec=nn.Linear)
linear2.weight = MagicMock(size=2e6)
linear2.parameters.return_value = [linear2.weight]
lora_linear = MagicMock(spec=LoRALinear)
lora_linear.weight = MagicMock(size=3e6)
lora_linear.parameters.return_value = [lora_linear.weight]
model.leaf_modules.return_value = {
"linear1": linear1,
"linear2": linear2,
"lora_linear": lora_linear,
}
model.trainable_parameters.return_value = {
"layer1.weight": MagicMock(size=1e6),
"layer3.weight": MagicMock(size=2e6),
}
expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n"
print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output)
if __name__ == "__main__":
unittest.main()