diff --git a/docs/src/usage/distributed.rst b/docs/src/usage/distributed.rst index 73c40f8b9..df0d7b246 100644 --- a/docs/src/usage/distributed.rst +++ b/docs/src/usage/distributed.rst @@ -117,6 +117,8 @@ The following examples aim to clarify the backend initialization logic in MLX: world_ring = mx.distributed.init(backend="ring") world_any = mx.distributed.init() # same as MPI because it was initialized first! +.. _training_example: + Training Example ---------------- @@ -289,7 +291,7 @@ Enabling RDMA ^^^^^^^^^^^^^ Until the feature matures, enabling RDMA over thunderbolt is slightly more -involved and **cannot** be done remotely even with sudo. In fact it has to be +involved and **cannot** be done remotely even with sudo. In fact, it has to be done in macOS recovery: 1. `Start your computer in recovery `_. @@ -316,8 +318,8 @@ Defining a Mesh ^^^^^^^^^^^^^^^ The JACCL backend supports only fully connected topologies. Namely, there needs -to be a thunderbolt cable connecting all pairs of Macs directly. For example in -the following topology visualizations the left one is valid because there is a +to be a thunderbolt cable connecting all pairs of Macs directly. For example, in +the following topology visualizations, the left one is valid because there is a connection from any node to any other node, while for the one on the right M3 Ultra 1 is not connected to M3 Ultra 2. @@ -372,7 +374,7 @@ Even though TCP/IP is not used when communicating with Thunderbolt RDMA, disabling the thunderbolt bridge is still required as well as setting up isolated local networks for each thunderbolt connection. -All of the above can be done instead via ``mlx.distributed_config``. The helper +All of the above can be done instead via ``mlx.distributed_config``. This helper script will - ssh into each node @@ -384,13 +386,13 @@ script will Putting it All Together ^^^^^^^^^^^^^^^^^^^^^^ -For example to launch a distributed MLX script that uses JACCL is fairly simple +For example launching a distributed MLX script that uses JACCL is fairly simple if the nodes are reachable via ssh and have password-less sudo. First, connect all the thunderbolt cables. Then we can verify the connections by using the ``mlx.distributed_config`` script to visualize them. -.. code-block:: bash +.. code-block:: mlx.distributed_config --verbose \ --hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \ @@ -399,7 +401,7 @@ by using the ``mlx.distributed_config`` script to visualize them. After making sure that everything looks right we can auto-configure the nodes and save the hostfile to ``m3-ultra-jaccl.json`` by running: -.. code-block:: bash +.. code-block:: mlx.distributed_config --verbose \ --hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \ @@ -409,7 +411,7 @@ and save the hostfile to ``m3-ultra-jaccl.json`` by running: And now we are ready to run a distributed MLX script such as distributed inference of a gigantic model using MLX-LM. -.. code-block:: bash +.. code-block:: mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \ --env MLX_METAL_FAST_SYNCH=1 -- \ # <--- important @@ -428,6 +430,32 @@ of a gigantic model using MLX-LM. Getting Started with NCCL ------------------------- +MLX on CUDA environments ships with the ability to talk to `NCCL +`_ which is a high-performance collective +communication library that supports both multi-gpu and multi-node setups. + +For CUDA environments, NCCL is the default backend for ``mlx.launch`` and all +it takes to run a distributed job is + +.. code-block:: + + mlx.launch -n 8 test.py + + # perfect for interactive scripts + mlx.launch -n 8 python -m mlx_lm chat --model my-model --shard + +You can also use ``mlx.launch`` to ssh to a remote node and launch a script +with the same ease + +.. code-block:: + + mlx.launch --hosts my-cuda-node -n 8 test.py + +In many cases you may not want to use ``mlx.launch`` with the NCCL backend +because the cluster scheduler will be the one launching the processes. You can +:ref:`see which environment variables need to be defined ` in +order for the MLX NCCL backend to be initialized correctly. + .. _mpi_section: Getting Started with MPI @@ -507,9 +535,116 @@ Force MPI to use the most performant network interface by setting ``--mca btl_tcp_if_include `` where ```` should be the interface you want to use. +.. _no_mlx_launch: + Distributed Without ``mlx.launch`` ---------------------------------- +None of the implementations of the distributed backends require launching with +``mlx.launch``. The script simply connects to each host. Starts a process per +rank and sets up the necessary environment variables before delegating to your +MLX script. See the :doc:`dedicated documentation page ` +for more details. -Using the helper scripts -------------------------- +For many use-cases this will be the easiest way to perform distributed +computations in MLX. However, there may be reasons that you cannot or should +not use ``mlx.launch``. A common such case is the use of a scheduler that +starts all the processes for you on machines undetermined at the time of +scheduling the job. + +Below we list the environment variables required to use each backend. + +Ring +^^^^^^ + +**MLX_RANK** should contain a single 0-based integer that defines the rank of +the process. + +**MLX_HOSTFILE** should contain the path to a json file that contains IPs and +ports for each rank to listen to, something like the following: + +.. code-block:: json + + [ + ["123.123.1.1:5000", "123.123.1.2:5000"], + ["123.123.2.1:5000", "123.123.2.2:5000"], + ["123.123.3.1:5000", "123.123.3.2:5000"], + ["123.123.4.1:5000", "123.123.4.2:5000"] + ] + +**MLX_RING_VERBOSE** is optional and if set to 1 it enables some more logging +from the distributed backend. + +JACCL +^^^^^ + +**MLX_RANK** should contain a single 0-based integer that defines the rank of +the process. + +**MLX_JACCL_COORDINATOR** should contain the IP and port that rank 0 can listen +to all the other ranks connect to in order to establish the RDMA connections. + +**MLX_IBV_DEVICES** should contain the path to a json file that contains the +ibverbs device names that connect each node to each other node, something like +the following: + +.. code-block:: json + + [ + [null, "rdma_en5", "rdma_en4", "rdma_en3"], + ["rdma_en5", null, "rdma_en3", "rdma_en4"], + ["rdma_en4", "rdma_en3", null, "rdma_en5"], + ["rdma_en3", "rdma_en4", "rdma_en5", null] + ] + + +NCCL +^^^^^ + +**MLX_RANK** should contain a single 0-based integer that defines the rank of +the process. + +**MLX_WORLD_SIZE** should contain the total number of processes that will be +launched. + +**NCCL_HOST_IP** and **NCCL_PORT** should contain the IP and port that all +hosts can connect to to establish the NCCL communication. + +**CUDA_VISIBLE_DEVICES** should contain the local index of the gpu that +corresponds to this process. + +Of course any `other environment variable +`_ that is +used by NCCL can be set. + +.. _tips_and_tricks: + +Tips and Tricks +---------------- + +This is a small collection of tips to help you utilize better the distributed +communication capabilities of MLX. + +- *Test locally first.* + + You can use the pattern ``mlx.launch -n2 -- my_script.py`` to run a small + scale test on a single node first. + +- *Batch your communication.* + + As described in the :ref:`training example `, performing a + lot of small communication can hurt performance. Copy the approach of + :func:`mlx.nn.average_gradients` to gather many small communications in a + single large one. + +- *Visualize the connectivity.* + + Use ``mlx.distributed_config --hosts h1,h2,h3 --over thunderbolt --dot`` to + visualize the connnections and make sure that the cables are connected + correctly. See the :ref:`JACCL section ` for examples. + +- *Use the debugger.* + + ``mlx.launch`` is meant for interactive use. It broadcasts stdin to all + processes and gathers stdout from all processes. This makes using ``pdb`` a + breeze.