Finish the distributed docs

This commit is contained in:
Angelos Katharopoulos
2025-12-12 14:16:16 -08:00
parent 2f939acefa
commit 5523087cfb

View File

@@ -117,6 +117,8 @@ The following examples aim to clarify the backend initialization logic in MLX:
world_ring = mx.distributed.init(backend="ring") world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init() # same as MPI because it was initialized first! world_any = mx.distributed.init() # same as MPI because it was initialized first!
.. _training_example:
Training Example Training Example
---------------- ----------------
@@ -289,7 +291,7 @@ Enabling RDMA
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
Until the feature matures, enabling RDMA over thunderbolt is slightly more Until the feature matures, enabling RDMA over thunderbolt is slightly more
involved and **cannot** be done remotely even with sudo. In fact it has to be involved and **cannot** be done remotely even with sudo. In fact, it has to be
done in macOS recovery: done in macOS recovery:
1. `Start your computer in recovery <https://support.apple.com/en-us/102518>`_. 1. `Start your computer in recovery <https://support.apple.com/en-us/102518>`_.
@@ -316,8 +318,8 @@ Defining a Mesh
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^
The JACCL backend supports only fully connected topologies. Namely, there needs The JACCL backend supports only fully connected topologies. Namely, there needs
to be a thunderbolt cable connecting all pairs of Macs directly. For example in to be a thunderbolt cable connecting all pairs of Macs directly. For example, in
the following topology visualizations the left one is valid because there is a the following topology visualizations, the left one is valid because there is a
connection from any node to any other node, while for the one on the right M3 connection from any node to any other node, while for the one on the right M3
Ultra 1 is not connected to M3 Ultra 2. Ultra 1 is not connected to M3 Ultra 2.
@@ -372,7 +374,7 @@ Even though TCP/IP is not used when communicating with Thunderbolt RDMA,
disabling the thunderbolt bridge is still required as well as setting up disabling the thunderbolt bridge is still required as well as setting up
isolated local networks for each thunderbolt connection. isolated local networks for each thunderbolt connection.
All of the above can be done instead via ``mlx.distributed_config``. The helper All of the above can be done instead via ``mlx.distributed_config``. This helper
script will script will
- ssh into each node - ssh into each node
@@ -384,13 +386,13 @@ script will
Putting it All Together Putting it All Together
^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^
For example to launch a distributed MLX script that uses JACCL is fairly simple For example launching a distributed MLX script that uses JACCL is fairly simple
if the nodes are reachable via ssh and have password-less sudo. if the nodes are reachable via ssh and have password-less sudo.
First, connect all the thunderbolt cables. Then we can verify the connections First, connect all the thunderbolt cables. Then we can verify the connections
by using the ``mlx.distributed_config`` script to visualize them. by using the ``mlx.distributed_config`` script to visualize them.
.. code-block:: bash .. code-block::
mlx.distributed_config --verbose \ mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \ --hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
@@ -399,7 +401,7 @@ by using the ``mlx.distributed_config`` script to visualize them.
After making sure that everything looks right we can auto-configure the nodes After making sure that everything looks right we can auto-configure the nodes
and save the hostfile to ``m3-ultra-jaccl.json`` by running: and save the hostfile to ``m3-ultra-jaccl.json`` by running:
.. code-block:: bash .. code-block::
mlx.distributed_config --verbose \ mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \ --hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
@@ -409,7 +411,7 @@ and save the hostfile to ``m3-ultra-jaccl.json`` by running:
And now we are ready to run a distributed MLX script such as distributed inference And now we are ready to run a distributed MLX script such as distributed inference
of a gigantic model using MLX-LM. of a gigantic model using MLX-LM.
.. code-block:: bash .. code-block::
mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \ mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \
--env MLX_METAL_FAST_SYNCH=1 -- \ # <--- important --env MLX_METAL_FAST_SYNCH=1 -- \ # <--- important
@@ -428,6 +430,32 @@ of a gigantic model using MLX-LM.
Getting Started with NCCL Getting Started with NCCL
------------------------- -------------------------
MLX on CUDA environments ships with the ability to talk to `NCCL
<https://developer.nvidia.com/nccl>`_ which is a high-performance collective
communication library that supports both multi-gpu and multi-node setups.
For CUDA environments, NCCL is the default backend for ``mlx.launch`` and all
it takes to run a distributed job is
.. code-block::
mlx.launch -n 8 test.py
# perfect for interactive scripts
mlx.launch -n 8 python -m mlx_lm chat --model my-model --shard
You can also use ``mlx.launch`` to ssh to a remote node and launch a script
with the same ease
.. code-block::
mlx.launch --hosts my-cuda-node -n 8 test.py
In many cases you may not want to use ``mlx.launch`` with the NCCL backend
because the cluster scheduler will be the one launching the processes. You can
:ref:`see which environment variables need to be defined <no_mlx_launch>` in
order for the MLX NCCL backend to be initialized correctly.
.. _mpi_section: .. _mpi_section:
Getting Started with MPI Getting Started with MPI
@@ -507,9 +535,116 @@ Force MPI to use the most performant network interface by setting ``--mca
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
to use. to use.
.. _no_mlx_launch:
Distributed Without ``mlx.launch`` Distributed Without ``mlx.launch``
---------------------------------- ----------------------------------
None of the implementations of the distributed backends require launching with
``mlx.launch``. The script simply connects to each host. Starts a process per
rank and sets up the necessary environment variables before delegating to your
MLX script. See the :doc:`dedicated documentation page <launching_distributed>`
for more details.
Using the helper scripts For many use-cases this will be the easiest way to perform distributed
------------------------- computations in MLX. However, there may be reasons that you cannot or should
not use ``mlx.launch``. A common such case is the use of a scheduler that
starts all the processes for you on machines undetermined at the time of
scheduling the job.
Below we list the environment variables required to use each backend.
Ring
^^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_HOSTFILE** should contain the path to a json file that contains IPs and
ports for each rank to listen to, something like the following:
.. code-block:: json
[
["123.123.1.1:5000", "123.123.1.2:5000"],
["123.123.2.1:5000", "123.123.2.2:5000"],
["123.123.3.1:5000", "123.123.3.2:5000"],
["123.123.4.1:5000", "123.123.4.2:5000"]
]
**MLX_RING_VERBOSE** is optional and if set to 1 it enables some more logging
from the distributed backend.
JACCL
^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_JACCL_COORDINATOR** should contain the IP and port that rank 0 can listen
to all the other ranks connect to in order to establish the RDMA connections.
**MLX_IBV_DEVICES** should contain the path to a json file that contains the
ibverbs device names that connect each node to each other node, something like
the following:
.. code-block:: json
[
[null, "rdma_en5", "rdma_en4", "rdma_en3"],
["rdma_en5", null, "rdma_en3", "rdma_en4"],
["rdma_en4", "rdma_en3", null, "rdma_en5"],
["rdma_en3", "rdma_en4", "rdma_en5", null]
]
NCCL
^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_WORLD_SIZE** should contain the total number of processes that will be
launched.
**NCCL_HOST_IP** and **NCCL_PORT** should contain the IP and port that all
hosts can connect to to establish the NCCL communication.
**CUDA_VISIBLE_DEVICES** should contain the local index of the gpu that
corresponds to this process.
Of course any `other environment variable
<https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html>`_ that is
used by NCCL can be set.
.. _tips_and_tricks:
Tips and Tricks
----------------
This is a small collection of tips to help you utilize better the distributed
communication capabilities of MLX.
- *Test locally first.*
You can use the pattern ``mlx.launch -n2 -- my_script.py`` to run a small
scale test on a single node first.
- *Batch your communication.*
As described in the :ref:`training example <training_example>`, performing a
lot of small communication can hurt performance. Copy the approach of
:func:`mlx.nn.average_gradients` to gather many small communications in a
single large one.
- *Visualize the connectivity.*
Use ``mlx.distributed_config --hosts h1,h2,h3 --over thunderbolt --dot`` to
visualize the connnections and make sure that the cables are connected
correctly. See the :ref:`JACCL section <jaccl_section>` for examples.
- *Use the debugger.*
``mlx.launch`` is meant for interactive use. It broadcasts stdin to all
processes and gathers stdout from all processes. This makes using ``pdb`` a
breeze.