mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Use concatenated all reduce and gather stats
This commit is contained in:
parent
4786b4e3eb
commit
e0f18d15aa
@ -10,6 +10,7 @@ from typing import Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
|
||||
@ -29,17 +30,6 @@ def grad_checkpoint(layer):
|
||||
type(layer).__call__ = checkpointed_fn
|
||||
|
||||
|
||||
def average_gradients(gradients):
|
||||
world_size = mx.distributed.init().size()
|
||||
if world_size == 1:
|
||||
return gradients
|
||||
|
||||
def _all_average(x):
|
||||
return mx.distributed.all_sum(x) / world_size
|
||||
|
||||
return tree_map(_all_average, gradients)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArgs:
|
||||
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
||||
@ -204,6 +194,11 @@ def train(
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
print(f"Starting training..., iters: {args.iters}")
|
||||
world = mx.distributed.init()
|
||||
world_size = world.size()
|
||||
rank = world.rank()
|
||||
if world_size > 1:
|
||||
print(f"Node {rank} of {world_size}")
|
||||
|
||||
if args.grad_checkpoint:
|
||||
grad_checkpoint(model.layers[0])
|
||||
@ -224,8 +219,9 @@ def train(
|
||||
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||
|
||||
losses = []
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
trained_tokens = 0
|
||||
# Main training loop
|
||||
start = time.perf_counter()
|
||||
@ -254,8 +250,11 @@ def train(
|
||||
iterate_batches=iterate_batches,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
|
||||
f"Iter {it}: "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val took {val_time:.3f}s"
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
@ -269,22 +268,24 @@ def train(
|
||||
start = time.perf_counter()
|
||||
|
||||
lvalue, toks = step(batch)
|
||||
mx.eval(state, lvalue, toks)
|
||||
|
||||
# Record loss
|
||||
losses.append(lvalue.item())
|
||||
n_tokens += toks.item()
|
||||
losses += lvalue
|
||||
n_tokens += toks
|
||||
steps += 1
|
||||
mx.eval(state, losses, n_tokens)
|
||||
|
||||
# Report training loss if needed
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
|
||||
train_loss = np.mean(losses)
|
||||
train_loss = mx.distributed.all_sum(losses).item()
|
||||
train_loss /= steps * mx.distributed.init().size()
|
||||
n_tokens = mx.distributed.all_sum(n_tokens).item()
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
trained_tokens += n_tokens
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: Train loss {train_loss:.3f}, "
|
||||
f"Learning Rate {learning_rate:.3e}, "
|
||||
@ -306,8 +307,9 @@ def train(
|
||||
}
|
||||
training_callback.on_train_loss_report(train_info)
|
||||
|
||||
losses = []
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save adapter weights
|
||||
|
Loading…
Reference in New Issue
Block a user