From e7267d30f83bc3f22ff6f0f8132ca0bcd9c38115 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 5 Mar 2025 13:33:15 -0800 Subject: [PATCH] Distributed support cifar (#1301) --- cifar/README.md | 14 ++++++++ cifar/dataset.py | 11 +++++- cifar/main.py | 91 +++++++++++++++++++++++++++++++----------------- 3 files changed, 84 insertions(+), 32 deletions(-) diff --git a/cifar/README.md b/cifar/README.md index 763e641d..2016200d 100644 --- a/cifar/README.md +++ b/cifar/README.md @@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM. At the time of writing, `mlx` doesn't have built-in learning rate schedules. We intend to update this example once these features are added. + +## Distributed training + +The example also supports distributed data parallel training. You can launch a +distributed training as follows: + +```shell +$ cat >hostfile.json +[ + {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}, + {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]} +] +$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20 +``` diff --git a/cifar/dataset.py b/cifar/dataset.py index 22b229f8..8967591e 100644 --- a/cifar/dataset.py +++ b/cifar/dataset.py @@ -1,3 +1,4 @@ +import mlx.core as mx import numpy as np from mlx.data.datasets import load_cifar10 @@ -12,8 +13,11 @@ def get_cifar10(batch_size, root=None): x = x.astype("float32") / 255.0 return (x - mean) / std + group = mx.distributed.init() + tr_iter = ( tr.shuffle() + .partition_if(group.size() > 1, group.size(), group.rank()) .to_stream() .image_random_h_flip("image", prob=0.5) .pad("image", 0, 4, 4, 0.0) @@ -25,6 +29,11 @@ def get_cifar10(batch_size, root=None): ) test = load_cifar10(root=root, train=False) - test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size) + test_iter = ( + test.to_stream() + .partition_if(group.size() > 1, group.size(), group.rank()) + .key_transform("image", normalize) + .batch(batch_size) + ) return tr_iter, test_iter diff --git a/cifar/main.py b/cifar/main.py index 378bc424..ac010636 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -23,6 +23,13 @@ parser.add_argument("--seed", type=int, default=0, help="random seed") parser.add_argument("--cpu", action="store_true", help="use cpu only") +def print_zero(group, *args, **kwargs): + if group.rank() != 0: + return + flush = kwargs.pop("flush", True) + print(*args, **kwargs, flush=flush) + + def eval_fn(model, inp, tgt): return mx.mean(mx.argmax(model(inp), axis=1) == tgt) @@ -34,9 +41,20 @@ def train_epoch(model, train_iter, optimizer, epoch): acc = mx.mean(mx.argmax(output, axis=1) == tgt) return loss, acc - losses = [] - accs = [] - samples_per_sec = [] + world = mx.distributed.init() + losses = 0 + accuracies = 0 + samples_per_sec = 0 + count = 0 + + def average_stats(stats, count): + if world.size() == 1: + return [s / count for s in stats] + + with mx.stream(mx.cpu): + stats = mx.distributed.all_sum(mx.array(stats)) + count = mx.distributed.all_sum(count) + return (stats / count).tolist() state = [model.state, optimizer.state] @@ -44,6 +62,7 @@ def train_epoch(model, train_iter, optimizer, epoch): def step(inp, tgt): train_step_fn = nn.value_and_grad(model, train_step) (loss, acc), grads = train_step_fn(model, inp, tgt) + grads = nn.utils.average_gradients(grads) optimizer.update(model, grads) return loss, acc @@ -52,69 +71,79 @@ def train_epoch(model, train_iter, optimizer, epoch): y = mx.array(batch["label"]) tic = time.perf_counter() loss, acc = step(x, y) - mx.eval(state) + mx.eval(loss, acc, state) toc = time.perf_counter() - loss = loss.item() - acc = acc.item() - losses.append(loss) - accs.append(acc) - throughput = x.shape[0] / (toc - tic) - samples_per_sec.append(throughput) + losses += loss.item() + accuracies += acc.item() + samples_per_sec += x.shape[0] / (toc - tic) + count += 1 if batch_counter % 10 == 0: - print( + l, a, s = average_stats( + [losses, accuracies, world.size() * samples_per_sec], + count, + ) + print_zero( + world, " | ".join( ( f"Epoch {epoch:02d} [{batch_counter:03d}]", - f"Train loss {loss:.3f}", - f"Train acc {acc:.3f}", - f"Throughput: {throughput:.2f} images/second", + f"Train loss {l:.3f}", + f"Train acc {a:.3f}", + f"Throughput: {s:.2f} images/second", ) - ) + ), ) - mean_tr_loss = mx.mean(mx.array(losses)) - mean_tr_acc = mx.mean(mx.array(accs)) - samples_per_sec = mx.mean(mx.array(samples_per_sec)) - return mean_tr_loss, mean_tr_acc, samples_per_sec + return average_stats([losses, accuracies, world.size() * samples_per_sec], count) def test_epoch(model, test_iter, epoch): - accs = [] + accuracies = 0 + count = 0 for batch_counter, batch in enumerate(test_iter): x = mx.array(batch["image"]) y = mx.array(batch["label"]) acc = eval_fn(model, x, y) - acc_value = acc.item() - accs.append(acc_value) - mean_acc = mx.mean(mx.array(accs)) - return mean_acc + accuracies += acc.item() + count += 1 + + with mx.stream(mx.cpu): + accuracies = mx.distributed.all_sum(accuracies) + count = mx.distributed.all_sum(count) + return (accuracies / count).item() def main(args): mx.random.seed(args.seed) + # Initialize the distributed group and report the nodes that showed up + world = mx.distributed.init() + if world.size() > 1: + print(f"Starting rank {world.rank()} of {world.size()}", flush=True) + model = getattr(resnet, args.arch)() - print("Number of params: {:0.04f} M".format(model.num_params() / 1e6)) + print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M") optimizer = optim.Adam(learning_rate=args.lr) train_data, test_data = get_cifar10(args.batch_size) for epoch in range(args.epochs): tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch) - print( + print_zero( + world, " | ".join( ( f"Epoch: {epoch}", - f"avg. Train loss {tr_loss.item():.3f}", - f"avg. Train acc {tr_acc.item():.3f}", - f"Throughput: {throughput.item():.2f} images/sec", + f"avg. Train loss {tr_loss:.3f}", + f"avg. Train acc {tr_acc:.3f}", + f"Throughput: {throughput:.2f} images/sec", ) - ) + ), ) test_acc = test_epoch(model, test_data, epoch) - print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}") + print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}") train_data.reset() test_data.reset()