LoRA: Extract small function (#614)

* LoRA: Extract pre_processing_model  function

* LoRA: Extract small functions(train_model,evaluate_model)

* move test case to test_tuner_utils.py

* nits

* nits

* remove extra param, validate at it 0

* version

* fix test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
madroid 2024-06-02 21:38:42 +08:00 committed by GitHub
parent 81318ad4a8
commit c457a3f88b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 232 additions and 206 deletions

View File

@ -60,8 +60,7 @@ lora_parameters:
# These will be applied for the last lora_layers
keys: ["self_attn.q_proj", "self_attn.v_proj"]
rank: 8
alpha: 16.0
scale: 10.0
scale: 20.0
dropout: 0.0
# Schedule can only be specified in a config file, uncomment to use.

View File

@ -10,11 +10,16 @@ import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import yaml
from mlx.utils import tree_flatten
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import apply_lora_layers, build_schedule, linear_to_lora_layers
from .tuner.utils import (
apply_lora_layers,
build_schedule,
linear_to_lora_layers,
print_trainable_parameters,
)
from .utils import load, save_config
yaml_loader = yaml.SafeLoader
@ -33,7 +38,6 @@ yaml_loader.add_implicit_resolver(
list("-+0123456789."),
)
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
@ -150,63 +154,32 @@ def build_parser():
return parser
def print_trainable_parameters(model):
def nparams(m):
if isinstance(m, nn.QuantizedLinear):
return m.weight.size * (32 // m.bits)
return sum(v.size for _, v in tree_flatten(m.parameters()))
leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
)
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
trainable_p = (
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
)
print(
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)"
)
def run(args, training_callback: TrainingCallback = None):
np.random.seed(args.seed)
print("Loading pretrained model")
model, tokenizer = load(args.model)
def train_model(
args,
model: nn.Module,
tokenizer: TokenizerWrapper,
train_set,
valid_set,
training_callback: TrainingCallback = None,
):
# Freeze all layers
model.freeze()
adapter_path = Path(args.adapter_path)
adapter_file = adapter_path / "adapters.safetensors"
if args.test and not args.train:
# Allow testing without LoRA layers by providing empty path
if args.adapter_path != "":
apply_lora_layers(model, adapter_path)
elif args.train:
adapter_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), adapter_path / "adapter_config.json")
# Convert linear layers to lora layers and unfreeze in the process
linear_to_lora_layers(
model, args.lora_layers, args.lora_parameters, args.use_dora
)
print_trainable_parameters(model)
else:
raise ValueError("Must provide at least one of --train or --test")
print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer)
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
# Resume training the given adapters.
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
print(f"Loading pretrained adapters from {resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
if args.train:
print("Training")
print_trainable_parameters(model)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
@ -223,9 +196,7 @@ def run(args, training_callback: TrainingCallback = None):
model.train()
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule)
if args.lr_schedule
else args.learning_rate
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
# Train model
@ -239,8 +210,8 @@ def run(args, training_callback: TrainingCallback = None):
training_callback=training_callback,
)
if args.test:
print("Testing")
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()
test_loss = evaluate(
@ -257,6 +228,31 @@ def run(args, training_callback: TrainingCallback = None):
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
def run(args, training_callback: TrainingCallback = None):
np.random.seed(args.seed)
print("Loading pretrained model")
model, tokenizer = load(args.model)
print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer)
if args.test and not args.train:
# Allow testing without LoRA layers by providing empty path
if args.adapter_path != "":
apply_lora_layers(model, args.adapter_path)
elif args.train:
print("Training")
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
else:
raise ValueError("Must provide at least one of --train or --test")
if args.test:
print("Testing")
evaluate_model(args, model, tokenizer, test_set)
def main():
parser = build_parser()
args = parser.parse_args()

View File

@ -0,0 +1,2 @@
from .trainer import TrainingArgs, evaluate, train
from .utils import linear_to_lora_layers

View File

@ -11,9 +11,8 @@ class DoRALinear(nn.Module):
def from_linear(
linear: nn.Linear,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
):
# TODO support quantized weights in DoRALinear
output_dims, input_dims = linear.weight.shape
@ -23,7 +22,6 @@ class DoRALinear(nn.Module):
input_dims=input_dims,
output_dims=output_dims,
r=r,
alpha=alpha,
dropout=dropout,
scale=scale,
)
@ -56,9 +54,8 @@ class DoRALinear(nn.Module):
input_dims: int,
output_dims: int,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
bias: bool = False,
):
super().__init__()
@ -68,7 +65,7 @@ class DoRALinear(nn.Module):
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale * (alpha / r)
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)

