diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index 32099e0d..b616aaf4 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -1,6 +1,5 @@ # The path to the local model directory or Hugging Face repo. model: "mlx_model" - # Whether or not to train (boolean) train: true @@ -49,6 +48,9 @@ test_batches: 500 # Maximum sequence length. max_seq_length: 2048 +# Use gradient checkpointing to reduce memory use. +grad_checkpoint: false + # LoRA parameters can only be specified in a config file lora_parameters: # The layer keys to apply LoRA to. diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index e11fed84..36307802 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -145,7 +145,12 @@ def build_parser(): default=None, help="A YAML configuration file with the training options", ) - parser.add_argument("--seed", type=int, help="The PRNG seed") + parser.add_argument( + "--grad-checkpoint", + action="store_true", + help="Use gradient checkpointing to reduce memory use.", + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") return parser @@ -222,6 +227,7 @@ def run(args, training_callback: TrainingCallback = None): steps_per_save=args.save_every, adapter_file=args.adapter_file, max_seq_length=args.max_seq_length, + grad_checkpoint=args.grad_checkpoint, ) if args.train: print("Training") diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index ec1f40a7..47cfe002 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -2,6 +2,7 @@ import time from dataclasses import dataclass, field +from functools import partial from pathlib import Path import mlx.core as mx @@ -10,6 +11,22 @@ import numpy as np from mlx.utils import tree_flatten +def grad_checkpoint(layer): + """ + Update all instances of type(layer) to use gradient checkpointing. + """ + fn = type(layer).__call__ + + def checkpointed_fn(model, *args, **kwargs): + def inner_fn(params, *args, **kwargs): + model.update(params) + return fn(model, *args, **kwargs) + + return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) + + type(layer).__call__ = checkpointed_fn + + @dataclass class TrainingArgs: lora_layers: int = field( @@ -40,6 +57,10 @@ class TrainingArgs: default="adapter.npz", metadata={"help": "Save/load path for the trained adapter weights."}, ) + grad_checkpoint: bool = field( + default=False, + metadata={"help": "Use gradient checkpointing to reduce memory use."}, + ) def default_loss(model, inputs, targets, lengths): @@ -56,16 +77,19 @@ def default_loss(model, inputs, targets, lengths): def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): + # Sort by length: + idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) + + # Make the batches: + batch_idx = [ + idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size) + ] + while True: - # Shuffle indices - indices = np.arange(len(dataset)) - indices = np.random.permutation(indices) - # Collect batches from dataset - for i in range(0, len(indices) - batch_size + 1, batch_size): + indices = np.random.permutation(len(batch_idx)) + for i in indices: # Encode batch - batch = [ - tokenizer.encode(dataset[indices[i + j]]) for j in range(batch_size) - ] + batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]] lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: @@ -75,8 +99,11 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) "Consider pre-splitting your data to save memory." ) - # Pad to the max length - max_length_in_batch = min(max(lengths), max_seq_length) + # Pad to the nearest multiple of 8 or the maximum length + pad_to = 8 + max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) + max_length_in_batch = min(max_length_in_batch, max_seq_length) + batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) for j in range(batch_size): @@ -157,7 +184,21 @@ def train( # Create checkpoints directory if it does not exist adapter_path = checkpoints_path(args.adapter_file) - # Create value and grad function for loss + if args.grad_checkpoint: + grad_checkpoint(model.layers[0]) + + state = [model.state, optimizer.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(batch): + # Forward and backward pass + (lvalue, toks), grad = loss_value_and_grad(model, *batch) + + # Model update + optimizer.update(model, grad) + + return lvalue, toks + loss_value_and_grad = nn.value_and_grad(model, loss) losses = [] @@ -175,13 +216,8 @@ def train( train=True, ), ): - # Forward and backward pass - (lvalue, toks), grad = loss_value_and_grad(model, *batch) - - # Model update - optimizer.update(model, grad) - - mx.eval(model.parameters(), optimizer.state, lvalue) + lvalue, toks = step(batch) + mx.eval(state, lvalue, toks) # Record loss losses.append(lvalue.item()) @@ -196,12 +232,14 @@ def train( it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens + peak_mem = mx.metal.get_peak_memory() / 2**30 print( f"Iter {it + 1}: Train loss {train_loss:.3f}, " f"Learning Rate {learning_rate:.3e}, " f"It/sec {it_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, " - f"Trained Tokens {trained_tokens}" + f"Trained Tokens {trained_tokens}, " + f"Peak mem {peak_mem:.3f} GB" ) if training_callback is not None: diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 355e1699..f4f94fc7 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -35,9 +35,6 @@ def linear_to_lora_layers( lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"] ) - # If the lora_parameters are set, we assume the keys - # are correct for the given model - keys = config.get("keys", None) if keys is not None: keys = set(keys) @@ -53,7 +50,7 @@ def linear_to_lora_layers( ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type == "mixtral": - keys.add(["block_sparse_moe.gate"]) + keys.add("block_sparse_moe.gate") elif model.model_type == "olmo": keys = set(["att_proj"]) elif model.model_type == "phi-msft":