Enable distributed LoRA training (#821)

This commit is contained in:
Angelos Katharopoulos 2024-11-02 18:02:31 -07:00 committed by GitHub
parent 29c954f4cb
commit 331148d8ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 86 additions and 46 deletions

View File

@ -10,6 +10,7 @@ from typing import Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
@ -84,9 +85,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
f" examples but only has {len(dataset)}." f" examples but only has {len(dataset)}."
) )
# If running in distributed mode (N machines) then each one should skip N-1
# samples
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Make the batches: # Make the batches:
batch_idx = [ batch_idx = [
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size) idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
] ]
while True: while True:
@ -112,9 +120,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length) max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
for j in range(batch_size): for j in range(batch_size // step):
truncated_length = min(lengths[j], max_seq_length) truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, :truncated_length] = batch[j][:truncated_length] batch_arr[j, :truncated_length] = batch[j][:truncated_length]
lengths[j] = ( lengths[j] = (
@ -138,7 +146,7 @@ def evaluate(
loss: callable = default_loss, loss: callable = default_loss,
iterate_batches: callable = iterate_batches, iterate_batches: callable = iterate_batches,
): ):
all_losses = [] all_losses = 0
ntokens = 0 ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@ -153,10 +161,14 @@ def evaluate(
), ),
): ):
losses, toks = loss(model, *batch) losses, toks = loss(model, *batch)
all_losses.append((losses * toks).item()) all_losses += losses * toks
ntokens += toks.item() ntokens += toks
mx.eval(all_losses, ntokens)
return np.sum(all_losses) / ntokens all_losses = mx.distributed.all_sum(all_losses)
ntokens = mx.distributed.all_sum(ntokens)
return (all_losses / ntokens).item()
class TrainingCallback: class TrainingCallback:
@ -182,6 +194,11 @@ def train(
training_callback: TrainingCallback = None, training_callback: TrainingCallback = None,
): ):
print(f"Starting training..., iters: {args.iters}") print(f"Starting training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint: if args.grad_checkpoint:
grad_checkpoint(model.layers[0]) grad_checkpoint(model.layers[0])
@ -192,6 +209,9 @@ def train(
# Forward and backward pass # Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch) (lvalue, toks), grad = loss_value_and_grad(model, *batch)
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)
# Model update # Model update
optimizer.update(model, grad) optimizer.update(model, grad)
@ -199,8 +219,9 @@ def train(
loss_value_and_grad = nn.value_and_grad(model, loss) loss_value_and_grad = nn.value_and_grad(model, loss)
losses = [] losses = 0
n_tokens = 0 n_tokens = 0
steps = 0
trained_tokens = 0 trained_tokens = 0
# Main training loop # Main training loop
start = time.perf_counter() start = time.perf_counter()
@ -229,8 +250,12 @@ def train(
iterate_batches=iterate_batches, iterate_batches=iterate_batches,
) )
val_time = time.perf_counter() - stop val_time = time.perf_counter() - stop
if rank == 0:
print( print(
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s" f"Iter {it}: "
f"Val loss {val_loss:.3f}, "
f"Val took {val_time:.3f}s",
flush=True,
) )
if training_callback is not None: if training_callback is not None:
@ -244,29 +269,32 @@ def train(
start = time.perf_counter() start = time.perf_counter()
lvalue, toks = step(batch) lvalue, toks = step(batch)
mx.eval(state, lvalue, toks) losses += lvalue
n_tokens += toks
# Record loss steps += 1
losses.append(lvalue.item()) mx.eval(state, losses, n_tokens)
n_tokens += toks.item()
# Report training loss if needed # Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters: if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter() stop = time.perf_counter()
train_loss = np.mean(losses) train_loss = mx.distributed.all_sum(losses).item()
train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item() learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start) tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 2**30 peak_mem = mx.metal.get_peak_memory() / 2**30
if rank == 0:
print( print(
f"Iter {it}: Train loss {train_loss:.3f}, " f"Iter {it}: Train loss {train_loss:.3f}, "
f"Learning Rate {learning_rate:.3e}, " f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, " f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, " f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB" f"Peak mem {peak_mem:.3f} GB",
flush=True,
) )
if training_callback is not None: if training_callback is not None:
@ -281,8 +309,9 @@ def train(
} }
training_callback.on_train_loss_report(train_info) training_callback.on_train_loss_report(train_info)
losses = [] losses = 0
n_tokens = 0 n_tokens = 0
steps = 0
start = time.perf_counter() start = time.perf_counter()
# Save adapter weights # Save adapter weights

View File

@ -3,6 +3,7 @@
import math import math
import sys import sys
import unittest import unittest
from contextlib import contextmanager
from io import StringIO from io import StringIO
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate
from mlx_lm.tuner.utils import build_schedule from mlx_lm.tuner.utils import build_schedule
@contextmanager
def swapped_with_identity(obj, func):
old_func = getattr(obj, func)
setattr(obj, func, lambda x: x)
yield
setattr(obj, func, old_func)
class TestLora(unittest.TestCase): class TestLora(unittest.TestCase):
def setUp(self): def setUp(self):
self.capturedOutput = StringIO() self.capturedOutput = StringIO()
@ -374,6 +383,7 @@ class TestScheduleConfig(unittest.TestCase):
(MagicMock(return_value=0.4), MagicMock(return_value=180)), (MagicMock(return_value=0.4), MagicMock(return_value=180)),
(MagicMock(return_value=0.6), MagicMock(return_value=120)), (MagicMock(return_value=0.6), MagicMock(return_value=120)),
] ]
with swapped_with_identity(mx.distributed, "all_sum"):
evaluate( evaluate(
model=mock_model, model=mock_model,
dataset=mock_dataset, dataset=mock_dataset,
@ -412,6 +422,7 @@ class TestScheduleConfig(unittest.TestCase):
(MagicMock(return_value=0.2), MagicMock(return_value=150)), (MagicMock(return_value=0.2), MagicMock(return_value=150)),
] ]
with swapped_with_identity(mx.distributed, "all_sum"):
evaluate( evaluate(
model=mock_model, model=mock_model,
dataset=mock_dataset, dataset=mock_dataset,