View File

@ -13,9 +13,8 @@ class LoRALinear(nn.Module):
def from_linear(
linear: nn.Linear,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
):
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
@ -26,7 +25,6 @@ class LoRALinear(nn.Module):
input_dims=input_dims,
output_dims=output_dims,
r=r,
alpha=alpha,
dropout=dropout,
scale=scale,
)
@ -74,9 +72,8 @@ class LoRALinear(nn.Module):
input_dims: int,
output_dims: int,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
bias: bool = False,
):
super().__init__()
@ -87,7 +84,7 @@ class LoRALinear(nn.Module):
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale * (alpha / r)
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
@ -109,16 +106,14 @@ class LoRASwitchLinear(nn.Module):
def from_linear(
linear: nn.Module,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
):
lora_lin = LoRASwitchLinear(
input_dims=linear.input_dims,
output_dims=linear.output_dims,
num_experts=linear.num_experts,
r=r,
alpha=alpha,
dropout=dropout,
scale=scale,
)
@ -163,9 +158,8 @@ class LoRASwitchLinear(nn.Module):
output_dims: int,
num_experts: int,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
bias: bool = False,
):
super().__init__()
@ -176,7 +170,7 @@ class LoRASwitchLinear(nn.Module):
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale * (alpha / r)
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)

View File

