Use concatenated all reduce and gather stats

This commit is contained in:
Angelos Katharopoulos 2024-09-12 13:33:57 -07:00
parent 4786b4e3eb
commit e0f18d15aa

View File

@ -10,6 +10,7 @@ from typing import Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map from mlx.utils import tree_flatten, tree_map
@ -29,17 +30,6 @@ def grad_checkpoint(layer):
type(layer).__call__ = checkpointed_fn type(layer).__call__ = checkpointed_fn
def average_gradients(gradients):
world_size = mx.distributed.init().size()
if world_size == 1:
return gradients
def _all_average(x):
return mx.distributed.all_sum(x) / world_size
return tree_map(_all_average, gradients)
@dataclass @dataclass
class TrainingArgs: class TrainingArgs:
batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
@ -204,6 +194,11 @@ def train(
training_callback: TrainingCallback = None, training_callback: TrainingCallback = None,
): ):
print(f"Starting training..., iters: {args.iters}") print(f"Starting training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint: if args.grad_checkpoint:
grad_checkpoint(model.layers[0]) grad_checkpoint(model.layers[0])
@ -224,8 +219,9 @@ def train(
loss_value_and_grad = nn.value_and_grad(model, loss) loss_value_and_grad = nn.value_and_grad(model, loss)
losses = [] losses = 0
n_tokens = 0 n_tokens = 0
steps = 0
trained_tokens = 0 trained_tokens = 0
# Main training loop # Main training loop
start = time.perf_counter() start = time.perf_counter()
@ -254,9 +250,12 @@ def train(
iterate_batches=iterate_batches, iterate_batches=iterate_batches,
) )
val_time = time.perf_counter() - stop val_time = time.perf_counter() - stop
print( if rank == 0:
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s" print(
) f"Iter {it}: "
f"Val loss {val_loss:.3f}, "
f"Val took {val_time:.3f}s"
)
if training_callback is not None: if training_callback is not None:
val_info = { val_info = {
@ -269,30 +268,32 @@ def train(
start = time.perf_counter() start = time.perf_counter()
lvalue, toks = step(batch) lvalue, toks = step(batch)
mx.eval(state, lvalue, toks) losses += lvalue
n_tokens += toks
# Record loss steps += 1
losses.append(lvalue.item()) mx.eval(state, losses, n_tokens)
n_tokens += toks.item()
# Report training loss if needed # Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters: if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter() stop = time.perf_counter()
train_loss = np.mean(losses) train_loss = mx.distributed.all_sum(losses).item()
train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item() learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start) tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 2**30 peak_mem = mx.metal.get_peak_memory() / 2**30
print( if rank == 0:
f"Iter {it}: Train loss {train_loss:.3f}, " print(
f"Learning Rate {learning_rate:.3e}, " f"Iter {it}: Train loss {train_loss:.3f}, "
f"It/sec {it_sec:.3f}, " f"Learning Rate {learning_rate:.3e}, "
f"Tokens/sec {tokens_sec:.3f}, " f"It/sec {it_sec:.3f}, "
f"Trained Tokens {trained_tokens}, " f"Tokens/sec {tokens_sec:.3f}, "
f"Peak mem {peak_mem:.3f} GB" f"Trained Tokens {trained_tokens}, "
) f"Peak mem {peak_mem:.3f} GB"
)
if training_callback is not None: if training_callback is not None:
train_info = { train_info = {
@ -306,8 +307,9 @@ def train(
} }
training_callback.on_train_loss_report(train_info) training_callback.on_train_loss_report(train_info)
losses = [] losses = 0
n_tokens = 0 n_tokens = 0
steps = 0
start = time.perf_counter() start = time.perf_counter()
# Save adapter weights # Save adapter weights