mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 15:41:01 +08:00
Add it to the readme and fix the rank printing in main
This commit is contained in:
parent
14faec4ca2
commit
d20413a54d
@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM.
|
|||||||
|
|
||||||
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
|
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
|
||||||
We intend to update this example once these features are added.
|
We intend to update this example once these features are added.
|
||||||
|
|
||||||
|
## Distributed training
|
||||||
|
|
||||||
|
The example also supports distributed data parallel training. You can launch a
|
||||||
|
distributed training as follows:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
$ cat >hostfile.json
|
||||||
|
[
|
||||||
|
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
|
||||||
|
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
|
||||||
|
]
|
||||||
|
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
|
||||||
|
```
|
||||||
|
@ -122,7 +122,7 @@ def main(args):
|
|||||||
# Initialize the distributed group and report the nodes that showed up
|
# Initialize the distributed group and report the nodes that showed up
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
if world.size() > 1:
|
if world.size() > 1:
|
||||||
print(f"{world.rank()} of {world.size()}", flush=True)
|
print(f"Starting rank {world.rank()} of {world.size()}", flush=True)
|
||||||
|
|
||||||
model = getattr(resnet, args.arch)()
|
model = getattr(resnet, args.arch)()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user