mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Add distributed option for lora training
This commit is contained in:
parent
9f34fdbda4
commit
4786b4e3eb
@ -10,7 +10,7 @@ from typing import Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten, tree_map
|
||||||
|
|
||||||
|
|
||||||
def grad_checkpoint(layer):
|
def grad_checkpoint(layer):
|
||||||
@ -29,6 +29,17 @@ def grad_checkpoint(layer):
|
|||||||
type(layer).__call__ = checkpointed_fn
|
type(layer).__call__ = checkpointed_fn
|
||||||
|
|
||||||
|
|
||||||
|
def average_gradients(gradients):
|
||||||
|
world_size = mx.distributed.init().size()
|
||||||
|
if world_size == 1:
|
||||||
|
return gradients
|
||||||
|
|
||||||
|
def _all_average(x):
|
||||||
|
return mx.distributed.all_sum(x) / world_size
|
||||||
|
|
||||||
|
return tree_map(_all_average, gradients)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArgs:
|
class TrainingArgs:
|
||||||
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
||||||
@ -84,9 +95,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
f" examples but only has {len(dataset)}."
|
f" examples but only has {len(dataset)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If running in distributed mode (N machines) then each one should skip N-1
|
||||||
|
# samples
|
||||||
|
step = mx.distributed.init().size()
|
||||||
|
if batch_size % step != 0:
|
||||||
|
raise ValueError("The batch size must be divisible by the number of workers")
|
||||||
|
|
||||||
# Make the batches:
|
# Make the batches:
|
||||||
batch_idx = [
|
batch_idx = [
|
||||||
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size)
|
idx[i : i + batch_size : step]
|
||||||
|
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@ -112,9 +130,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
||||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||||
|
|
||||||
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
|
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||||
|
|
||||||
for j in range(batch_size):
|
for j in range(batch_size // step):
|
||||||
truncated_length = min(lengths[j], max_seq_length)
|
truncated_length = min(lengths[j], max_seq_length)
|
||||||
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
|
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
|
||||||
lengths[j] = (
|
lengths[j] = (
|
||||||
@ -138,7 +156,7 @@ def evaluate(
|
|||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
iterate_batches: callable = iterate_batches,
|
iterate_batches: callable = iterate_batches,
|
||||||
):
|
):
|
||||||
all_losses = []
|
all_losses = 0
|
||||||
ntokens = 0
|
ntokens = 0
|
||||||
|
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
@ -153,10 +171,14 @@ def evaluate(
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
losses, toks = loss(model, *batch)
|
losses, toks = loss(model, *batch)
|
||||||
all_losses.append((losses * toks).item())
|
all_losses += losses * toks
|
||||||
ntokens += toks.item()
|
ntokens += toks
|
||||||
|
mx.eval(all_losses, ntokens)
|
||||||
|
|
||||||
return np.sum(all_losses) / ntokens
|
all_losses = mx.distributed.all_sum(all_losses)
|
||||||
|
ntokens = mx.distributed.all_sum(ntokens)
|
||||||
|
|
||||||
|
return (all_losses / ntokens).item()
|
||||||
|
|
||||||
|
|
||||||
class TrainingCallback:
|
class TrainingCallback:
|
||||||
@ -192,6 +214,9 @@ def train(
|
|||||||
# Forward and backward pass
|
# Forward and backward pass
|
||||||
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
||||||
|
|
||||||
|
# All reduce the gradients if running in distributed mode
|
||||||
|
grad = average_gradients(grad)
|
||||||
|
|
||||||
# Model update
|
# Model update
|
||||||
optimizer.update(model, grad)
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user