mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 23:49:43 +08:00
Add it to the readme and fix the rank printing in main
This commit is contained in:
parent
14faec4ca2
commit
d20413a54d
@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM.
|
||||
|
||||
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
|
||||
We intend to update this example once these features are added.
|
||||
|
||||
## Distributed training
|
||||
|
||||
The example also supports distributed data parallel training. You can launch a
|
||||
distributed training as follows:
|
||||
|
||||
```shell
|
||||
$ cat >hostfile.json
|
||||
[
|
||||
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
|
||||
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
|
||||
]
|
||||
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
|
||||
```
|
||||
|
@ -122,7 +122,7 @@ def main(args):
|
||||
# Initialize the distributed group and report the nodes that showed up
|
||||
world = mx.distributed.init()
|
||||
if world.size() > 1:
|
||||
print(f"{world.rank()} of {world.size()}", flush=True)
|
||||
print(f"Starting rank {world.rank()} of {world.size()}", flush=True)
|
||||
|
||||
model = getattr(resnet, args.arch)()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user