diff --git a/cifar/README.md b/cifar/README.md index 763e641d..2016200d 100644 --- a/cifar/README.md +++ b/cifar/README.md @@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM. At the time of writing, `mlx` doesn't have built-in learning rate schedules. We intend to update this example once these features are added. + +## Distributed training + +The example also supports distributed data parallel training. You can launch a +distributed training as follows: + +```shell +$ cat >hostfile.json +[ + {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}, + {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]} +] +$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20 +``` diff --git a/cifar/main.py b/cifar/main.py index 7eb6efdf..3fe5d2e0 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -122,7 +122,7 @@ def main(args): # Initialize the distributed group and report the nodes that showed up world = mx.distributed.init() if world.size() > 1: - print(f"{world.rank()} of {world.size()}", flush=True) + print(f"Starting rank {world.rank()} of {world.size()}", flush=True) model = getattr(resnet, args.arch)()