@ -29,9 +29,6 @@ def grad_checkpoint(layer):
@dataclass
class TrainingArgs:
lora_layers: int = field(
default=16, metadata={"help": "Number of layers to fine-tune"}
)
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
val_batches: int = field(
@ -211,6 +208,35 @@ def train(
train=True,
),
):
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
loss=loss,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
print(
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
)
if training_callback is not None:
val_info = {
"iteration": it,
"val_loss": val_loss,
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
lvalue, toks = step(batch)
mx.eval(state, lvalue, toks)
@ -220,9 +246,9 @@ def train(
# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters:
train_loss = np.mean(losses)
stop = time.perf_counter()
train_loss = np.mean(losses)
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
@ -253,34 +279,6 @@ def train(
n_tokens = 0
start = time.perf_counter()
# Report validation loss if needed
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
loss=loss,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
print(
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
)
if training_callback is not None:
val_info = {
"iteration": it,
"val_loss": val_loss,
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
# Save adapter weights
if it % args.steps_per_save == 0:
save_adapter(model, args.adapter_file)

View File

@ -7,7 +7,7 @@ from typing import Dict
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_unflatten
from mlx.utils import tree_flatten, tree_unflatten
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
from .dora import DoRALinear
@ -48,7 +48,7 @@ def linear_to_lora_layers(
num_lora_layers (int): The number of blocks to convert to lora layers
starting from the last layer.
config (dict): More configuration parameters for LoRA, including the
rank, alpha, scale, and optional layer keys.
rank, scale, and optional layer keys.
use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False``
"""
@ -79,7 +79,6 @@ def linear_to_lora_layers(
return LoRALayer.from_linear(
layer,
r=config["rank"],
alpha=config["alpha"],
scale=config["scale"],
dropout=config["dropout"],
)
@ -218,3 +217,22 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
if len(reset_layers) > 0:
model.update_modules(tree_unflatten(reset_layers))
return model
def print_trainable_parameters(model):
def nparams(m):
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
return m.weight.size * (32 // m.bits)
return sum(v.size for _, v in tree_flatten(m.parameters()))
leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
)
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
trainable_p = (
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
)
print(
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)"
)

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.14.1"
__version__ = "0.14.2"

View File

@ -6,7 +6,6 @@ import unittest
from io import StringIO
from unittest.mock import MagicMock
import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_flatten
from mlx_lm import lora, tuner
@ -61,68 +60,6 @@ class TestLora(unittest.TestCase):
params["keys"] = ["self_attn.k_proj"]
check_config(params)
def test_quantized_print_trainable_parameters(self):
model = MagicMock()
quantized_linear = MagicMock(spec=nn.QuantizedLinear)
quantized_linear.weight = MagicMock(size=1e6)
quantized_linear.bits = 8
lora_linear = MagicMock(spec=LoRALinear)
lora_linear.weight = MagicMock(size=2e6)
lora_linear.parameters.return_value = [lora_linear.weight]
linear = MagicMock(spec=nn.Linear)
linear.weight = MagicMock(size=3e6)
linear.parameters.return_value = [linear.weight]
model.leaf_modules.return_value = {
"quantized_linear": quantized_linear,
"lora_linear": lora_linear,
"linear": linear,
}
model.trainable_parameters.return_value = {
"layer1.weight": MagicMock(size=1e6),
"layer3.weight": MagicMock(size=2e6),
}
expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n"
lora.print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits)
self.capturedOutput.truncate(0)
self.capturedOutput.seek(0)
quantized_linear.weight = MagicMock(size=1e6)
quantized_linear.bits = 4
expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n"
lora.print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits)
self.capturedOutput.truncate(0)
self.capturedOutput.seek(0)
def test_print_trainable_parameters(self):
model = MagicMock()
linear1 = MagicMock(spec=nn.Linear)
linear1.weight = MagicMock(size=1e6)
linear1.parameters.return_value = [linear1.weight]
linear2 = MagicMock(spec=nn.Linear)
linear2.weight = MagicMock(size=2e6)
linear2.parameters.return_value = [linear2.weight]
lora_linear = MagicMock(spec=LoRALinear)
lora_linear.weight = MagicMock(size=3e6)
lora_linear.parameters.return_value = [lora_linear.weight]
model.leaf_modules.return_value = {
"linear1": linear1,
"linear2": linear2,
"lora_linear": lora_linear,
}
model.trainable_parameters.return_value = {
"layer1.weight": MagicMock(size=1e6),
"layer3.weight": MagicMock(size=2e6),
}
expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n"
lora.print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output)
class TestScheduleConfig(unittest.TestCase):
def test_join(self):

View File

@ -0,0 +1,85 @@
# Copyright © 2024 Apple Inc.
import sys
import unittest
from io import StringIO
from unittest.mock import MagicMock
import mlx.nn as nn
from mlx_lm.tuner.lora import LoRALinear
from mlx_lm.tuner.utils import print_trainable_parameters
class TestTunerUtils(unittest.TestCase):
def setUp(self):
self.capturedOutput = StringIO()
sys.stdout = self.capturedOutput
def tearDown(self):
sys.stdout = sys.__stdout__
def test_quantized_print_trainable_parameters(self):
model = MagicMock()
quantized_linear = MagicMock(spec=nn.QuantizedLinear)
quantized_linear.weight = MagicMock(size=1e6)
quantized_linear.bits = 8
lora_linear = MagicMock(spec=LoRALinear)
lora_linear.weight = MagicMock(size=2e6)
lora_linear.parameters.return_value = [lora_linear.weight]
linear = MagicMock(spec=nn.Linear)
linear.weight = MagicMock(size=3e6)
linear.parameters.return_value = [linear.weight]
model.leaf_modules.return_value = {
"quantized_linear": quantized_linear,
"lora_linear": lora_linear,
"linear": linear,
}
model.trainable_parameters.return_value = {
"layer1.weight": MagicMock(size=1e6),
"layer3.weight": MagicMock(size=2e6),
}
expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n"
print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits)
self.capturedOutput.truncate(0)
self.capturedOutput.seek(0)
quantized_linear.weight = MagicMock(size=1e6)
quantized_linear.bits = 4
expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n"
print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits)
self.capturedOutput.truncate(0)
self.capturedOutput.seek(0)
def test_print_trainable_parameters(self):
model = MagicMock()
linear1 = MagicMock(spec=nn.Linear)
linear1.weight = MagicMock(size=1e6)
linear1.parameters.return_value = [linear1.weight]
linear2 = MagicMock(spec=nn.Linear)
linear2.weight = MagicMock(size=2e6)
linear2.parameters.return_value = [linear2.weight]
lora_linear = MagicMock(spec=LoRALinear)
lora_linear.weight = MagicMock(size=3e6)
lora_linear.parameters.return_value = [lora_linear.weight]
model.leaf_modules.return_value = {
"linear1": linear1,
"linear2": linear2,
"lora_linear": lora_linear,
}
model.trainable_parameters.return_value = {
"layer1.weight": MagicMock(size=1e6),
"layer3.weight": MagicMock(size=2e6),
}
expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n"
print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output)
if __name__ == "__main__":
unittest.main()