mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
feat: move lora into mlx-lm (#337)
* feat: Add lora and qlora training to mlx-lm --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
0
llms/mlx_lm/tuner/__init__.py
Normal file
0
llms/mlx_lm/tuner/__init__.py
Normal file
88
llms/mlx_lm/tuner/lora.py
Normal file
88
llms/mlx_lm/tuner/lora.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_linear(linear: nn.Linear, rank: int = 8, scale: float = 20.0):
|
||||
# TODO remove when input_dims and output_dims are attributes
|
||||
# on linear and quantized linear
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
if isinstance(linear, nn.QuantizedLinear):
|
||||
input_dims *= 32 // linear.bits
|
||||
lora_lin = LoRALinear(
|
||||
input_dims=input_dims, output_dims=output_dims, rank=rank, scale=scale
|
||||
)
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
|
||||
def to_linear(self):
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = linear.weight
|
||||
is_quantized = isinstance(linear, nn.QuantizedLinear)
|
||||
|
||||
# Use the same type as the linear weight if not quantized
|
||||
dtype = weight.dtype
|
||||
|
||||
if is_quantized:
|
||||
dtype = mx.float16
|
||||
weight = mx.dequantize(
|
||||
weight,
|
||||
linear.scales,
|
||||
linear.biases,
|
||||
linear.group_size,
|
||||
linear.bits,
|
||||
)
|
||||
output_dims, input_dims = weight.shape
|
||||
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||
lora_a = self.lora_a.T.astype(dtype)
|
||||
fused_linear.weight = weight + lora_b @ lora_a
|
||||
if bias:
|
||||
fused_linear.bias = linear.bias
|
||||
|
||||
if is_quantized:
|
||||
fused_linear = nn.QuantizedLinear.from_linear(
|
||||
fused_linear,
|
||||
linear.group_size,
|
||||
linear.bits,
|
||||
)
|
||||
|
||||
return fused_linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
rank: int = 8,
|
||||
bias: bool = False,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(input_dims, rank),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(rank, output_dims))
|
||||
|
||||
def __call__(self, x):
|
||||
dtype = self.linear.weight.dtype
|
||||
if isinstance(self.linear, nn.QuantizedLinear):
|
||||
dtype = self.linear.scales.dtype
|
||||
y = self.linear(x.astype(dtype))
|
||||
z = (x @ self.lora_a) @ self.lora_b
|
||||
return y + self.scale * z
|
||||
204
llms/mlx_lm/tuner/trainer.py
Normal file
204
llms/mlx_lm/tuner/trainer.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArgs:
|
||||
lora_layers: int = field(
|
||||
default=16, metadata={"help": "Number of layers to fine-tune"}
|
||||
)
|
||||
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
||||
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
|
||||
val_batches: int = field(
|
||||
default=25,
|
||||
metadata={
|
||||
"help": "Number of validation batches, -1 uses the entire validation set."
|
||||
},
|
||||
)
|
||||
steps_per_report: int = field(
|
||||
default=10,
|
||||
metadata={"help": "Number of training steps between loss reporting."},
|
||||
)
|
||||
steps_per_eval: int = field(
|
||||
default=200, metadata={"help": "Number of training steps between validations."}
|
||||
)
|
||||
steps_per_save: int = field(
|
||||
default=100, metadata={"help": "Save the model every number steps"}
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048, metadata={"help": "Maximum sequence length."}
|
||||
)
|
||||
adapter_file: str = field(
|
||||
default="adapter.npz",
|
||||
metadata={"help": "Save/load path for the trained adapter weights."},
|
||||
)
|
||||
|
||||
|
||||
def default_loss(model, inputs, targets, lengths):
|
||||
logits, _ = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
|
||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
||||
ntoks = length_mask.sum()
|
||||
ce = ce.sum() / ntoks
|
||||
|
||||
return ce, ntoks
|
||||
|
||||
|
||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
while True:
|
||||
# Shuffle indices
|
||||
indices = np.arange(len(dataset))
|
||||
indices = np.random.permutation(indices)
|
||||
# Collect batches from dataset
|
||||
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
||||
# Encode batch
|
||||
batch = [
|
||||
tokenizer.encode(dataset[indices[i + j]]) for j in range(batch_size)
|
||||
]
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
# Check if any sequence is longer than max_seq_length
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
"[WARNING] Some sequences are longer than 2048 tokens. "
|
||||
"Consider pre-splitting your data to save memory."
|
||||
)
|
||||
|
||||
# Pad to the max length
|
||||
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
||||
|
||||
for j in range(batch_size):
|
||||
batch_arr[j, : lengths[j]] = batch[j]
|
||||
batch = mx.array(batch_arr)
|
||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
||||
|
||||
def evaluate(
|
||||
model,
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
num_batches,
|
||||
max_seq_length=2048,
|
||||
loss: callable = default_loss,
|
||||
):
|
||||
all_losses = []
|
||||
ntokens = 0
|
||||
for it, batch in zip(
|
||||
range(num_batches),
|
||||
iterate_batches(
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=batch_size,
|
||||
max_seq_length=max_seq_length,
|
||||
),
|
||||
):
|
||||
losses, toks = loss(model, *batch)
|
||||
all_losses.append((losses * toks).item())
|
||||
ntokens += toks.item()
|
||||
|
||||
return np.sum(all_losses) / ntokens
|
||||
|
||||
|
||||
def train(
|
||||
model,
|
||||
tokenizer,
|
||||
optimizer,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
args: TrainingArgs = TrainingArgs(),
|
||||
loss: callable = default_loss,
|
||||
):
|
||||
# Create value and grad function for loss
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||
|
||||
losses = []
|
||||
n_tokens = 0
|
||||
print("Starting training..., iters:", args.iters)
|
||||
# Main training loop
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
range(args.iters),
|
||||
iterate_batches(
|
||||
dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
max_seq_length=args.max_seq_length,
|
||||
train=True,
|
||||
),
|
||||
):
|
||||
# Forward and backward pass
|
||||
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
||||
|
||||
# Model update
|
||||
optimizer.update(model, grad)
|
||||
|
||||
mx.eval(model.parameters(), optimizer.state, lvalue)
|
||||
|
||||
# Record loss
|
||||
losses.append(lvalue.item())
|
||||
n_tokens += toks.item()
|
||||
|
||||
# Report training loss if needed
|
||||
if (it + 1) % args.steps_per_report == 0:
|
||||
train_loss = np.mean(losses)
|
||||
|
||||
stop = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||
f"It/sec {args.steps_per_report / (stop - start):.3f}, "
|
||||
f"Tokens/sec {float(n_tokens) / (stop - start):.3f}"
|
||||
)
|
||||
losses = []
|
||||
n_tokens = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
# Report validation loss if needed
|
||||
if it == 0 or (it + 1) % args.steps_per_eval == 0:
|
||||
stop = time.perf_counter()
|
||||
val_loss = evaluate(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
loss=loss,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
)
|
||||
print(
|
||||
f"Iter {it + 1}: "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val took {(time.perf_counter() - stop):.3f}s"
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save adapter weights if needed
|
||||
if (it + 1) % args.steps_per_save == 0:
|
||||
save_adapter(model=model, adapter_file=args.adapter_file)
|
||||
print(
|
||||
f"Iter {it + 1}: Saved adapter weights to {os.path.join(args.adapter_file)}."
|
||||
)
|
||||
# save final adapter weights
|
||||
save_adapter(model=model, adapter_file=args.adapter_file)
|
||||
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")
|
||||
|
||||
|
||||
def save_adapter(
|
||||
model: nn.Module,
|
||||
adapter_file: str,
|
||||
):
|
||||
flattened_tree = tree_flatten(model.trainable_parameters())
|
||||
|
||||
mx.savez(adapter_file, **dict(flattened_tree))
|
||||
22
llms/mlx_lm/tuner/utils.py
Normal file
22
llms/mlx_lm/tuner/utils.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .lora import LoRALinear
|
||||
|
||||
|
||||
def apply_lora_layers(model, adapter_file: str):
|
||||
adapters = list(mx.load(adapter_file).items())
|
||||
linear_replacements = {}
|
||||
lora_layers = set(
|
||||
[name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters]
|
||||
)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if name in lora_layers:
|
||||
replacement_module = LoRALinear.from_linear(module)
|
||||
linear_replacements[name] = replacement_module
|
||||
|
||||
model.update_modules(tree_unflatten(list(linear_replacements.items())))
|
||||
|
||||
model.update(tree_unflatten(adapters))
|
||||
return model
|
||||
Reference in New Issue
Block a user