mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Add distributed option for lora training
This commit is contained in:
parent
9f34fdbda4
commit
4786b4e3eb
@ -10,7 +10,7 @@ from typing import Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
|
||||
def grad_checkpoint(layer):
|
||||
@ -29,6 +29,17 @@ def grad_checkpoint(layer):
|
||||
type(layer).__call__ = checkpointed_fn
|
||||
|
||||
|
||||
def average_gradients(gradients):
|
||||
world_size = mx.distributed.init().size()
|
||||
if world_size == 1:
|
||||
return gradients
|
||||
|
||||
def _all_average(x):
|
||||
return mx.distributed.all_sum(x) / world_size
|
||||
|
||||
return tree_map(_all_average, gradients)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArgs:
|
||||
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
||||
@ -84,9 +95,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
f" examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
# If running in distributed mode (N machines) then each one should skip N-1
|
||||
# samples
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
|
||||
# Make the batches:
|
||||
batch_idx = [
|
||||
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
idx[i : i + batch_size : step]
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
]
|
||||
|
||||
while True:
|
||||
@ -112,9 +130,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||
|
||||
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
|
||||
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
|
||||
for j in range(batch_size):
|
||||
for j in range(batch_size // step):
|
||||
truncated_length = min(lengths[j], max_seq_length)
|
||||
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
|
||||
lengths[j] = (
|
||||
@ -138,7 +156,7 @@ def evaluate(
|
||||
loss: callable = default_loss,
|
||||
iterate_batches: callable = iterate_batches,
|
||||
):
|
||||
all_losses = []
|
||||
all_losses = 0
|
||||
ntokens = 0
|
||||
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
@ -153,10 +171,14 @@ def evaluate(
|
||||
),
|
||||
):
|
||||
losses, toks = loss(model, *batch)
|
||||
all_losses.append((losses * toks).item())
|
||||
ntokens += toks.item()
|
||||
all_losses += losses * toks
|
||||
ntokens += toks
|
||||
mx.eval(all_losses, ntokens)
|
||||
|
||||
return np.sum(all_losses) / ntokens
|
||||
all_losses = mx.distributed.all_sum(all_losses)
|
||||
ntokens = mx.distributed.all_sum(ntokens)
|
||||
|
||||
return (all_losses / ntokens).item()
|
||||
|
||||
|
||||
class TrainingCallback:
|
||||
@ -192,6 +214,9 @@ def train(
|
||||
# Forward and backward pass
|
||||
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
||||
|
||||
# All reduce the gradients if running in distributed mode
|
||||
grad = average_gradients(grad)
|
||||
|
||||
# Model update
|
||||
optimizer.update(model, grad)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user