Some improvements to LoRA (#528)

* set cache_limit

* remove set cache_limit

* cleanup

* add gradient checkpointing

* fix sort

* mokey patch call for checkpoint

* fix example config
This commit is contained in:
Awni Hannun 2024-03-12 20:02:03 -07:00 committed by GitHub
parent e56d9015ef
commit 39084e81c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 68 additions and 25 deletions

View File

@ -1,6 +1,5 @@
# The path to the local model directory or Hugging Face repo. # The path to the local model directory or Hugging Face repo.
model: "mlx_model" model: "mlx_model"
# Whether or not to train (boolean) # Whether or not to train (boolean)
train: true train: true
@ -49,6 +48,9 @@ test_batches: 500
# Maximum sequence length. # Maximum sequence length.
max_seq_length: 2048 max_seq_length: 2048
# Use gradient checkpointing to reduce memory use.
grad_checkpoint: false
# LoRA parameters can only be specified in a config file # LoRA parameters can only be specified in a config file
lora_parameters: lora_parameters:
# The layer keys to apply LoRA to. # The layer keys to apply LoRA to.

View File

@ -145,7 +145,12 @@ def build_parser():
default=None, default=None,
help="A YAML configuration file with the training options", help="A YAML configuration file with the training options",
) )
parser.add_argument("--seed", type=int, help="The PRNG seed") parser.add_argument(
"--grad-checkpoint",
action="store_true",
help="Use gradient checkpointing to reduce memory use.",
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
return parser return parser
@ -222,6 +227,7 @@ def run(args, training_callback: TrainingCallback = None):
steps_per_save=args.save_every, steps_per_save=args.save_every,
adapter_file=args.adapter_file, adapter_file=args.adapter_file,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
) )
if args.train: if args.train:
print("Training") print("Training")

View File

@ -2,6 +2,7 @@
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial
from pathlib import Path from pathlib import Path
import mlx.core as mx import mlx.core as mx
@ -10,6 +11,22 @@ import numpy as np
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
def grad_checkpoint(layer):
"""
Update all instances of type(layer) to use gradient checkpointing.
"""
fn = type(layer).__call__
def checkpointed_fn(model, *args, **kwargs):
def inner_fn(params, *args, **kwargs):
model.update(params)
return fn(model, *args, **kwargs)
return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs)
type(layer).__call__ = checkpointed_fn
@dataclass @dataclass
class TrainingArgs: class TrainingArgs:
lora_layers: int = field( lora_layers: int = field(
@ -40,6 +57,10 @@ class TrainingArgs:
default="adapter.npz", default="adapter.npz",
metadata={"help": "Save/load path for the trained adapter weights."}, metadata={"help": "Save/load path for the trained adapter weights."},
) )
grad_checkpoint: bool = field(
default=False,
metadata={"help": "Use gradient checkpointing to reduce memory use."},
)
def default_loss(model, inputs, targets, lengths): def default_loss(model, inputs, targets, lengths):
@ -56,16 +77,19 @@ def default_loss(model, inputs, targets, lengths):
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
# Make the batches:
batch_idx = [
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True: while True:
# Shuffle indices indices = np.random.permutation(len(batch_idx))
indices = np.arange(len(dataset)) for i in indices:
indices = np.random.permutation(indices)
# Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size):
# Encode batch # Encode batch
batch = [ batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
tokenizer.encode(dataset[indices[i + j]]) for j in range(batch_size)
]
lengths = [len(x) for x in batch] lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length: if max(lengths) > max_seq_length:
@ -75,8 +99,11 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
"Consider pre-splitting your data to save memory." "Consider pre-splitting your data to save memory."
) )
# Pad to the max length # Pad to the nearest multiple of 8 or the maximum length
max_length_in_batch = min(max(lengths), max_seq_length) pad_to = 8
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
for j in range(batch_size): for j in range(batch_size):
@ -157,7 +184,21 @@ def train(
# Create checkpoints directory if it does not exist # Create checkpoints directory if it does not exist
adapter_path = checkpoints_path(args.adapter_file) adapter_path = checkpoints_path(args.adapter_file)
# Create value and grad function for loss if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(batch):
# Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
# Model update
optimizer.update(model, grad)
return lvalue, toks
loss_value_and_grad = nn.value_and_grad(model, loss) loss_value_and_grad = nn.value_and_grad(model, loss)
losses = [] losses = []
@ -175,13 +216,8 @@ def train(
train=True, train=True,
), ),
): ):
# Forward and backward pass lvalue, toks = step(batch)
(lvalue, toks), grad = loss_value_and_grad(model, *batch) mx.eval(state, lvalue, toks)
# Model update
optimizer.update(model, grad)
mx.eval(model.parameters(), optimizer.state, lvalue)
# Record loss # Record loss
losses.append(lvalue.item()) losses.append(lvalue.item())
@ -196,12 +232,14 @@ def train(
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start) tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 2**30
print( print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, " f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"Learning Rate {learning_rate:.3e}, " f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, " f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}" f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB"
) )
if training_callback is not None: if training_callback is not None:

View File

@ -35,9 +35,6 @@ def linear_to_lora_layers(
lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"] lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"]
) )
# If the lora_parameters are set, we assume the keys
# are correct for the given model
keys = config.get("keys", None) keys = config.get("keys", None)
if keys is not None: if keys is not None:
keys = set(keys) keys = set(keys)
@ -53,7 +50,7 @@ def linear_to_lora_layers(
]: ]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"]) keys = set(["self_attn.q_proj", "self_attn.v_proj"])
if model.model_type == "mixtral": if model.model_type == "mixtral":
keys.add(["block_sparse_moe.gate"]) keys.add("block_sparse_moe.gate")
elif model.model_type == "olmo": elif model.model_type == "olmo":
keys = set(["att_proj"]) keys = set(["att_proj"])
elif model.model_type == "phi-msft": elif model.model_type == "phi-msft":