diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 31f845b2..397dd07f 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -10,6 +10,7 @@ from typing import Union import mlx.core as mx import mlx.nn as nn import numpy as np +from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map @@ -29,17 +30,6 @@ def grad_checkpoint(layer): type(layer).__call__ = checkpointed_fn -def average_gradients(gradients): - world_size = mx.distributed.init().size() - if world_size == 1: - return gradients - - def _all_average(x): - return mx.distributed.all_sum(x) / world_size - - return tree_map(_all_average, gradients) - - @dataclass class TrainingArgs: batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) @@ -204,6 +194,11 @@ def train( training_callback: TrainingCallback = None, ): print(f"Starting training..., iters: {args.iters}") + world = mx.distributed.init() + world_size = world.size() + rank = world.rank() + if world_size > 1: + print(f"Node {rank} of {world_size}") if args.grad_checkpoint: grad_checkpoint(model.layers[0]) @@ -224,8 +219,9 @@ def train( loss_value_and_grad = nn.value_and_grad(model, loss) - losses = [] + losses = 0 n_tokens = 0 + steps = 0 trained_tokens = 0 # Main training loop start = time.perf_counter() @@ -254,9 +250,12 @@ def train( iterate_batches=iterate_batches, ) val_time = time.perf_counter() - stop - print( - f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s" - ) + if rank == 0: + print( + f"Iter {it}: " + f"Val loss {val_loss:.3f}, " + f"Val took {val_time:.3f}s" + ) if training_callback is not None: val_info = { @@ -269,30 +268,32 @@ def train( start = time.perf_counter() lvalue, toks = step(batch) - mx.eval(state, lvalue, toks) - - # Record loss - losses.append(lvalue.item()) - n_tokens += toks.item() + losses += lvalue + n_tokens += toks + steps += 1 + mx.eval(state, losses, n_tokens) # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() - train_loss = np.mean(losses) + train_loss = mx.distributed.all_sum(losses).item() + train_loss /= steps * mx.distributed.init().size() + n_tokens = mx.distributed.all_sum(n_tokens).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 2**30 - print( - f"Iter {it}: Train loss {train_loss:.3f}, " - f"Learning Rate {learning_rate:.3e}, " - f"It/sec {it_sec:.3f}, " - f"Tokens/sec {tokens_sec:.3f}, " - f"Trained Tokens {trained_tokens}, " - f"Peak mem {peak_mem:.3f} GB" - ) + if rank == 0: + print( + f"Iter {it}: Train loss {train_loss:.3f}, " + f"Learning Rate {learning_rate:.3e}, " + f"It/sec {it_sec:.3f}, " + f"Tokens/sec {tokens_sec:.3f}, " + f"Trained Tokens {trained_tokens}, " + f"Peak mem {peak_mem:.3f} GB" + ) if training_callback is not None: train_info = { @@ -306,8 +307,9 @@ def train( } training_callback.on_train_loss_report(train_info) - losses = [] + losses = 0 n_tokens = 0 + steps = 0 start = time.perf_counter() # Save adapter weights