Add distributed option for lora training

This commit is contained in:
Angelos Katharopoulos 2024-05-31 18:42:22 -07:00
parent 9f34fdbda4
commit 4786b4e3eb

View File

@ -10,7 +10,7 @@ from typing import Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from mlx.utils import tree_flatten from mlx.utils import tree_flatten, tree_map
def grad_checkpoint(layer): def grad_checkpoint(layer):
@ -29,6 +29,17 @@ def grad_checkpoint(layer):
type(layer).__call__ = checkpointed_fn type(layer).__call__ = checkpointed_fn
def average_gradients(gradients):
world_size = mx.distributed.init().size()
if world_size == 1:
return gradients
def _all_average(x):
return mx.distributed.all_sum(x) / world_size
return tree_map(_all_average, gradients)
@dataclass @dataclass
class TrainingArgs: class TrainingArgs:
batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
@ -84,9 +95,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
f" examples but only has {len(dataset)}." f" examples but only has {len(dataset)}."
) )
# If running in distributed mode (N machines) then each one should skip N-1
# samples
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Make the batches: # Make the batches:
batch_idx = [ batch_idx = [
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size) idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
] ]
while True: while True:
@ -112,9 +130,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length) max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
for j in range(batch_size): for j in range(batch_size // step):
truncated_length = min(lengths[j], max_seq_length) truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, :truncated_length] = batch[j][:truncated_length] batch_arr[j, :truncated_length] = batch[j][:truncated_length]
lengths[j] = ( lengths[j] = (
@ -138,7 +156,7 @@ def evaluate(
loss: callable = default_loss, loss: callable = default_loss,
iterate_batches: callable = iterate_batches, iterate_batches: callable = iterate_batches,
): ):
all_losses = [] all_losses = 0
ntokens = 0 ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@ -153,10 +171,14 @@ def evaluate(
), ),
): ):
losses, toks = loss(model, *batch) losses, toks = loss(model, *batch)
all_losses.append((losses * toks).item()) all_losses += losses * toks
ntokens += toks.item() ntokens += toks
mx.eval(all_losses, ntokens)
return np.sum(all_losses) / ntokens all_losses = mx.distributed.all_sum(all_losses)
ntokens = mx.distributed.all_sum(ntokens)
return (all_losses / ntokens).item()
class TrainingCallback: class TrainingCallback:
@ -192,6 +214,9 @@ def train(
# Forward and backward pass # Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch) (lvalue, toks), grad = loss_value_and_grad(model, *batch)
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)
# Model update # Model update
optimizer.update(model, grad) optimizer.update(model, grad)