From 4786b4e3eb84a30428aed5beb4b3d50dfa4fc1d9 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 31 May 2024 18:42:22 -0700 Subject: [PATCH] Add distributed option for lora training --- llms/mlx_lm/tuner/trainer.py | 41 +++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 1d934a72..31f845b2 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -10,7 +10,7 @@ from typing import Union import mlx.core as mx import mlx.nn as nn import numpy as np -from mlx.utils import tree_flatten +from mlx.utils import tree_flatten, tree_map def grad_checkpoint(layer): @@ -29,6 +29,17 @@ def grad_checkpoint(layer): type(layer).__call__ = checkpointed_fn +def average_gradients(gradients): + world_size = mx.distributed.init().size() + if world_size == 1: + return gradients + + def _all_average(x): + return mx.distributed.all_sum(x) / world_size + + return tree_map(_all_average, gradients) + + @dataclass class TrainingArgs: batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) @@ -84,9 +95,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) f" examples but only has {len(dataset)}." ) + # If running in distributed mode (N machines) then each one should skip N-1 + # samples + step = mx.distributed.init().size() + if batch_size % step != 0: + raise ValueError("The batch size must be divisible by the number of workers") + # Make the batches: batch_idx = [ - idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size) + idx[i : i + batch_size : step] + for i in range(0, len(idx) - batch_size + 1, batch_size) ] while True: @@ -112,9 +130,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) max_length_in_batch = min(max_length_in_batch, max_seq_length) - batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) + batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - for j in range(batch_size): + for j in range(batch_size // step): truncated_length = min(lengths[j], max_seq_length) batch_arr[j, :truncated_length] = batch[j][:truncated_length] lengths[j] = ( @@ -138,7 +156,7 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = [] + all_losses = 0 ntokens = 0 index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -153,10 +171,14 @@ def evaluate( ), ): losses, toks = loss(model, *batch) - all_losses.append((losses * toks).item()) - ntokens += toks.item() + all_losses += losses * toks + ntokens += toks + mx.eval(all_losses, ntokens) - return np.sum(all_losses) / ntokens + all_losses = mx.distributed.all_sum(all_losses) + ntokens = mx.distributed.all_sum(ntokens) + + return (all_losses / ntokens).item() class TrainingCallback: @@ -192,6 +214,9 @@ def train( # Forward and backward pass (lvalue, toks), grad = loss_value_and_grad(model, *batch) + # All reduce the gradients if running in distributed mode + grad = average_gradients(grad) + # Model update optimizer.update(model, grad)