Some improvements to LoRA (#528)

* set cache_limit

* remove set cache_limit

* cleanup

* add gradient checkpointing

* fix sort

* mokey patch call for checkpoint

* fix example config
This commit is contained in:
Awni Hannun
2024-03-12 20:02:03 -07:00
committed by GitHub
parent e56d9015ef
commit 39084e81c2
4 changed files with 68 additions and 25 deletions

View File

@@ -2,6 +2,7 @@
import time
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
import mlx.core as mx
@@ -10,6 +11,22 @@ import numpy as np
from mlx.utils import tree_flatten
def grad_checkpoint(layer):
"""
Update all instances of type(layer) to use gradient checkpointing.
"""
fn = type(layer).__call__
def checkpointed_fn(model, *args, **kwargs):
def inner_fn(params, *args, **kwargs):
model.update(params)
return fn(model, *args, **kwargs)
return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs)
type(layer).__call__ = checkpointed_fn
@dataclass
class TrainingArgs:
lora_layers: int = field(
@@ -40,6 +57,10 @@ class TrainingArgs:
default="adapter.npz",
metadata={"help": "Save/load path for the trained adapter weights."},
)
grad_checkpoint: bool = field(
default=False,
metadata={"help": "Use gradient checkpointing to reduce memory use."},
)
def default_loss(model, inputs, targets, lengths):
@@ -56,16 +77,19 @@ def default_loss(model, inputs, targets, lengths):
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
# Make the batches:
batch_idx = [
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
# Shuffle indices
indices = np.arange(len(dataset))
indices = np.random.permutation(indices)
# Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size):
indices = np.random.permutation(len(batch_idx))
for i in indices:
# Encode batch
batch = [
tokenizer.encode(dataset[indices[i + j]]) for j in range(batch_size)
]
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
@@ -75,8 +99,11 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
"Consider pre-splitting your data to save memory."
)
# Pad to the max length
max_length_in_batch = min(max(lengths), max_seq_length)
# Pad to the nearest multiple of 8 or the maximum length
pad_to = 8
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
for j in range(batch_size):
@@ -157,7 +184,21 @@ def train(
# Create checkpoints directory if it does not exist
adapter_path = checkpoints_path(args.adapter_file)
# Create value and grad function for loss
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(batch):
# Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
# Model update
optimizer.update(model, grad)
return lvalue, toks
loss_value_and_grad = nn.value_and_grad(model, loss)
losses = []
@@ -175,13 +216,8 @@ def train(
train=True,
),
):
# Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
# Model update
optimizer.update(model, grad)
mx.eval(model.parameters(), optimizer.state, lvalue)
lvalue, toks = step(batch)
mx.eval(state, lvalue, toks)
# Record loss
losses.append(lvalue.item())
@@ -196,12 +232,14 @@ def train(
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 2**30
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}"
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB"
)
if training_callback is not None: