From 331148d8ec05ce2f1dd50444570c61805b700039 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 2 Nov 2024 18:02:31 -0700 Subject: [PATCH] Enable distributed LoRA training (#821) --- llms/mlx_lm/tuner/trainer.py | 81 ++++++++++++++++++++++++------------ llms/tests/test_finetune.py | 51 ++++++++++++++--------- 2 files changed, 86 insertions(+), 46 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 1d934a72..38619d95 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -10,6 +10,7 @@ from typing import Union import mlx.core as mx import mlx.nn as nn import numpy as np +from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten @@ -84,9 +85,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) f" examples but only has {len(dataset)}." ) + # If running in distributed mode (N machines) then each one should skip N-1 + # samples + step = mx.distributed.init().size() + if batch_size % step != 0: + raise ValueError("The batch size must be divisible by the number of workers") + # Make the batches: batch_idx = [ - idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size) + idx[i : i + batch_size : step] + for i in range(0, len(idx) - batch_size + 1, batch_size) ] while True: @@ -112,9 +120,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) max_length_in_batch = min(max_length_in_batch, max_seq_length) - batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) + batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - for j in range(batch_size): + for j in range(batch_size // step): truncated_length = min(lengths[j], max_seq_length) batch_arr[j, :truncated_length] = batch[j][:truncated_length] lengths[j] = ( @@ -138,7 +146,7 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = [] + all_losses = 0 ntokens = 0 index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -153,10 +161,14 @@ def evaluate( ), ): losses, toks = loss(model, *batch) - all_losses.append((losses * toks).item()) - ntokens += toks.item() + all_losses += losses * toks + ntokens += toks + mx.eval(all_losses, ntokens) - return np.sum(all_losses) / ntokens + all_losses = mx.distributed.all_sum(all_losses) + ntokens = mx.distributed.all_sum(ntokens) + + return (all_losses / ntokens).item() class TrainingCallback: @@ -182,6 +194,11 @@ def train( training_callback: TrainingCallback = None, ): print(f"Starting training..., iters: {args.iters}") + world = mx.distributed.init() + world_size = world.size() + rank = world.rank() + if world_size > 1: + print(f"Node {rank} of {world_size}") if args.grad_checkpoint: grad_checkpoint(model.layers[0]) @@ -192,6 +209,9 @@ def train( # Forward and backward pass (lvalue, toks), grad = loss_value_and_grad(model, *batch) + # All reduce the gradients if running in distributed mode + grad = average_gradients(grad) + # Model update optimizer.update(model, grad) @@ -199,8 +219,9 @@ def train( loss_value_and_grad = nn.value_and_grad(model, loss) - losses = [] + losses = 0 n_tokens = 0 + steps = 0 trained_tokens = 0 # Main training loop start = time.perf_counter() @@ -229,9 +250,13 @@ def train( iterate_batches=iterate_batches, ) val_time = time.perf_counter() - stop - print( - f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s" - ) + if rank == 0: + print( + f"Iter {it}: " + f"Val loss {val_loss:.3f}, " + f"Val took {val_time:.3f}s", + flush=True, + ) if training_callback is not None: val_info = { @@ -244,30 +269,33 @@ def train( start = time.perf_counter() lvalue, toks = step(batch) - mx.eval(state, lvalue, toks) - - # Record loss - losses.append(lvalue.item()) - n_tokens += toks.item() + losses += lvalue + n_tokens += toks + steps += 1 + mx.eval(state, losses, n_tokens) # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() - train_loss = np.mean(losses) + train_loss = mx.distributed.all_sum(losses).item() + train_loss /= steps * mx.distributed.init().size() + n_tokens = mx.distributed.all_sum(n_tokens).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 2**30 - print( - f"Iter {it}: Train loss {train_loss:.3f}, " - f"Learning Rate {learning_rate:.3e}, " - f"It/sec {it_sec:.3f}, " - f"Tokens/sec {tokens_sec:.3f}, " - f"Trained Tokens {trained_tokens}, " - f"Peak mem {peak_mem:.3f} GB" - ) + if rank == 0: + print( + f"Iter {it}: Train loss {train_loss:.3f}, " + f"Learning Rate {learning_rate:.3e}, " + f"It/sec {it_sec:.3f}, " + f"Tokens/sec {tokens_sec:.3f}, " + f"Trained Tokens {trained_tokens}, " + f"Peak mem {peak_mem:.3f} GB", + flush=True, + ) if training_callback is not None: train_info = { @@ -281,8 +309,9 @@ def train( } training_callback.on_train_loss_report(train_info) - losses = [] + losses = 0 n_tokens = 0 + steps = 0 start = time.perf_counter() # Save adapter weights diff --git a/llms/tests/test_finetune.py b/llms/tests/test_finetune.py index 107be092..6ba81628 100644 --- a/llms/tests/test_finetune.py +++ b/llms/tests/test_finetune.py @@ -3,6 +3,7 @@ import math import sys import unittest +from contextlib import contextmanager from io import StringIO from unittest.mock import MagicMock @@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate from mlx_lm.tuner.utils import build_schedule +@contextmanager +def swapped_with_identity(obj, func): + old_func = getattr(obj, func) + setattr(obj, func, lambda x: x) + yield + setattr(obj, func, old_func) + + class TestLora(unittest.TestCase): def setUp(self): self.capturedOutput = StringIO() @@ -374,16 +383,17 @@ class TestScheduleConfig(unittest.TestCase): (MagicMock(return_value=0.4), MagicMock(return_value=180)), (MagicMock(return_value=0.6), MagicMock(return_value=120)), ] - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=2, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) + with swapped_with_identity(mx.distributed, "all_sum"): + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=2, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) mock_iterate_batches.assert_called_once_with( dataset=mock_dataset, @@ -412,16 +422,17 @@ class TestScheduleConfig(unittest.TestCase): (MagicMock(return_value=0.2), MagicMock(return_value=150)), ] - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=-1, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) + with swapped_with_identity(mx.distributed, "all_sum"): + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=-1, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) mock_iterate_batches.assert_called_once_with( dataset=mock_dataset,