mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Enable distributed LoRA training (#821)
This commit is contained in:
parent
29c954f4cb
commit
331148d8ec
@ -10,6 +10,7 @@ from typing import Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
|
||||
@ -84,9 +85,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
f" examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
# If running in distributed mode (N machines) then each one should skip N-1
|
||||
# samples
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
|
||||
# Make the batches:
|
||||
batch_idx = [
|
||||
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
idx[i : i + batch_size : step]
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
]
|
||||
|
||||
while True:
|
||||
@ -112,9 +120,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||
|
||||
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
|
||||
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
|
||||
for j in range(batch_size):
|
||||
for j in range(batch_size // step):
|
||||
truncated_length = min(lengths[j], max_seq_length)
|
||||
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
|
||||
lengths[j] = (
|
||||
@ -138,7 +146,7 @@ def evaluate(
|
||||
loss: callable = default_loss,
|
||||
iterate_batches: callable = iterate_batches,
|
||||
):
|
||||
all_losses = []
|
||||
all_losses = 0
|
||||
ntokens = 0
|
||||
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
@ -153,10 +161,14 @@ def evaluate(
|
||||
),
|
||||
):
|
||||
losses, toks = loss(model, *batch)
|
||||
all_losses.append((losses * toks).item())
|
||||
ntokens += toks.item()
|
||||
all_losses += losses * toks
|
||||
ntokens += toks
|
||||
mx.eval(all_losses, ntokens)
|
||||
|
||||
return np.sum(all_losses) / ntokens
|
||||
all_losses = mx.distributed.all_sum(all_losses)
|
||||
ntokens = mx.distributed.all_sum(ntokens)
|
||||
|
||||
return (all_losses / ntokens).item()
|
||||
|
||||
|
||||
class TrainingCallback:
|
||||
@ -182,6 +194,11 @@ def train(
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
print(f"Starting training..., iters: {args.iters}")
|
||||
world = mx.distributed.init()
|
||||
world_size = world.size()
|
||||
rank = world.rank()
|
||||
if world_size > 1:
|
||||
print(f"Node {rank} of {world_size}")
|
||||
|
||||
if args.grad_checkpoint:
|
||||
grad_checkpoint(model.layers[0])
|
||||
@ -192,6 +209,9 @@ def train(
|
||||
# Forward and backward pass
|
||||
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
||||
|
||||
# All reduce the gradients if running in distributed mode
|
||||
grad = average_gradients(grad)
|
||||
|
||||
# Model update
|
||||
optimizer.update(model, grad)
|
||||
|
||||
@ -199,8 +219,9 @@ def train(
|
||||
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||
|
||||
losses = []
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
trained_tokens = 0
|
||||
# Main training loop
|
||||
start = time.perf_counter()
|
||||
@ -229,8 +250,12 @@ def train(
|
||||
iterate_batches=iterate_batches,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
|
||||
f"Iter {it}: "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val took {val_time:.3f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
@ -244,29 +269,32 @@ def train(
|
||||
start = time.perf_counter()
|
||||
|
||||
lvalue, toks = step(batch)
|
||||
mx.eval(state, lvalue, toks)
|
||||
|
||||
# Record loss
|
||||
losses.append(lvalue.item())
|
||||
n_tokens += toks.item()
|
||||
losses += lvalue
|
||||
n_tokens += toks
|
||||
steps += 1
|
||||
mx.eval(state, losses, n_tokens)
|
||||
|
||||
# Report training loss if needed
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
|
||||
train_loss = np.mean(losses)
|
||||
train_loss = mx.distributed.all_sum(losses).item()
|
||||
train_loss /= steps * mx.distributed.init().size()
|
||||
n_tokens = mx.distributed.all_sum(n_tokens).item()
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
trained_tokens += n_tokens
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: Train loss {train_loss:.3f}, "
|
||||
f"Learning Rate {learning_rate:.3e}, "
|
||||
f"It/sec {it_sec:.3f}, "
|
||||
f"Tokens/sec {tokens_sec:.3f}, "
|
||||
f"Trained Tokens {trained_tokens}, "
|
||||
f"Peak mem {peak_mem:.3f} GB"
|
||||
f"Peak mem {peak_mem:.3f} GB",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
@ -281,8 +309,9 @@ def train(
|
||||
}
|
||||
training_callback.on_train_loss_report(train_info)
|
||||
|
||||
losses = []
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save adapter weights
|
||||
|
@ -3,6 +3,7 @@
|
||||
import math
|
||||
import sys
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate
|
||||
from mlx_lm.tuner.utils import build_schedule
|
||||
|
||||
|
||||
@contextmanager
|
||||
def swapped_with_identity(obj, func):
|
||||
old_func = getattr(obj, func)
|
||||
setattr(obj, func, lambda x: x)
|
||||
yield
|
||||
setattr(obj, func, old_func)
|
||||
|
||||
|
||||
class TestLora(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.capturedOutput = StringIO()
|
||||
@ -374,6 +383,7 @@ class TestScheduleConfig(unittest.TestCase):
|
||||
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
|
||||
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
|
||||
]
|
||||
with swapped_with_identity(mx.distributed, "all_sum"):
|
||||
evaluate(
|
||||
model=mock_model,
|
||||
dataset=mock_dataset,
|
||||
@ -412,6 +422,7 @@ class TestScheduleConfig(unittest.TestCase):
|
||||
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
||||
]
|
||||
|
||||
with swapped_with_identity(mx.distributed, "all_sum"):
|
||||
evaluate(
|
||||
model=mock_model,
|
||||
dataset=mock_dataset,
|
||||
|
Loading…
Reference in New Issue
Block a user