Add it to the readme and fix the rank printing in main

This commit is contained in:
Angelos Katharopoulos 2025-02-25 17:40:24 -08:00
parent 14faec4ca2
commit d20413a54d
2 changed files with 15 additions and 1 deletions

View File

@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
We intend to update this example once these features are added.
## Distributed training
The example also supports distributed data parallel training. You can launch a
distributed training as follows:
```shell
$ cat >hostfile.json
[
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
]
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
```

View File

@ -122,7 +122,7 @@ def main(args):
# Initialize the distributed group and report the nodes that showed up
world = mx.distributed.init()
if world.size() > 1:
print(f"{world.rank()} of {world.size()}", flush=True)
print(f"Starting rank {world.rank()} of {world.size()}", flush=True)
model = getattr(resnet, args.arch)()