Add distributed option for lora training

This commit is contained in:
Angelos Katharopoulos 2024-05-31 18:42:22 -07:00
parent 9f34fdbda4
commit 4786b4e3eb

View File

@ -10,7 +10,7 @@ from typing import Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten
from mlx.utils import tree_flatten, tree_map
def grad_checkpoint(layer):
@ -29,6 +29,17 @@ def grad_checkpoint(layer):
type(layer).__call__ = checkpointed_fn
def average_gradients(gradients):
world_size = mx.distributed.init().size()
if world_size == 1:
return gradients
def _all_average(x):
return mx.distributed.all_sum(x) / world_size
return tree_map(_all_average, gradients)
@dataclass
class TrainingArgs:
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
@ -84,9 +95,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
f" examples but only has {len(dataset)}."
)
# If running in distributed mode (N machines) then each one should skip N-1
# samples
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Make the batches:
batch_idx = [
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size)
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
@ -112,9 +130,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
for j in range(batch_size):
for j in range(batch_size // step):
truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
lengths[j] = (
@ -138,7 +156,7 @@ def evaluate(
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
):
all_losses = []
all_losses = 0
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@ -153,10 +171,14 @@ def evaluate(
),
):
losses, toks = loss(model, *batch)
all_losses.append((losses * toks).item())
ntokens += toks.item()
all_losses += losses * toks
ntokens += toks
mx.eval(all_losses, ntokens)
return np.sum(all_losses) / ntokens
all_losses = mx.distributed.all_sum(all_losses)
ntokens = mx.distributed.all_sum(ntokens)
return (all_losses / ntokens).item()
class TrainingCallback:
@ -192,6 +214,9 @@ def train(
# Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)
# Model update
optimizer.update(model, grad)