Some improvements to LoRA (#528)

* set cache_limit

* remove set cache_limit

* cleanup

* add gradient checkpointing

* fix sort

* mokey patch call for checkpoint

* fix example config
This commit is contained in:
Awni Hannun 2024-03-12 20:02:03 -07:00 committed by GitHub
parent e56d9015ef
commit 39084e81c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 68 additions and 25 deletions

View File

@ -1,6 +1,5 @@
# The path to the local model directory or Hugging Face repo.
model: "mlx_model"
# Whether or not to train (boolean)
train: true
@ -49,6 +48,9 @@ test_batches: 500
# Maximum sequence length.
max_seq_length: 2048
# Use gradient checkpointing to reduce memory use.
grad_checkpoint: false
# LoRA parameters can only be specified in a config file
lora_parameters:
# The layer keys to apply LoRA to.

View File

@ -145,7 +145,12 @@ def build_parser():
default=None,
help="A YAML configuration file with the training options",
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
parser.add_argument(
"--grad-checkpoint",
action="store_true",
help="Use gradient checkpointing to reduce memory use.",
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
return parser
@ -222,6 +227,7 @@ def run(args, training_callback: TrainingCallback = None):
steps_per_save=args.save_every,
adapter_file=args.adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)
if args.train:
print("Training")

View File

@ -2,6 +2,7 @@
import time
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
import mlx.core as mx
@ -10,6 +11,22 @@ import numpy as np
from mlx.utils import tree_flatten
def grad_checkpoint(layer):
"""
Update all instances of type(layer) to use gradient checkpointing.
"""
fn = type(layer).__call__
def checkpointed_fn(model, *args, **kwargs):
def inner_fn(params, *args, **kwargs):
model.update(params)
return fn(model, *args, **kwargs)
return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs)
type(layer).__call__ = checkpointed_fn
@dataclass
class TrainingArgs:
lora_layers: int = field(
@ -40,6 +57,10 @@ class TrainingArgs:
default="adapter.npz",
metadata={"help": "Save/load path for the trained adapter weights."},
)
grad_checkpoint: bool = field(
default=False,
metadata={"help": "Use gradient checkpointing to reduce memory use."},
)
def default_loss(model, inputs, targets, lengths):
@ -56,16 +77,19 @@ def default_loss(model, inputs, targets, lengths):
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
# Make the batches:
batch_idx = [
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
# Shuffle indices
indices = np.arange(len(dataset))
indices = np.random.permutation(indices)
# Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size):
indices = np.random.permutation(len(batch_idx))
for i in indices:
# Encode batch
batch = [
tokenizer.encode(dataset[indices[i + j]]) for j in range(batch_size)
]
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
@ -75,8 +99,11 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
"Consider pre-splitting your data to save memory."
)
# Pad to the max length
max_length_in_batch = min(max(lengths), max_seq_length)
# Pad to the nearest multiple of 8 or the maximum length
pad_to = 8
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
for j in range(batch_size):
@ -157,7 +184,21 @@ def train(
# Create checkpoints directory if it does not exist
adapter_path = checkpoints_path(args.adapter_file)
# Create value and grad function for loss
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(batch):
# Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
# Model update
optimizer.update(model, grad)
return lvalue, toks
loss_value_and_grad = nn.value_and_grad(model, loss)
losses = []
@ -175,13 +216,8 @@ def train(
train=True,
),
):
# Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
# Model update
optimizer.update(model, grad)
mx.eval(model.parameters(), optimizer.state, lvalue)
lvalue, toks = step(batch)
mx.eval(state, lvalue, toks)
# Record loss
losses.append(lvalue.item())
@ -196,12 +232,14 @@ def train(
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 2**30
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}"
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB"
)
if training_callback is not None:

View File

@ -35,9 +35,6 @@ def linear_to_lora_layers(
lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"]
)
# If the lora_parameters are set, we assume the keys
# are correct for the given model
keys = config.get("keys", None)
if keys is not None:
keys = set(keys)
@ -53,7 +50,7 @@ def linear_to_lora_layers(
]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
if model.model_type == "mixtral":
keys.add(["block_sparse_moe.gate"])
keys.add("block_sparse_moe.gate")
elif model.model_type == "olmo":
keys = set(["att_proj"])
elif model.model_type == "phi-msft":