mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Use concatenated all reduce and gather stats
This commit is contained in:
parent
4786b4e3eb
commit
e0f18d15aa
@ -10,6 +10,7 @@ from typing import Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten, tree_map
|
from mlx.utils import tree_flatten, tree_map
|
||||||
|
|
||||||
|
|
||||||
@ -29,17 +30,6 @@ def grad_checkpoint(layer):
|
|||||||
type(layer).__call__ = checkpointed_fn
|
type(layer).__call__ = checkpointed_fn
|
||||||
|
|
||||||
|
|
||||||
def average_gradients(gradients):
|
|
||||||
world_size = mx.distributed.init().size()
|
|
||||||
if world_size == 1:
|
|
||||||
return gradients
|
|
||||||
|
|
||||||
def _all_average(x):
|
|
||||||
return mx.distributed.all_sum(x) / world_size
|
|
||||||
|
|
||||||
return tree_map(_all_average, gradients)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArgs:
|
class TrainingArgs:
|
||||||
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
||||||
@ -204,6 +194,11 @@ def train(
|
|||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
print(f"Starting training..., iters: {args.iters}")
|
print(f"Starting training..., iters: {args.iters}")
|
||||||
|
world = mx.distributed.init()
|
||||||
|
world_size = world.size()
|
||||||
|
rank = world.rank()
|
||||||
|
if world_size > 1:
|
||||||
|
print(f"Node {rank} of {world_size}")
|
||||||
|
|
||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
grad_checkpoint(model.layers[0])
|
grad_checkpoint(model.layers[0])
|
||||||
@ -224,8 +219,9 @@ def train(
|
|||||||
|
|
||||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||||
|
|
||||||
losses = []
|
losses = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
|
steps = 0
|
||||||
trained_tokens = 0
|
trained_tokens = 0
|
||||||
# Main training loop
|
# Main training loop
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
@ -254,9 +250,12 @@ def train(
|
|||||||
iterate_batches=iterate_batches,
|
iterate_batches=iterate_batches,
|
||||||
)
|
)
|
||||||
val_time = time.perf_counter() - stop
|
val_time = time.perf_counter() - stop
|
||||||
print(
|
if rank == 0:
|
||||||
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
|
print(
|
||||||
)
|
f"Iter {it}: "
|
||||||
|
f"Val loss {val_loss:.3f}, "
|
||||||
|
f"Val took {val_time:.3f}s"
|
||||||
|
)
|
||||||
|
|
||||||
if training_callback is not None:
|
if training_callback is not None:
|
||||||
val_info = {
|
val_info = {
|
||||||
@ -269,30 +268,32 @@ def train(
|
|||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
lvalue, toks = step(batch)
|
lvalue, toks = step(batch)
|
||||||
mx.eval(state, lvalue, toks)
|
losses += lvalue
|
||||||
|
n_tokens += toks
|
||||||
# Record loss
|
steps += 1
|
||||||
losses.append(lvalue.item())
|
mx.eval(state, losses, n_tokens)
|
||||||
n_tokens += toks.item()
|
|
||||||
|
|
||||||
# Report training loss if needed
|
# Report training loss if needed
|
||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
|
|
||||||
train_loss = np.mean(losses)
|
train_loss = mx.distributed.all_sum(losses).item()
|
||||||
|
train_loss /= steps * mx.distributed.init().size()
|
||||||
|
n_tokens = mx.distributed.all_sum(n_tokens).item()
|
||||||
learning_rate = optimizer.learning_rate.item()
|
learning_rate = optimizer.learning_rate.item()
|
||||||
it_sec = args.steps_per_report / (stop - start)
|
it_sec = args.steps_per_report / (stop - start)
|
||||||
tokens_sec = float(n_tokens) / (stop - start)
|
tokens_sec = float(n_tokens) / (stop - start)
|
||||||
trained_tokens += n_tokens
|
trained_tokens += n_tokens
|
||||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||||
print(
|
if rank == 0:
|
||||||
f"Iter {it}: Train loss {train_loss:.3f}, "
|
print(
|
||||||
f"Learning Rate {learning_rate:.3e}, "
|
f"Iter {it}: Train loss {train_loss:.3f}, "
|
||||||
f"It/sec {it_sec:.3f}, "
|
f"Learning Rate {learning_rate:.3e}, "
|
||||||
f"Tokens/sec {tokens_sec:.3f}, "
|
f"It/sec {it_sec:.3f}, "
|
||||||
f"Trained Tokens {trained_tokens}, "
|
f"Tokens/sec {tokens_sec:.3f}, "
|
||||||
f"Peak mem {peak_mem:.3f} GB"
|
f"Trained Tokens {trained_tokens}, "
|
||||||
)
|
f"Peak mem {peak_mem:.3f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
if training_callback is not None:
|
if training_callback is not None:
|
||||||
train_info = {
|
train_info = {
|
||||||
@ -306,8 +307,9 @@ def train(
|
|||||||
}
|
}
|
||||||
training_callback.on_train_loss_report(train_info)
|
training_callback.on_train_loss_report(train_info)
|
||||||
|
|
||||||
losses = []
|
losses = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
|
steps = 0
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
# Save adapter weights
|
# Save adapter weights
|
||||||
|
Loading…
Reference in New Issue
Block a user