mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
2f939acefa
...
ibv-backen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d2bc340df4 | ||
|
|
fabc947df4 | ||
|
|
5523087cfb |
@@ -117,6 +117,8 @@ The following examples aim to clarify the backend initialization logic in MLX:
|
||||
world_ring = mx.distributed.init(backend="ring")
|
||||
world_any = mx.distributed.init() # same as MPI because it was initialized first!
|
||||
|
||||
.. _training_example:
|
||||
|
||||
Training Example
|
||||
----------------
|
||||
|
||||
@@ -289,7 +291,7 @@ Enabling RDMA
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Until the feature matures, enabling RDMA over thunderbolt is slightly more
|
||||
involved and **cannot** be done remotely even with sudo. In fact it has to be
|
||||
involved and **cannot** be done remotely even with sudo. In fact, it has to be
|
||||
done in macOS recovery:
|
||||
|
||||
1. `Start your computer in recovery <https://support.apple.com/en-us/102518>`_.
|
||||
@@ -316,8 +318,8 @@ Defining a Mesh
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
The JACCL backend supports only fully connected topologies. Namely, there needs
|
||||
to be a thunderbolt cable connecting all pairs of Macs directly. For example in
|
||||
the following topology visualizations the left one is valid because there is a
|
||||
to be a thunderbolt cable connecting all pairs of Macs directly. For example, in
|
||||
the following topology visualizations, the left one is valid because there is a
|
||||
connection from any node to any other node, while for the one on the right M3
|
||||
Ultra 1 is not connected to M3 Ultra 2.
|
||||
|
||||
@@ -372,7 +374,7 @@ Even though TCP/IP is not used when communicating with Thunderbolt RDMA,
|
||||
disabling the thunderbolt bridge is still required as well as setting up
|
||||
isolated local networks for each thunderbolt connection.
|
||||
|
||||
All of the above can be done instead via ``mlx.distributed_config``. The helper
|
||||
All of the above can be done instead via ``mlx.distributed_config``. This helper
|
||||
script will
|
||||
|
||||
- ssh into each node
|
||||
@@ -382,15 +384,15 @@ script will
|
||||
- generate the hostfile to be used with ``mlx.launch``
|
||||
|
||||
Putting it All Together
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
For example to launch a distributed MLX script that uses JACCL is fairly simple
|
||||
For example launching a distributed MLX script that uses JACCL is fairly simple
|
||||
if the nodes are reachable via ssh and have password-less sudo.
|
||||
|
||||
First, connect all the thunderbolt cables. Then we can verify the connections
|
||||
by using the ``mlx.distributed_config`` script to visualize them.
|
||||
|
||||
.. code-block:: bash
|
||||
.. code-block::
|
||||
|
||||
mlx.distributed_config --verbose \
|
||||
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
|
||||
@@ -399,7 +401,7 @@ by using the ``mlx.distributed_config`` script to visualize them.
|
||||
After making sure that everything looks right we can auto-configure the nodes
|
||||
and save the hostfile to ``m3-ultra-jaccl.json`` by running:
|
||||
|
||||
.. code-block:: bash
|
||||
.. code-block::
|
||||
|
||||
mlx.distributed_config --verbose \
|
||||
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
|
||||
@@ -409,7 +411,7 @@ and save the hostfile to ``m3-ultra-jaccl.json`` by running:
|
||||
And now we are ready to run a distributed MLX script such as distributed inference
|
||||
of a gigantic model using MLX-LM.
|
||||
|
||||
.. code-block:: bash
|
||||
.. code-block::
|
||||
|
||||
mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \
|
||||
--env MLX_METAL_FAST_SYNCH=1 -- \ # <--- important
|
||||
@@ -428,6 +430,32 @@ of a gigantic model using MLX-LM.
|
||||
Getting Started with NCCL
|
||||
-------------------------
|
||||
|
||||
MLX on CUDA environments ships with the ability to talk to `NCCL
|
||||
<https://developer.nvidia.com/nccl>`_ which is a high-performance collective
|
||||
communication library that supports both multi-gpu and multi-node setups.
|
||||
|
||||
For CUDA environments, NCCL is the default backend for ``mlx.launch`` and all
|
||||
it takes to run a distributed job is
|
||||
|
||||
.. code-block::
|
||||
|
||||
mlx.launch -n 8 test.py
|
||||
|
||||
# perfect for interactive scripts
|
||||
mlx.launch -n 8 python -m mlx_lm chat --model my-model --shard
|
||||
|
||||
You can also use ``mlx.launch`` to ssh to a remote node and launch a script
|
||||
with the same ease
|
||||
|
||||
.. code-block::
|
||||
|
||||
mlx.launch --hosts my-cuda-node -n 8 test.py
|
||||
|
||||
In many cases you may not want to use ``mlx.launch`` with the NCCL backend
|
||||
because the cluster scheduler will be the one launching the processes. You can
|
||||
:ref:`see which environment variables need to be defined <no_mlx_launch>` in
|
||||
order for the MLX NCCL backend to be initialized correctly.
|
||||
|
||||
.. _mpi_section:
|
||||
|
||||
Getting Started with MPI
|
||||
@@ -507,9 +535,116 @@ Force MPI to use the most performant network interface by setting ``--mca
|
||||
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
|
||||
to use.
|
||||
|
||||
.. _no_mlx_launch:
|
||||
|
||||
Distributed Without ``mlx.launch``
|
||||
----------------------------------
|
||||
|
||||
None of the implementations of the distributed backends require launching with
|
||||
``mlx.launch``. The script simply connects to each host. Starts a process per
|
||||
rank and sets up the necessary environment variables before delegating to your
|
||||
MLX script. See the :doc:`dedicated documentation page <launching_distributed>`
|
||||
for more details.
|
||||
|
||||
Using the helper scripts
|
||||
-------------------------
|
||||
For many use-cases this will be the easiest way to perform distributed
|
||||
computations in MLX. However, there may be reasons that you cannot or should
|
||||
not use ``mlx.launch``. A common such case is the use of a scheduler that
|
||||
starts all the processes for you on machines undetermined at the time of
|
||||
scheduling the job.
|
||||
|
||||
Below we list the environment variables required to use each backend.
|
||||
|
||||
Ring
|
||||
^^^^^^
|
||||
|
||||
**MLX_RANK** should contain a single 0-based integer that defines the rank of
|
||||
the process.
|
||||
|
||||
**MLX_HOSTFILE** should contain the path to a json file that contains IPs and
|
||||
ports for each rank to listen to, something like the following:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
[
|
||||
["123.123.1.1:5000", "123.123.1.2:5000"],
|
||||
["123.123.2.1:5000", "123.123.2.2:5000"],
|
||||
["123.123.3.1:5000", "123.123.3.2:5000"],
|
||||
["123.123.4.1:5000", "123.123.4.2:5000"]
|
||||
]
|
||||
|
||||
**MLX_RING_VERBOSE** is optional and if set to 1 it enables some more logging
|
||||
from the distributed backend.
|
||||
|
||||
JACCL
|
||||
^^^^^
|
||||
|
||||
**MLX_RANK** should contain a single 0-based integer that defines the rank of
|
||||
the process.
|
||||
|
||||
**MLX_JACCL_COORDINATOR** should contain the IP and port that rank 0 can listen
|
||||
to all the other ranks connect to in order to establish the RDMA connections.
|
||||
|
||||
**MLX_IBV_DEVICES** should contain the path to a json file that contains the
|
||||
ibverbs device names that connect each node to each other node, something like
|
||||
the following:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
[
|
||||
[null, "rdma_en5", "rdma_en4", "rdma_en3"],
|
||||
["rdma_en5", null, "rdma_en3", "rdma_en4"],
|
||||
["rdma_en4", "rdma_en3", null, "rdma_en5"],
|
||||
["rdma_en3", "rdma_en4", "rdma_en5", null]
|
||||
]
|
||||
|
||||
|
||||
NCCL
|
||||
^^^^^
|
||||
|
||||
**MLX_RANK** should contain a single 0-based integer that defines the rank of
|
||||
the process.
|
||||
|
||||
**MLX_WORLD_SIZE** should contain the total number of processes that will be
|
||||
launched.
|
||||
|
||||
**NCCL_HOST_IP** and **NCCL_PORT** should contain the IP and port that all
|
||||
hosts can connect to to establish the NCCL communication.
|
||||
|
||||
**CUDA_VISIBLE_DEVICES** should contain the local index of the gpu that
|
||||
corresponds to this process.
|
||||
|
||||
Of course any `other environment variable
|
||||
<https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html>`_ that is
|
||||
used by NCCL can be set.
|
||||
|
||||
.. _tips_and_tricks:
|
||||
|
||||
Tips and Tricks
|
||||
----------------
|
||||
|
||||
This is a small collection of tips to help you utilize better the distributed
|
||||
communication capabilities of MLX.
|
||||
|
||||
- *Test locally first.*
|
||||
|
||||
You can use the pattern ``mlx.launch -n2 -- my_script.py`` to run a small
|
||||
scale test on a single node first.
|
||||
|
||||
- *Batch your communication.*
|
||||
|
||||
As described in the :ref:`training example <training_example>`, performing a
|
||||
lot of small communication can hurt performance. Copy the approach of
|
||||
:func:`mlx.nn.average_gradients` to gather many small communications in a
|
||||
single large one.
|
||||
|
||||
- *Visualize the connectivity.*
|
||||
|
||||
Use ``mlx.distributed_config --hosts h1,h2,h3 --over thunderbolt --dot`` to
|
||||
visualize the connnections and make sure that the cables are connected
|
||||
correctly. See the :ref:`JACCL section <jaccl_section>` for examples.
|
||||
|
||||
- *Use the debugger.*
|
||||
|
||||
``mlx.launch`` is meant for interactive use. It broadcasts stdin to all
|
||||
processes and gathers stdout from all processes. This makes using ``pdb`` a
|
||||
breeze.
|
||||
|
||||
@@ -7,13 +7,106 @@ Launching Distributed Programs
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
Installing the MLX python package provides a helper script ``mlx.launch`` that
|
||||
can be used to run python scripts distributed on several nodes. It allows
|
||||
launching using either the MPI backend or the ring backend. See the
|
||||
:doc:`distributed docs <distributed>` for the different backends.
|
||||
Installing the MLX python package provides two utilities to help you configure
|
||||
your Macs for distributed computation and also launch distributed programs on
|
||||
multiple nodes or with many processes in a single node. These utilities are aptly named
|
||||
|
||||
Usage
|
||||
-----
|
||||
- ``mlx.launch``
|
||||
- ``mlx.distributed_config``
|
||||
|
||||
See the :doc:`distributed docs <distributed>` for an introduction and
|
||||
getting-started guides to the various backends.
|
||||
|
||||
``mlx.distributed_config``
|
||||
---------------------------
|
||||
|
||||
Unless you are launching distributed jobs locally for development or multi-gpu
|
||||
CUDA environments, then you have several Macs that you need to configure for
|
||||
distributed communication with MLX.
|
||||
|
||||
``mlx.distributed_config`` aims to automate the process of configuring the
|
||||
network interfaces (especially for communication over thunderbolt) and also
|
||||
creating the hostfile to be used with ``mlx.launch``.
|
||||
|
||||
We will analyse 3 cases of using ``mlx.distributed_config``
|
||||
|
||||
1. RDMA over thunderbolt using JACCL
|
||||
2. TCP/IP over thunderbolt using the ring backend
|
||||
3. TCP/IP over ethernet using the ring backend
|
||||
|
||||
JACCL
|
||||
^^^^^^^
|
||||
|
||||
After following :ref:`the steps to enable RDMA <jaccl_section>` you can run the
|
||||
following command to configure the nodes and create the hostfile.
|
||||
|
||||
.. code-block::
|
||||
|
||||
mlx.distributed_config --verbose --backend jaccl \
|
||||
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 --over thunderbolt \
|
||||
--auto-setup --output m3-ultra-jaccl.json
|
||||
|
||||
Let's walk through the steps that the script takes to configure the nodes.
|
||||
|
||||
1. Ssh to all nodes to verify that they are reachable
|
||||
2. Extract the thunderbolt connectivity. Namely run commands on each node to
|
||||
calculate which node is connected to which other node.
|
||||
3. Verify that we have a valid fully connected mesh
|
||||
4. Check that RDMA is enabled
|
||||
5. Extract the ethernet IP from interface en0
|
||||
6. Disable the thunderbolt bridge and set up peer to peer networks for each
|
||||
thunderbolt cable
|
||||
7. Write the hostfile
|
||||
|
||||
Knowing the above steps allows you to manually configure the nodes but also
|
||||
debug any configuration issue. For instance changing the Ethernet IP to a
|
||||
different interface directly in the config is possible (as long as it is
|
||||
reachable from all nodes).
|
||||
|
||||
The ``--auto-setup`` argument requires password-less sudo on each node. If it
|
||||
isn't available then the configuration script will print commands to be run on
|
||||
each node.
|
||||
|
||||
Ring over thunderbolt
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Setting up a ring backend over thunderbolt only requires changing the
|
||||
``--backend`` from ``jaccl`` to ``ring``.
|
||||
|
||||
The steps are very similar with the main difference being that instead of
|
||||
verifying that the nodes are fully connected, the script attempts to identify a
|
||||
ring topology (or multiple rings).
|
||||
|
||||
Ring over Ethernet
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Configuring the ring backend over ethernet doesn't require setting up network
|
||||
interface and as such it simply extracts the ``en0`` IP from each node and
|
||||
writes the hostfile.
|
||||
|
||||
Debugging cable connections
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
``mlx.distributed_config`` can help you debug the connectivity of your nodes
|
||||
over thunderbolt by exporting a graph of the connections.
|
||||
|
||||
Running
|
||||
|
||||
.. code-block::
|
||||
|
||||
mlx.distributed_config --verbose \
|
||||
--hosts host1,host2,host3,host4 \
|
||||
--over thunderbolt --dot
|
||||
|
||||
will export a `GraphViz <https://graphviz.org>`_ representation of the
|
||||
connections between the nodes which makes it very easy to figure out which
|
||||
cable is not connected correctly.
|
||||
|
||||
See :ref:`the JACCL section <jaccl_section>` for an example.
|
||||
|
||||
|
||||
``mlx.launch``
|
||||
--------------
|
||||
|
||||
The minimal usage example of ``mlx.launch`` is simply
|
||||
|
||||
@@ -33,6 +126,10 @@ the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
|
||||
It also takes care of forwarding the output of each remote process to stdout
|
||||
and stderr respectively.
|
||||
|
||||
Importantly, it also broadcasts stdin to each process which enables interactive
|
||||
programs to work in distributed mode as well as debugging using the interactive
|
||||
debugger.
|
||||
|
||||
Providing Hosts
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
@@ -63,10 +160,62 @@ host and on the same path. A good checklist to debug errors is the following:
|
||||
``mlx.launch --print-python`` to see what that path is.
|
||||
* the script you want to run is available on all hosts at the same path
|
||||
|
||||
If you are launching from a node with a completely different setup than the
|
||||
nodes that the program will run on, you can specify ``--no-verify-script`` so
|
||||
that ``mlx.launch`` does not attempt to verify that the executable and script
|
||||
exist locally before launching the distributed job.
|
||||
|
||||
.. _ring_specifics:
|
||||
|
||||
Ring Specifics
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
The :ref:`ring <ring_section>` backend, which is also the default
|
||||
backend, can be explicitly selected with the argument ``--backend ring``. The
|
||||
ring backend has some specific requirements and arguments that are different to
|
||||
other backends:
|
||||
|
||||
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
|
||||
ssh to a hostname that does not correspond to the IP we want to bind to we
|
||||
have to provide a hostfile.
|
||||
* ``--starting-port`` defines the port to bind to on the remote hosts.
|
||||
Specifically rank 0 for the first IP will use this port and each subsequent
|
||||
IP or rank will add 1 to this port.
|
||||
* ``--connections-per-ip`` allows us to increase the number of connections
|
||||
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
|
||||
``mpirun``.
|
||||
|
||||
.. _jaccl_specifics:
|
||||
|
||||
JACCL Specifics
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
The :ref:`JACCL <jaccl_section>` backend can be selected with the argument
|
||||
``--backend jaccl``. A hostfile is necessary to launch with this backend
|
||||
because it needs to contain the RDMA devices connecting each node to each other
|
||||
node.
|
||||
|
||||
NCCL Specifics
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
The :ref:`NCCL <nccl_section>` backend is the default backend for CUDA
|
||||
environments. When launching from a Mac to a Linux machine with CUDA then the
|
||||
backend should be selected using ``--backend nccl``.
|
||||
|
||||
The ``--repeat-hosts, -n`` argument should be used to launch multi-node and
|
||||
multi-gpu jobs. For instance
|
||||
|
||||
.. code-block::
|
||||
|
||||
mlx.launch --backend nccl --hosts linux-1,linux-2 -n 8 --no-verify-script -- ./my-job.sh
|
||||
|
||||
will attempt to launch 16 processes, 8 on each node that will all run
|
||||
``my-job.sh``.
|
||||
|
||||
.. _mpi_specifics:
|
||||
|
||||
MPI Specifics
|
||||
-------------
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
|
||||
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
|
||||
@@ -83,23 +232,3 @@ to choose a specific interface for the byte-transfer-layer of MPI we can call
|
||||
.. code:: shell
|
||||
|
||||
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
|
||||
|
||||
|
||||
.. _ring_specifics:
|
||||
|
||||
Ring Specifics
|
||||
--------------
|
||||
|
||||
The ring backend, which is also the default backend, can be explicitly selected
|
||||
with the argument ``--backend ring``. The ring backend has some specific
|
||||
requirements and arguments that are different to MPI:
|
||||
|
||||
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
|
||||
ssh to a hostname that does not correspond to the IP we want to bind to we
|
||||
have to provide a hostfile.
|
||||
* ``--starting-port`` defines the port to bind to on the remote hosts.
|
||||
Specifically rank 0 for the first IP will use this port and each subsequent
|
||||
IP or rank will add 1 to this port.
|
||||
* ``--connections-per-ip`` allows us to increase the number of connections
|
||||
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
|
||||
``mpirun``.
|
||||
|
||||
Reference in New Issue
Block a user