mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
LoRA: Extract small function (#614)
* LoRA: Extract pre_processing_model function * LoRA: Extract small functions(train_model,evaluate_model) * move test case to test_tuner_utils.py * nits * nits * remove extra param, validate at it 0 * version * fix test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
81318ad4a8
commit
c457a3f88b
@ -60,8 +60,7 @@ lora_parameters:
|
|||||||
# These will be applied for the last lora_layers
|
# These will be applied for the last lora_layers
|
||||||
keys: ["self_attn.q_proj", "self_attn.v_proj"]
|
keys: ["self_attn.q_proj", "self_attn.v_proj"]
|
||||||
rank: 8
|
rank: 8
|
||||||
alpha: 16.0
|
scale: 20.0
|
||||||
scale: 10.0
|
|
||||||
dropout: 0.0
|
dropout: 0.0
|
||||||
|
|
||||||
# Schedule can only be specified in a config file, uncomment to use.
|
# Schedule can only be specified in a config file, uncomment to use.
|
||||||
|
@ -10,11 +10,16 @@ import mlx.nn as nn
|
|||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
from mlx.utils import tree_flatten
|
|
||||||
|
|
||||||
|
from .tokenizer_utils import TokenizerWrapper
|
||||||
from .tuner.datasets import load_dataset
|
from .tuner.datasets import load_dataset
|
||||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
from .tuner.utils import apply_lora_layers, build_schedule, linear_to_lora_layers
|
from .tuner.utils import (
|
||||||
|
apply_lora_layers,
|
||||||
|
build_schedule,
|
||||||
|
linear_to_lora_layers,
|
||||||
|
print_trainable_parameters,
|
||||||
|
)
|
||||||
from .utils import load, save_config
|
from .utils import load, save_config
|
||||||
|
|
||||||
yaml_loader = yaml.SafeLoader
|
yaml_loader = yaml.SafeLoader
|
||||||
@ -33,7 +38,6 @@ yaml_loader.add_implicit_resolver(
|
|||||||
list("-+0123456789."),
|
list("-+0123456789."),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
CONFIG_DEFAULTS = {
|
CONFIG_DEFAULTS = {
|
||||||
"model": "mlx_model",
|
"model": "mlx_model",
|
||||||
"train": False,
|
"train": False,
|
||||||
@ -150,111 +154,103 @@ def build_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def print_trainable_parameters(model):
|
def train_model(
|
||||||
def nparams(m):
|
args,
|
||||||
if isinstance(m, nn.QuantizedLinear):
|
model: nn.Module,
|
||||||
return m.weight.size * (32 // m.bits)
|
tokenizer: TokenizerWrapper,
|
||||||
return sum(v.size for _, v in tree_flatten(m.parameters()))
|
train_set,
|
||||||
|
valid_set,
|
||||||
|
training_callback: TrainingCallback = None,
|
||||||
|
):
|
||||||
|
# Freeze all layers
|
||||||
|
model.freeze()
|
||||||
|
|
||||||
leaf_modules = tree_flatten(
|
# Convert linear layers to lora layers and unfreeze in the process
|
||||||
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
|
||||||
|
|
||||||
|
# Resume training the given adapters.
|
||||||
|
if args.resume_adapter_file is not None:
|
||||||
|
print(f"Loading pretrained adapters from {resume_adapter_file}")
|
||||||
|
model.load_weights(args.resume_adapter_file, strict=False)
|
||||||
|
|
||||||
|
print_trainable_parameters(model)
|
||||||
|
|
||||||
|
adapter_path = Path(args.adapter_path)
|
||||||
|
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
adapter_file = adapter_path / "adapters.safetensors"
|
||||||
|
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||||
|
|
||||||
|
# init training args
|
||||||
|
training_args = TrainingArgs(
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
iters=args.iters,
|
||||||
|
val_batches=args.val_batches,
|
||||||
|
steps_per_report=args.steps_per_report,
|
||||||
|
steps_per_eval=args.steps_per_eval,
|
||||||
|
steps_per_save=args.save_every,
|
||||||
|
adapter_file=adapter_file,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
grad_checkpoint=args.grad_checkpoint,
|
||||||
)
|
)
|
||||||
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
|
|
||||||
trainable_p = (
|
model.train()
|
||||||
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
opt = optim.Adam(
|
||||||
|
learning_rate=(
|
||||||
|
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||||
|
)
|
||||||
)
|
)
|
||||||
print(
|
# Train model
|
||||||
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
train(
|
||||||
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
args=training_args,
|
||||||
|
optimizer=opt,
|
||||||
|
train_dataset=train_set,
|
||||||
|
val_dataset=valid_set,
|
||||||
|
training_callback=training_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
test_loss = evaluate(
|
||||||
|
model=model,
|
||||||
|
dataset=test_set,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_batches=args.test_batches,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_ppl = math.exp(test_loss)
|
||||||
|
|
||||||
|
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||||
|
|
||||||
|
|
||||||
def run(args, training_callback: TrainingCallback = None):
|
def run(args, training_callback: TrainingCallback = None):
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
print("Loading pretrained model")
|
print("Loading pretrained model")
|
||||||
model, tokenizer = load(args.model)
|
model, tokenizer = load(args.model)
|
||||||
|
|
||||||
# Freeze all layers
|
print("Loading datasets")
|
||||||
model.freeze()
|
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||||
|
|
||||||
adapter_path = Path(args.adapter_path)
|
|
||||||
adapter_file = adapter_path / "adapters.safetensors"
|
|
||||||
|
|
||||||
if args.test and not args.train:
|
if args.test and not args.train:
|
||||||
# Allow testing without LoRA layers by providing empty path
|
# Allow testing without LoRA layers by providing empty path
|
||||||
if args.adapter_path != "":
|
if args.adapter_path != "":
|
||||||
apply_lora_layers(model, adapter_path)
|
apply_lora_layers(model, args.adapter_path)
|
||||||
elif args.train:
|
|
||||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
|
||||||
|
|
||||||
# Convert linear layers to lora layers and unfreeze in the process
|
elif args.train:
|
||||||
linear_to_lora_layers(
|
print("Training")
|
||||||
model, args.lora_layers, args.lora_parameters, args.use_dora
|
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
|
||||||
)
|
|
||||||
print_trainable_parameters(model)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Must provide at least one of --train or --test")
|
raise ValueError("Must provide at least one of --train or --test")
|
||||||
|
|
||||||
print("Loading datasets")
|
|
||||||
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
|
||||||
|
|
||||||
# Resume training the given adapters.
|
|
||||||
if args.resume_adapter_file is not None:
|
|
||||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
|
||||||
model.load_weights(args.resume_adapter_file, strict=False)
|
|
||||||
|
|
||||||
if args.train:
|
|
||||||
print("Training")
|
|
||||||
# init training args
|
|
||||||
training_args = TrainingArgs(
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
iters=args.iters,
|
|
||||||
val_batches=args.val_batches,
|
|
||||||
steps_per_report=args.steps_per_report,
|
|
||||||
steps_per_eval=args.steps_per_eval,
|
|
||||||
steps_per_save=args.save_every,
|
|
||||||
adapter_file=adapter_file,
|
|
||||||
max_seq_length=args.max_seq_length,
|
|
||||||
grad_checkpoint=args.grad_checkpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
model.train()
|
|
||||||
opt = optim.Adam(
|
|
||||||
learning_rate=(
|
|
||||||
build_schedule(args.lr_schedule)
|
|
||||||
if args.lr_schedule
|
|
||||||
else args.learning_rate
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Train model
|
|
||||||
train(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
args=training_args,
|
|
||||||
optimizer=opt,
|
|
||||||
train_dataset=train_set,
|
|
||||||
val_dataset=valid_set,
|
|
||||||
training_callback=training_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.test:
|
if args.test:
|
||||||
print("Testing")
|
print("Testing")
|
||||||
model.eval()
|
evaluate_model(args, model, tokenizer, test_set)
|
||||||
|
|
||||||
test_loss = evaluate(
|
|
||||||
model=model,
|
|
||||||
dataset=test_set,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
num_batches=args.test_batches,
|
|
||||||
max_seq_length=args.max_seq_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_ppl = math.exp(test_loss)
|
|
||||||
|
|
||||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -0,0 +1,2 @@
|
|||||||
|
from .trainer import TrainingArgs, evaluate, train
|
||||||
|
from .utils import linear_to_lora_layers
|
@ -11,9 +11,8 @@ class DoRALinear(nn.Module):
|
|||||||
def from_linear(
|
def from_linear(
|
||||||
linear: nn.Linear,
|
linear: nn.Linear,
|
||||||
r: int = 8,
|
r: int = 8,
|
||||||
alpha: float = 16,
|
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
scale: float = 10.0,
|
scale: float = 20.0,
|
||||||
):
|
):
|
||||||
# TODO support quantized weights in DoRALinear
|
# TODO support quantized weights in DoRALinear
|
||||||
output_dims, input_dims = linear.weight.shape
|
output_dims, input_dims = linear.weight.shape
|
||||||
@ -23,7 +22,6 @@ class DoRALinear(nn.Module):
|
|||||||
input_dims=input_dims,
|
input_dims=input_dims,
|
||||||
output_dims=output_dims,
|
output_dims=output_dims,
|
||||||
r=r,
|
r=r,
|
||||||
alpha=alpha,
|
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
@ -56,9 +54,8 @@ class DoRALinear(nn.Module):
|
|||||||
input_dims: int,
|
input_dims: int,
|
||||||
output_dims: int,
|
output_dims: int,
|
||||||
r: int = 8,
|
r: int = 8,
|
||||||
alpha: float = 16,
|
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
scale: float = 10.0,
|
scale: float = 20.0,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -68,7 +65,7 @@ class DoRALinear(nn.Module):
|
|||||||
self.dropout = nn.Dropout(p=dropout)
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
# Scale for low-rank update
|
# Scale for low-rank update
|
||||||
self.scale = scale * (alpha / r)
|
self.scale = scale
|
||||||
|
|
||||||
# Low rank lora weights
|
# Low rank lora weights
|
||||||
scale = 1 / math.sqrt(input_dims)
|
scale = 1 / math.sqrt(input_dims)
|
||||||
|
@ -13,9 +13,8 @@ class LoRALinear(nn.Module):
|
|||||||
def from_linear(
|
def from_linear(
|
||||||
linear: nn.Linear,
|
linear: nn.Linear,
|
||||||
r: int = 8,
|
r: int = 8,
|
||||||
alpha: float = 16,
|
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
scale: float = 10.0,
|
scale: float = 20.0,
|
||||||
):
|
):
|
||||||
# TODO remove when input_dims and output_dims are attributes
|
# TODO remove when input_dims and output_dims are attributes
|
||||||
# on linear and quantized linear
|
# on linear and quantized linear
|
||||||
@ -26,7 +25,6 @@ class LoRALinear(nn.Module):
|
|||||||
input_dims=input_dims,
|
input_dims=input_dims,
|
||||||
output_dims=output_dims,
|
output_dims=output_dims,
|
||||||
r=r,
|
r=r,
|
||||||
alpha=alpha,
|
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
@ -74,9 +72,8 @@ class LoRALinear(nn.Module):
|
|||||||
input_dims: int,
|
input_dims: int,
|
||||||
output_dims: int,
|
output_dims: int,
|
||||||
r: int = 8,
|
r: int = 8,
|
||||||
alpha: float = 16,
|
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
scale: float = 10.0,
|
scale: float = 20.0,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -87,7 +84,7 @@ class LoRALinear(nn.Module):
|
|||||||
self.dropout = nn.Dropout(p=dropout)
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
# Scale for low-rank update
|
# Scale for low-rank update
|
||||||
self.scale = scale * (alpha / r)
|
self.scale = scale
|
||||||
|
|
||||||
# Low rank lora weights
|
# Low rank lora weights
|
||||||
scale = 1 / math.sqrt(input_dims)
|
scale = 1 / math.sqrt(input_dims)
|
||||||
@ -109,16 +106,14 @@ class LoRASwitchLinear(nn.Module):
|
|||||||
def from_linear(
|
def from_linear(
|
||||||
linear: nn.Module,
|
linear: nn.Module,
|
||||||
r: int = 8,
|
r: int = 8,
|
||||||
alpha: float = 16,
|
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
scale: float = 10.0,
|
scale: float = 20.0,
|
||||||
):
|
):
|
||||||
lora_lin = LoRASwitchLinear(
|
lora_lin = LoRASwitchLinear(
|
||||||
input_dims=linear.input_dims,
|
input_dims=linear.input_dims,
|
||||||
output_dims=linear.output_dims,
|
output_dims=linear.output_dims,
|
||||||
num_experts=linear.num_experts,
|
num_experts=linear.num_experts,
|
||||||
r=r,
|
r=r,
|
||||||
alpha=alpha,
|
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
@ -163,9 +158,8 @@ class LoRASwitchLinear(nn.Module):
|
|||||||
output_dims: int,
|
output_dims: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
r: int = 8,
|
r: int = 8,
|
||||||
alpha: float = 16,
|
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
scale: float = 10.0,
|
scale: float = 20.0,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -176,7 +170,7 @@ class LoRASwitchLinear(nn.Module):
|
|||||||
self.dropout = nn.Dropout(p=dropout)
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
# Scale for low-rank update
|
# Scale for low-rank update
|
||||||
self.scale = scale * (alpha / r)
|
self.scale = scale
|
||||||
|
|
||||||
# Low rank lora weights
|
# Low rank lora weights
|
||||||
scale = 1 / math.sqrt(input_dims)
|
scale = 1 / math.sqrt(input_dims)
|
||||||
|
@ -29,9 +29,6 @@ def grad_checkpoint(layer):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArgs:
|
class TrainingArgs:
|
||||||
lora_layers: int = field(
|
|
||||||
default=16, metadata={"help": "Number of layers to fine-tune"}
|
|
||||||
)
|
|
||||||
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
||||||
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
|
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
|
||||||
val_batches: int = field(
|
val_batches: int = field(
|
||||||
@ -211,6 +208,35 @@ def train(
|
|||||||
train=True,
|
train=True,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
# Report validation loss if needed, the first validation loss
|
||||||
|
# is always measured before any training.
|
||||||
|
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||||
|
stop = time.perf_counter()
|
||||||
|
val_loss = evaluate(
|
||||||
|
model=model,
|
||||||
|
dataset=val_dataset,
|
||||||
|
loss=loss,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_batches=args.val_batches,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
iterate_batches=iterate_batches,
|
||||||
|
)
|
||||||
|
val_time = time.perf_counter() - stop
|
||||||
|
print(
|
||||||
|
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_callback is not None:
|
||||||
|
val_info = {
|
||||||
|
"iteration": it,
|
||||||
|
"val_loss": val_loss,
|
||||||
|
"val_time": val_time,
|
||||||
|
}
|
||||||
|
training_callback.on_val_loss_report(val_info)
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
|
||||||
lvalue, toks = step(batch)
|
lvalue, toks = step(batch)
|
||||||
mx.eval(state, lvalue, toks)
|
mx.eval(state, lvalue, toks)
|
||||||
|
|
||||||
@ -220,9 +246,9 @@ def train(
|
|||||||
|
|
||||||
# Report training loss if needed
|
# Report training loss if needed
|
||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
train_loss = np.mean(losses)
|
|
||||||
|
|
||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
|
|
||||||
|
train_loss = np.mean(losses)
|
||||||
learning_rate = optimizer.learning_rate.item()
|
learning_rate = optimizer.learning_rate.item()
|
||||||
it_sec = args.steps_per_report / (stop - start)
|
it_sec = args.steps_per_report / (stop - start)
|
||||||
tokens_sec = float(n_tokens) / (stop - start)
|
tokens_sec = float(n_tokens) / (stop - start)
|
||||||
@ -253,34 +279,6 @@ def train(
|
|||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
# Report validation loss if needed
|
|
||||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
|
||||||
stop = time.perf_counter()
|
|
||||||
val_loss = evaluate(
|
|
||||||
model=model,
|
|
||||||
dataset=val_dataset,
|
|
||||||
loss=loss,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
num_batches=args.val_batches,
|
|
||||||
max_seq_length=args.max_seq_length,
|
|
||||||
iterate_batches=iterate_batches,
|
|
||||||
)
|
|
||||||
val_time = time.perf_counter() - stop
|
|
||||||
print(
|
|
||||||
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
if training_callback is not None:
|
|
||||||
val_info = {
|
|
||||||
"iteration": it,
|
|
||||||
"val_loss": val_loss,
|
|
||||||
"val_time": val_time,
|
|
||||||
}
|
|
||||||
training_callback.on_val_loss_report(val_info)
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
|
|
||||||
# Save adapter weights
|
# Save adapter weights
|
||||||
if it % args.steps_per_save == 0:
|
if it % args.steps_per_save == 0:
|
||||||
save_adapter(model, args.adapter_file)
|
save_adapter(model, args.adapter_file)
|
||||||
|
@ -7,7 +7,7 @@ from typing import Dict
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as opt
|
import mlx.optimizers as opt
|
||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
|
||||||
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
|
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
|
||||||
from .dora import DoRALinear
|
from .dora import DoRALinear
|
||||||
@ -48,7 +48,7 @@ def linear_to_lora_layers(
|
|||||||
num_lora_layers (int): The number of blocks to convert to lora layers
|
num_lora_layers (int): The number of blocks to convert to lora layers
|
||||||
starting from the last layer.
|
starting from the last layer.
|
||||||
config (dict): More configuration parameters for LoRA, including the
|
config (dict): More configuration parameters for LoRA, including the
|
||||||
rank, alpha, scale, and optional layer keys.
|
rank, scale, and optional layer keys.
|
||||||
use_dora (bool): If True, uses DoRA instead of LoRA.
|
use_dora (bool): If True, uses DoRA instead of LoRA.
|
||||||
Default: ``False``
|
Default: ``False``
|
||||||
"""
|
"""
|
||||||
@ -79,7 +79,6 @@ def linear_to_lora_layers(
|
|||||||
return LoRALayer.from_linear(
|
return LoRALayer.from_linear(
|
||||||
layer,
|
layer,
|
||||||
r=config["rank"],
|
r=config["rank"],
|
||||||
alpha=config["alpha"],
|
|
||||||
scale=config["scale"],
|
scale=config["scale"],
|
||||||
dropout=config["dropout"],
|
dropout=config["dropout"],
|
||||||
)
|
)
|
||||||
@ -218,3 +217,22 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
|
|||||||
if len(reset_layers) > 0:
|
if len(reset_layers) > 0:
|
||||||
model.update_modules(tree_unflatten(reset_layers))
|
model.update_modules(tree_unflatten(reset_layers))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def print_trainable_parameters(model):
|
||||||
|
def nparams(m):
|
||||||
|
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
|
||||||
|
return m.weight.size * (32 // m.bits)
|
||||||
|
return sum(v.size for _, v in tree_flatten(m.parameters()))
|
||||||
|
|
||||||
|
leaf_modules = tree_flatten(
|
||||||
|
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||||
|
)
|
||||||
|
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
|
||||||
|
trainable_p = (
|
||||||
|
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
||||||
|
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
||||||
|
)
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.14.1"
|
__version__ = "0.14.2"
|
||||||
|
@ -6,7 +6,6 @@ import unittest
|
|||||||
from io import StringIO
|
from io import StringIO
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import mlx.nn as nn
|
|
||||||
import mlx.optimizers as opt
|
import mlx.optimizers as opt
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from mlx_lm import lora, tuner
|
from mlx_lm import lora, tuner
|
||||||
@ -61,68 +60,6 @@ class TestLora(unittest.TestCase):
|
|||||||
params["keys"] = ["self_attn.k_proj"]
|
params["keys"] = ["self_attn.k_proj"]
|
||||||
check_config(params)
|
check_config(params)
|
||||||
|
|
||||||
def test_quantized_print_trainable_parameters(self):
|
|
||||||
model = MagicMock()
|
|
||||||
quantized_linear = MagicMock(spec=nn.QuantizedLinear)
|
|
||||||
quantized_linear.weight = MagicMock(size=1e6)
|
|
||||||
quantized_linear.bits = 8
|
|
||||||
lora_linear = MagicMock(spec=LoRALinear)
|
|
||||||
lora_linear.weight = MagicMock(size=2e6)
|
|
||||||
lora_linear.parameters.return_value = [lora_linear.weight]
|
|
||||||
|
|
||||||
linear = MagicMock(spec=nn.Linear)
|
|
||||||
linear.weight = MagicMock(size=3e6)
|
|
||||||
linear.parameters.return_value = [linear.weight]
|
|
||||||
|
|
||||||
model.leaf_modules.return_value = {
|
|
||||||
"quantized_linear": quantized_linear,
|
|
||||||
"lora_linear": lora_linear,
|
|
||||||
"linear": linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
model.trainable_parameters.return_value = {
|
|
||||||
"layer1.weight": MagicMock(size=1e6),
|
|
||||||
"layer3.weight": MagicMock(size=2e6),
|
|
||||||
}
|
|
||||||
expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n"
|
|
||||||
lora.print_trainable_parameters(model)
|
|
||||||
self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits)
|
|
||||||
self.capturedOutput.truncate(0)
|
|
||||||
self.capturedOutput.seek(0)
|
|
||||||
|
|
||||||
quantized_linear.weight = MagicMock(size=1e6)
|
|
||||||
quantized_linear.bits = 4
|
|
||||||
expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n"
|
|
||||||
lora.print_trainable_parameters(model)
|
|
||||||
self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits)
|
|
||||||
self.capturedOutput.truncate(0)
|
|
||||||
self.capturedOutput.seek(0)
|
|
||||||
|
|
||||||
def test_print_trainable_parameters(self):
|
|
||||||
model = MagicMock()
|
|
||||||
linear1 = MagicMock(spec=nn.Linear)
|
|
||||||
linear1.weight = MagicMock(size=1e6)
|
|
||||||
linear1.parameters.return_value = [linear1.weight]
|
|
||||||
linear2 = MagicMock(spec=nn.Linear)
|
|
||||||
linear2.weight = MagicMock(size=2e6)
|
|
||||||
linear2.parameters.return_value = [linear2.weight]
|
|
||||||
lora_linear = MagicMock(spec=LoRALinear)
|
|
||||||
lora_linear.weight = MagicMock(size=3e6)
|
|
||||||
lora_linear.parameters.return_value = [lora_linear.weight]
|
|
||||||
model.leaf_modules.return_value = {
|
|
||||||
"linear1": linear1,
|
|
||||||
"linear2": linear2,
|
|
||||||
"lora_linear": lora_linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
model.trainable_parameters.return_value = {
|
|
||||||
"layer1.weight": MagicMock(size=1e6),
|
|
||||||
"layer3.weight": MagicMock(size=2e6),
|
|
||||||
}
|
|
||||||
expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n"
|
|
||||||
lora.print_trainable_parameters(model)
|
|
||||||
self.assertEqual(self.capturedOutput.getvalue(), expected_output)
|
|
||||||
|
|
||||||
|
|
||||||
class TestScheduleConfig(unittest.TestCase):
|
class TestScheduleConfig(unittest.TestCase):
|
||||||
def test_join(self):
|
def test_join(self):
|
||||||
|
85
llms/tests/test_tuner_utils.py
Normal file
85
llms/tests/test_tuner_utils.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
from io import StringIO
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx_lm.tuner.lora import LoRALinear
|
||||||
|
from mlx_lm.tuner.utils import print_trainable_parameters
|
||||||
|
|
||||||
|
|
||||||
|
class TestTunerUtils(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.capturedOutput = StringIO()
|
||||||
|
sys.stdout = self.capturedOutput
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
sys.stdout = sys.__stdout__
|
||||||
|
|
||||||
|
def test_quantized_print_trainable_parameters(self):
|
||||||
|
model = MagicMock()
|
||||||
|
quantized_linear = MagicMock(spec=nn.QuantizedLinear)
|
||||||
|
quantized_linear.weight = MagicMock(size=1e6)
|
||||||
|
quantized_linear.bits = 8
|
||||||
|
lora_linear = MagicMock(spec=LoRALinear)
|
||||||
|
lora_linear.weight = MagicMock(size=2e6)
|
||||||
|
lora_linear.parameters.return_value = [lora_linear.weight]
|
||||||
|
|
||||||
|
linear = MagicMock(spec=nn.Linear)
|
||||||
|
linear.weight = MagicMock(size=3e6)
|
||||||
|
linear.parameters.return_value = [linear.weight]
|
||||||
|
|
||||||
|
model.leaf_modules.return_value = {
|
||||||
|
"quantized_linear": quantized_linear,
|
||||||
|
"lora_linear": lora_linear,
|
||||||
|
"linear": linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
model.trainable_parameters.return_value = {
|
||||||
|
"layer1.weight": MagicMock(size=1e6),
|
||||||
|
"layer3.weight": MagicMock(size=2e6),
|
||||||
|
}
|
||||||
|
expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n"
|
||||||
|
print_trainable_parameters(model)
|
||||||
|
self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits)
|
||||||
|
self.capturedOutput.truncate(0)
|
||||||
|
self.capturedOutput.seek(0)
|
||||||
|
|
||||||
|
quantized_linear.weight = MagicMock(size=1e6)
|
||||||
|
quantized_linear.bits = 4
|
||||||
|
expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n"
|
||||||
|
print_trainable_parameters(model)
|
||||||
|
self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits)
|
||||||
|
self.capturedOutput.truncate(0)
|
||||||
|
self.capturedOutput.seek(0)
|
||||||
|
|
||||||
|
def test_print_trainable_parameters(self):
|
||||||
|
model = MagicMock()
|
||||||
|
linear1 = MagicMock(spec=nn.Linear)
|
||||||
|
linear1.weight = MagicMock(size=1e6)
|
||||||
|
linear1.parameters.return_value = [linear1.weight]
|
||||||
|
linear2 = MagicMock(spec=nn.Linear)
|
||||||
|
linear2.weight = MagicMock(size=2e6)
|
||||||
|
linear2.parameters.return_value = [linear2.weight]
|
||||||
|
lora_linear = MagicMock(spec=LoRALinear)
|
||||||
|
lora_linear.weight = MagicMock(size=3e6)
|
||||||
|
lora_linear.parameters.return_value = [lora_linear.weight]
|
||||||
|
model.leaf_modules.return_value = {
|
||||||
|
"linear1": linear1,
|
||||||
|
"linear2": linear2,
|
||||||
|
"lora_linear": lora_linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
model.trainable_parameters.return_value = {
|
||||||
|
"layer1.weight": MagicMock(size=1e6),
|
||||||
|
"layer3.weight": MagicMock(size=2e6),
|
||||||
|
}
|
||||||
|
expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n"
|
||||||
|
print_trainable_parameters(model)
|
||||||
|
self.assertEqual(self.capturedOutput.getvalue(), expected_output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user