mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
345 lines
12 KiB
ReStructuredText
345 lines
12 KiB
ReStructuredText
.. _usage_distributed:
|
|
|
|
Distributed Communication
|
|
=========================
|
|
|
|
.. currentmodule:: mlx.core.distributed
|
|
|
|
MLX supports distributed communication operations that allow the computational cost
|
|
of training or inference to be shared across many physical machines. At the
|
|
moment we support two different communication backends:
|
|
|
|
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
|
full-featured and mature distributed communications library
|
|
* A **ring** backend of our own that uses native TCP sockets and should be
|
|
faster for thunderbolt connections.
|
|
|
|
The list of all currently supported operations and their documentation can be
|
|
seen in the :ref:`API docs<distributed>`.
|
|
|
|
.. note::
|
|
Some operations may not be supported or not as fast as they should be.
|
|
We are adding more and tuning the ones we have as we are figuring out the
|
|
best way to do distributed computing on Macs using MLX.
|
|
|
|
Getting Started
|
|
---------------
|
|
|
|
A distributed program in MLX is as simple as:
|
|
|
|
.. code:: python
|
|
|
|
import mlx.core as mx
|
|
|
|
world = mx.distributed.init()
|
|
x = mx.distributed.all_sum(mx.ones(10))
|
|
print(world.rank(), x)
|
|
|
|
The program above sums the array ``mx.ones(10)`` across all
|
|
distributed processes. However, when this script is run with ``python`` only
|
|
one process is launched and no distributed communication takes place. Namely,
|
|
all operations in ``mx.distributed`` are noops when the distributed group has a
|
|
size of one. This property allows us to avoid code that checks if we are in a
|
|
distributed setting similar to the one below:
|
|
|
|
.. code:: python
|
|
|
|
import mlx.core as mx
|
|
|
|
x = ...
|
|
world = mx.distributed.init()
|
|
# No need for the check we can simply do x = mx.distributed.all_sum(x)
|
|
if world.size() > 1:
|
|
x = mx.distributed.all_sum(x)
|
|
|
|
Running Distributed Programs
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
|
|
Continuing with our initial example we can run it on localhost with 4 processes using
|
|
|
|
.. code:: shell
|
|
|
|
$ mlx.launch -n 4 my_script.py
|
|
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
|
|
We can also run it on some remote hosts by providing their IPs (provided that
|
|
the script exists on all hosts and they are reachable by ssh)
|
|
|
|
.. code:: shell
|
|
|
|
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
|
|
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
|
|
Consult the dedicated :doc:`usage guide<launching_distributed>` for more
|
|
information on using ``mlx.launch``.
|
|
|
|
Selecting Backend
|
|
^^^^^^^^^^^^^^^^^
|
|
|
|
You can select the backend you want to use when calling :func:`init` by passing
|
|
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
|
|
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
|
|
both fail then a singleton group is created.
|
|
|
|
.. note::
|
|
After a distributed backend is successfully initialized :func:`init` will
|
|
return **the same backend** if called without arguments or with backend set to
|
|
``any``.
|
|
|
|
The following examples aim to clarify the backend initialization logic in MLX:
|
|
|
|
.. code:: python
|
|
|
|
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
|
|
world = mx.distributed.init(backend="mpi")
|
|
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
|
|
|
|
# Case 2: Initialize any backend
|
|
world = mx.distributed.init(backend="any") # equivalent to no arguments
|
|
world2 = mx.distributed.init() # same as above
|
|
|
|
# Case 3: Initialize both backends at the same time
|
|
world_mpi = mx.distributed.init(backend="mpi")
|
|
world_ring = mx.distributed.init(backend="ring")
|
|
world_any = mx.distributed.init() # same as MPI because it was initialized first!
|
|
|
|
Training Example
|
|
----------------
|
|
|
|
In this section we will adapt an MLX training loop to support data parallel
|
|
distributed training. Namely, we will average the gradients across a set of
|
|
hosts before applying them to the model.
|
|
|
|
Our training loop looks like the following code snippet if we omit the model,
|
|
dataset and optimizer initialization.
|
|
|
|
.. code:: python
|
|
|
|
model = ...
|
|
optimizer = ...
|
|
dataset = ...
|
|
|
|
def step(model, x, y):
|
|
loss, grads = loss_grad_fn(model, x, y)
|
|
optimizer.update(model, grads)
|
|
return loss
|
|
|
|
for x, y in dataset:
|
|
loss = step(model, x, y)
|
|
mx.eval(loss, model.parameters())
|
|
|
|
All we have to do to average the gradients across machines is perform an
|
|
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
|
|
have to :func:`mlx.utils.tree_map` the gradients with following function.
|
|
|
|
.. code:: python
|
|
|
|
def all_avg(x):
|
|
return mx.distributed.all_sum(x) / mx.distributed.init().size()
|
|
|
|
Putting everything together our training loop step looks as follows with
|
|
everything else remaining the same.
|
|
|
|
.. code:: python
|
|
|
|
from mlx.utils import tree_map
|
|
|
|
def all_reduce_grads(grads):
|
|
N = mx.distributed.init().size()
|
|
if N == 1:
|
|
return grads
|
|
return tree_map(
|
|
lambda x: mx.distributed.all_sum(x) / N,
|
|
grads
|
|
)
|
|
|
|
def step(model, x, y):
|
|
loss, grads = loss_grad_fn(model, x, y)
|
|
grads = all_reduce_grads(grads) # <--- This line was added
|
|
optimizer.update(model, grads)
|
|
return loss
|
|
|
|
Utilizing ``nn.average_gradients``
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
Although the code example above works correctly; it performs one communication
|
|
per gradient. It is significantly more efficient to aggregate several gradients
|
|
together and perform fewer communication steps.
|
|
|
|
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks
|
|
almost identical to the example above:
|
|
|
|
.. code:: python
|
|
|
|
model = ...
|
|
optimizer = ...
|
|
dataset = ...
|
|
|
|
def step(model, x, y):
|
|
loss, grads = loss_grad_fn(model, x, y)
|
|
grads = mlx.nn.average_gradients(grads) # <---- This line was added
|
|
optimizer.update(model, grads)
|
|
return loss
|
|
|
|
for x, y in dataset:
|
|
loss = step(model, x, y)
|
|
mx.eval(loss, model.parameters())
|
|
|
|
|
|
Getting Started with MPI
|
|
------------------------
|
|
|
|
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
|
machine. Launching distributed MLX programs that use MPI can be done with
|
|
``mpirun`` as expected. However, in the following examples we will be using
|
|
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
|
|
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
|
|
library.
|
|
|
|
The simplest possible usage is the following which, assuming the minimal
|
|
example in the beginning of this page, should result in:
|
|
|
|
.. code:: shell
|
|
|
|
$ mlx.launch --backend mpi -n 2 test.py
|
|
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
|
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
|
|
|
The above launches two processes on the same (local) machine and we can see
|
|
both standard output streams. The processes send the array of 1s to each other
|
|
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
|
|
print 4 etc.
|
|
|
|
Installing MPI
|
|
^^^^^^^^^^^^^^
|
|
|
|
MPI can be installed with Homebrew, using the Anaconda package manager or
|
|
compiled from source. Most of our testing is done using ``openmpi`` installed
|
|
with the Anaconda package manager as follows:
|
|
|
|
.. code:: shell
|
|
|
|
$ conda install conda-forge::openmpi
|
|
|
|
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
|
so that MLX can find it and load it at runtime. This can simply be achieved by
|
|
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
|
done automatically by ``mlx.launch``.
|
|
|
|
.. code:: shell
|
|
|
|
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
|
$ # or simply
|
|
$ mlx.launch -n 2 test.py
|
|
|
|
Setting up Remote Hosts
|
|
^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
MPI can automatically connect to remote hosts and set up the communication over
|
|
the network if the remote hosts can be accessed via ssh. A good checklist to
|
|
debug connectivity issues is the following:
|
|
|
|
* ``ssh hostname`` works from all machines to all machines without asking for
|
|
password or host confirmation
|
|
* ``mpirun`` is accessible on all machines.
|
|
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
|
in the ``.ssh/config`` files on all machines.
|
|
|
|
Tuning MPI All Reduce
|
|
^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
.. note::
|
|
|
|
For faster all reduce consider using the ring backend either with Thunderbolt
|
|
connections or over Ethernet.
|
|
|
|
Configure MPI to use N tcp connections between each host to improve bandwidth
|
|
by passing ``--mca btl_tcp_links N``.
|
|
|
|
Force MPI to use the most performant network interface by setting ``--mca
|
|
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
|
|
to use.
|
|
|
|
Getting Started with Ring
|
|
-------------------------
|
|
|
|
The ring backend does not depend on any third party library so it is always
|
|
available. It uses TCP sockets so the nodes need to be reachable via a network.
|
|
As the name suggests the nodes are connected in a ring which means that rank 1
|
|
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
|
|
and so on and so forth. As a result :func:`send` and :func:`recv` with
|
|
arbitrary sender and receiver is not supported in the ring backend.
|
|
|
|
Defining a Ring
|
|
^^^^^^^^^^^^^^^
|
|
|
|
The easiest way to define and use a ring is via a JSON hostfile and the
|
|
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
|
|
defines a hostname to ssh into to run commands on this node and one or more IPs
|
|
that this node will listen to for connections.
|
|
|
|
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
|
|
rank 0, ``hostname2`` rank 1 etc.
|
|
|
|
.. code:: json
|
|
|
|
[
|
|
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
|
|
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
|
|
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
|
|
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
|
|
]
|
|
|
|
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
|
|
node, run the script which will listen for connections in each of the provided
|
|
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
|
|
connection from ``123.123.123.4`` and so on and so forth.
|
|
|
|
Thunderbolt Ring
|
|
^^^^^^^^^^^^^^^^
|
|
|
|
Although the ring backend can have benefits over MPI even for Ethernet, its
|
|
main purpose is to use Thunderbolt rings for higher bandwidth communication.
|
|
Setting up such thunderbolt rings can be done manually, but is a relatively
|
|
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
|
|
|
|
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
|
|
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
|
|
utility as follows:
|
|
|
|
.. code:: shell
|
|
|
|
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
|
|
|
|
By default the script will attempt to discover the thunderbolt ring and provide
|
|
you with the commands to configure each node as well as the ``hostfile.json``
|
|
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
|
|
then ``--auto-setup`` can be used to configure them automatically.
|
|
|
|
To validate your connection without configuring anything
|
|
``mlx.distributed_config`` can also plot the ring using DOT format.
|
|
|
|
.. code:: shell
|
|
|
|
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
|
|
dot -Tpng ring.dot >ring.png
|
|
open ring.png
|
|
|
|
If you want to go through the process manually, the steps are as follows:
|
|
|
|
* Disable the thunderbolt bridge interface
|
|
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
|
|
corresponding to that cable in nodes ``i`` and ``i + 1``.
|
|
* Set up a unique subnetwork connecting the two nodes for the corresponding
|
|
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
|
|
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
|
|
``192.168.0.2`` respectively to the two nodes. For more details you can see
|
|
the commands prepared by the utility script.
|