mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
Ring docs (#1829)
This commit is contained in:
parent
607181644f
commit
5d68082881
@ -174,6 +174,7 @@ In detail:
|
|||||||
|
|
||||||
value_and_grad
|
value_and_grad
|
||||||
quantize
|
quantize
|
||||||
|
average_gradients
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
|
||||||
|
@ -5,21 +5,27 @@ Distributed Communication
|
|||||||
|
|
||||||
.. currentmodule:: mlx.core.distributed
|
.. currentmodule:: mlx.core.distributed
|
||||||
|
|
||||||
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
|
MLX supports distributed communication operations that allow the computational cost
|
||||||
provide distributed communication operations that allow the computational cost
|
of training or inference to be shared across many physical machines. At the
|
||||||
of training or inference to be shared across many physical machines. You can
|
moment we support two different communication backends:
|
||||||
see a list of the supported operations in the :ref:`API docs<distributed>`.
|
|
||||||
|
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
||||||
|
full-featured and mature distributed communications library
|
||||||
|
* A **ring** backend of our own that uses native TCP sockets and should be
|
||||||
|
faster for thunderbolt connections.
|
||||||
|
|
||||||
|
The list of all currently supported operations and their documentation can be
|
||||||
|
seen in the :ref:`API docs<distributed>`.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
A lot of operations may not be supported or not as fast as they should be.
|
Some operations may not be supported or not as fast as they should be.
|
||||||
We are adding more and tuning the ones we have as we are figuring out the
|
We are adding more and tuning the ones we have as we are figuring out the
|
||||||
best way to do distributed computing on Macs using MLX.
|
best way to do distributed computing on Macs using MLX.
|
||||||
|
|
||||||
Getting Started
|
Getting Started
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
A distributed program in MLX is as simple as:
|
||||||
machine. The minimal distributed program in MLX is as simple as:
|
|
||||||
|
|
||||||
.. code:: python
|
.. code:: python
|
||||||
|
|
||||||
@ -30,74 +36,79 @@ machine. The minimal distributed program in MLX is as simple as:
|
|||||||
print(world.rank(), x)
|
print(world.rank(), x)
|
||||||
|
|
||||||
The program above sums the array ``mx.ones(10)`` across all
|
The program above sums the array ``mx.ones(10)`` across all
|
||||||
distributed processes. If simply run with ``python``, however, only one
|
distributed processes. However, when this script is run with ``python`` only
|
||||||
process is launched and no distributed communication takes place.
|
one process is launched and no distributed communication takes place. Namely,
|
||||||
|
all operations in ``mx.distributed`` are noops when the distributed group has a
|
||||||
|
size of one. This property allows us to avoid code that checks if we are in a
|
||||||
|
distributed setting similar to the one below:
|
||||||
|
|
||||||
To launch the program in distributed mode we need to use ``mpirun`` or
|
.. code:: python
|
||||||
``mpiexec`` depending on the MPI installation. The simplest possible way is the
|
|
||||||
following:
|
import mlx.core as mx
|
||||||
|
|
||||||
|
x = ...
|
||||||
|
world = mx.distributed.init()
|
||||||
|
# No need for the check we can simply do x = mx.distributed.all_sum(x)
|
||||||
|
if world.size() > 1:
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
|
||||||
|
Running Distributed Programs
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
|
||||||
|
Continuing with our initial example we can run it on localhost with 4 processes using
|
||||||
|
|
||||||
.. code:: shell
|
.. code:: shell
|
||||||
|
|
||||||
$ mpirun -np 2 python test.py
|
$ mlx.launch -n 4 my_script.py
|
||||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
|
||||||
The above launches two processes on the same (local) machine and we can see
|
We can also run it on some remote hosts by providing their IPs (provided that
|
||||||
both standard output streams. The processes send the array of 1s to each other
|
the script exists on all hosts and they are reachable by ssh)
|
||||||
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
|
|
||||||
print 4 etc.
|
|
||||||
|
|
||||||
Installing MPI
|
|
||||||
---------------
|
|
||||||
|
|
||||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
|
||||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
|
||||||
with the Anaconda package manager as follows:
|
|
||||||
|
|
||||||
.. code:: shell
|
.. code:: shell
|
||||||
|
|
||||||
$ conda install conda-forge::openmpi
|
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
|
||||||
|
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
|
||||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
Consult the dedicated :doc:`usage guide<launching_distributed>` for more
|
||||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
information on using ``mlx.launch``.
|
||||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
|
|
||||||
|
|
||||||
.. code:: shell
|
Selecting Backend
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
You can select the backend you want to use when calling :func:`init` by passing
|
||||||
|
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
|
||||||
Setting up Remote Hosts
|
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
|
||||||
-----------------------
|
both fail then a singleton group is created.
|
||||||
|
|
||||||
MPI can automatically connect to remote hosts and set up the communication over
|
|
||||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
|
||||||
debug connectivity issues is the following:
|
|
||||||
|
|
||||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
|
||||||
password or host confirmation
|
|
||||||
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
|
|
||||||
full path to force all machines to use a specific path.
|
|
||||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
|
||||||
in the ``.ssh/config`` files on all machines.
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
|
After a distributed backend is successfully initialized :func:`init` will
|
||||||
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
|
return **the same backend** if called without arguments or with backend set to
|
||||||
|
``any``.
|
||||||
|
|
||||||
An easy way to pass the host names to MPI is using a host file. A host file
|
The following examples aim to clarify the backend initialization logic in MLX:
|
||||||
looks like the following, where ``host1`` and ``host2`` should be the fully
|
|
||||||
qualified domain names or IPs for these hosts.
|
|
||||||
|
|
||||||
.. code::
|
.. code:: python
|
||||||
|
|
||||||
host1 slots=1
|
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
|
||||||
host2 slots=1
|
world = mx.distributed.init(backend="mpi")
|
||||||
|
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
|
||||||
|
|
||||||
When using MLX, it is very likely that you want to use 1 slot per host, ie one
|
# Case 2: Initialize any backend
|
||||||
process per host. The hostfile also needs to contain the current
|
world = mx.distributed.init(backend="any") # equivalent to no arguments
|
||||||
host if you want to run on the local host. Passing the host file to
|
world2 = mx.distributed.init() # same as above
|
||||||
``mpirun`` is simply done using the ``--hostfile`` command line argument.
|
|
||||||
|
# Case 3: Initialize both backends at the same time
|
||||||
|
world_mpi = mx.distributed.init(backend="mpi")
|
||||||
|
world_ring = mx.distributed.init(backend="ring")
|
||||||
|
world_any = mx.distributed.init() # same as MPI because it was initialized first!
|
||||||
|
|
||||||
Training Example
|
Training Example
|
||||||
----------------
|
----------------
|
||||||
@ -155,13 +166,179 @@ everything else remaining the same.
|
|||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
Tuning All Reduce
|
Utilizing ``nn.average_gradients``
|
||||||
-----------------
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
We are working on improving the performance of all reduce on MLX but for now
|
Although the code example above works correctly; it performs one communication
|
||||||
the two main things one can do to extract the most out of distributed training with MLX are:
|
per gradient. It is significantly more efficient to aggregate several gradients
|
||||||
|
together and perform fewer communication steps.
|
||||||
|
|
||||||
1. Perform a few large reductions instead of many small ones to improve
|
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks
|
||||||
bandwidth and latency
|
almost identical to the example above:
|
||||||
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
|
|
||||||
connections between each host to improve bandwidth
|
.. code:: python
|
||||||
|
|
||||||
|
model = ...
|
||||||
|
optimizer = ...
|
||||||
|
dataset = ...
|
||||||
|
|
||||||
|
def step(model, x, y):
|
||||||
|
loss, grads = loss_grad_fn(model, x, y)
|
||||||
|
grads = mlx.nn.average_gradients(grads) # <---- This line was added
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
for x, y in dataset:
|
||||||
|
loss = step(model, x, y)
|
||||||
|
mx.eval(loss, model.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
Getting Started with MPI
|
||||||
|
------------------------
|
||||||
|
|
||||||
|
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||||
|
machine. Launching distributed MLX programs that use MPI can be done with
|
||||||
|
``mpirun`` as expected. However, in the following examples we will be using
|
||||||
|
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
|
||||||
|
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
|
||||||
|
library.
|
||||||
|
|
||||||
|
The simplest possible usage is the following which, assuming the minimal
|
||||||
|
example in the beginning of this page, should result in:
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
$ mlx.launch --backend mpi -n 2 test.py
|
||||||
|
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||||
|
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||||
|
|
||||||
|
The above launches two processes on the same (local) machine and we can see
|
||||||
|
both standard output streams. The processes send the array of 1s to each other
|
||||||
|
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
|
||||||
|
print 4 etc.
|
||||||
|
|
||||||
|
Installing MPI
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||||
|
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||||
|
with the Anaconda package manager as follows:
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
$ conda install conda-forge::openmpi
|
||||||
|
|
||||||
|
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||||
|
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||||
|
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
||||||
|
done automatically by ``mlx.launch``.
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||||
|
$ # or simply
|
||||||
|
$ mlx.launch -n 2 test.py
|
||||||
|
|
||||||
|
Setting up Remote Hosts
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
MPI can automatically connect to remote hosts and set up the communication over
|
||||||
|
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||||
|
debug connectivity issues is the following:
|
||||||
|
|
||||||
|
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||||
|
password or host confirmation
|
||||||
|
* ``mpirun`` is accessible on all machines.
|
||||||
|
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||||
|
in the ``.ssh/config`` files on all machines.
|
||||||
|
|
||||||
|
Tuning MPI All Reduce
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
For faster all reduce consider using the ring backend either with Thunderbolt
|
||||||
|
connections or over Ethernet.
|
||||||
|
|
||||||
|
Configure MPI to use N tcp connections between each host to improve bandwidth
|
||||||
|
by passing ``--mca btl_tcp_links N``.
|
||||||
|
|
||||||
|
Force MPI to use the most performant network interface by setting ``--mca
|
||||||
|
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
|
||||||
|
to use.
|
||||||
|
|
||||||
|
Getting Started with Ring
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
The ring backend does not depend on any third party library so it is always
|
||||||
|
available. It uses TCP sockets so the nodes need to be reachable via a network.
|
||||||
|
As the name suggests the nodes are connected in a ring which means that rank 1
|
||||||
|
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
|
||||||
|
and so on and so forth. As a result :func:`send` and :func:`recv` with
|
||||||
|
arbitrary sender and receiver is not supported in the ring backend.
|
||||||
|
|
||||||
|
Defining a Ring
|
||||||
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The easiest way to define and use a ring is via a JSON hostfile and the
|
||||||
|
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
|
||||||
|
defines a hostname to ssh into to run commands on this node and one or more IPs
|
||||||
|
that this node will listen to for connections.
|
||||||
|
|
||||||
|
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
|
||||||
|
rank 0, ``hostname2`` rank 1 etc.
|
||||||
|
|
||||||
|
.. code:: json
|
||||||
|
|
||||||
|
[
|
||||||
|
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
|
||||||
|
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
|
||||||
|
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
|
||||||
|
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
|
||||||
|
]
|
||||||
|
|
||||||
|
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
|
||||||
|
node, run the script which will listen for connections in each of the provided
|
||||||
|
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
|
||||||
|
connection from ``123.123.123.4`` and so on and so forth.
|
||||||
|
|
||||||
|
Thunderbolt Ring
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Although the ring backend can have benefits over MPI even for Ethernet, its
|
||||||
|
main purpose is to use Thunderbolt rings for higher bandwidth communication.
|
||||||
|
Setting up such thunderbolt rings can be done manually, but is a relatively
|
||||||
|
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
|
||||||
|
|
||||||
|
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
|
||||||
|
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
|
||||||
|
utility as follows:
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
|
||||||
|
|
||||||
|
By default the script will attempt to discover the thunderbolt ring and provide
|
||||||
|
you with the commands to configure each node as well as the ``hostfile.json``
|
||||||
|
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
|
||||||
|
then ``--auto-setup`` can be used to configure them automatically.
|
||||||
|
|
||||||
|
To validate your connection without configuring anything
|
||||||
|
``mlx.distributed_config`` can also plot the ring using DOT format.
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
|
||||||
|
dot -Tpng ring.dot >ring.png
|
||||||
|
open ring.png
|
||||||
|
|
||||||
|
If you want to go through the process manually, the steps are as follows:
|
||||||
|
|
||||||
|
* Disable the thunderbolt bridge interface
|
||||||
|
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
|
||||||
|
corresponding to that cable in nodes ``i`` and ``i + 1``.
|
||||||
|
* Set up a unique subnetwork connecting the two nodes for the corresponding
|
||||||
|
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
|
||||||
|
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
|
||||||
|
``192.168.0.2`` respectively to the two nodes. For more details you can see
|
||||||
|
the commands prepared by the utility script.
|
||||||
|
105
docs/src/usage/launching_distributed.rst
Normal file
105
docs/src/usage/launching_distributed.rst
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
:orphan:
|
||||||
|
|
||||||
|
.. _usage_launch_distributed:
|
||||||
|
|
||||||
|
Launching Distributed Programs
|
||||||
|
==============================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.distributed
|
||||||
|
|
||||||
|
Installing the MLX python package provides a helper script ``mlx.launch`` that
|
||||||
|
can be used to run python scripts distributed on several nodes. It allows
|
||||||
|
launching using either the MPI backend or the ring backend. See the
|
||||||
|
:doc:`distributed docs <distributed>` for the different backends.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
|
||||||
|
The minimal usage example of ``mlx.launch`` is simply
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
mlx.launch --hosts ip1,ip2 my_script.py
|
||||||
|
|
||||||
|
or for testing on localhost
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
mlx.launch -n 2 my_script.py
|
||||||
|
|
||||||
|
The ``mlx.launch`` command connects to the provided host and launches the input
|
||||||
|
script on each host. It monitors each of the launched processes and terminates
|
||||||
|
the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
|
||||||
|
It also takes care of forwarding the output of each remote process to stdout
|
||||||
|
and stderr respectively.
|
||||||
|
|
||||||
|
Providing Hosts
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Hosts can be provided as command line arguments, like above, but the way that
|
||||||
|
allows to fully define a list of hosts is via a JSON hostfile. The hostfile has
|
||||||
|
a very simple schema. It is simply a list of objects that define each host via
|
||||||
|
a hostname to ssh to and a list of IPs to utilize for the communication.
|
||||||
|
|
||||||
|
.. code:: json
|
||||||
|
|
||||||
|
[
|
||||||
|
{"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]},
|
||||||
|
{"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]}
|
||||||
|
]
|
||||||
|
|
||||||
|
You can use ``mlx.distributed_config --over ethernet`` to create a hostfile
|
||||||
|
with IPs corresponding to the ``en0`` interface.
|
||||||
|
|
||||||
|
Setting up Remote Hosts
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
In order to be able to launch the script on each host we need to be able to
|
||||||
|
connect via ssh. Moreover the input script and python binary need to be on each
|
||||||
|
host and on the same path. A good checklist to debug errors is the following:
|
||||||
|
|
||||||
|
* ``ssh hostname`` works without asking for password or host confirmation
|
||||||
|
* the python binary is available on all hosts at the same path. You can use
|
||||||
|
``mlx.launch --print-python`` to see what that path is.
|
||||||
|
* the script you want to run is available on all hosts at the same path
|
||||||
|
|
||||||
|
.. _mpi_specifics:
|
||||||
|
|
||||||
|
MPI Specifics
|
||||||
|
-------------
|
||||||
|
|
||||||
|
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
|
||||||
|
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
|
||||||
|
|
||||||
|
* The IPs in the hostfile are ignored
|
||||||
|
* The ssh connectivity requirement is stronger as every node needs to be able
|
||||||
|
to connect to every other node
|
||||||
|
* ``mpirun`` needs to be available on every node at the same path
|
||||||
|
|
||||||
|
Finally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance
|
||||||
|
to choose a specific interface for the byte-transfer-layer of MPI we can call
|
||||||
|
``mlx.launch`` as follows:
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
|
||||||
|
|
||||||
|
|
||||||
|
.. _ring_specifics:
|
||||||
|
|
||||||
|
Ring Specifics
|
||||||
|
--------------
|
||||||
|
|
||||||
|
The ring backend, which is also the default backend, can be explicitly selected
|
||||||
|
with the argument ``--backend ring``. The ring backend has some specific
|
||||||
|
requirements and arguments that are different to MPI:
|
||||||
|
|
||||||
|
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
|
||||||
|
ssh to a hostname that does not correspond to the IP we want to bind to we
|
||||||
|
have to provide a hostfile.
|
||||||
|
* ``--starting-port`` defines the port to bind to on the remote hosts.
|
||||||
|
Specifically rank 0 for the first IP will use this port and each subsequent
|
||||||
|
IP or rank will add 1 to this port.
|
||||||
|
* ``--connections-per-ip`` allows us to increase the number of connections
|
||||||
|
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
|
||||||
|
``mpirun``.
|
@ -297,7 +297,7 @@ def launch_ring(parser, hosts, args, command):
|
|||||||
"The ring backend requires IPs to be provided instead of hostnames"
|
"The ring backend requires IPs to be provided instead of hostnames"
|
||||||
)
|
)
|
||||||
|
|
||||||
port = 5000
|
port = args.starting_port
|
||||||
ring_hosts = []
|
ring_hosts = []
|
||||||
for h in hosts:
|
for h in hosts:
|
||||||
node = []
|
node = []
|
||||||
@ -669,6 +669,11 @@ def distributed_config():
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
|
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
|
||||||
|
parser.add_argument(
|
||||||
|
"--print-python",
|
||||||
|
action="store_true",
|
||||||
|
help="Print the path to the current python executable and exit",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose", action="store_true", help="Print debug messages in stdout"
|
"--verbose", action="store_true", help="Print debug messages in stdout"
|
||||||
)
|
)
|
||||||
@ -707,11 +712,25 @@ def main():
|
|||||||
type=int,
|
type=int,
|
||||||
help="How many connections per ip to use for the ring backend",
|
help="How many connections per ip to use for the ring backend",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--starting-port",
|
||||||
|
"-p",
|
||||||
|
type=int,
|
||||||
|
default=5000,
|
||||||
|
help="For the ring backend listen on this port increasing by 1 per rank and IP",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cwd", help="Set the working directory on each node to the provided one"
|
"--cwd", help="Set the working directory on each node to the provided one"
|
||||||
)
|
)
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
|
|
||||||
|
if args.print_python:
|
||||||
|
print(sys.executable)
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(rest) == 0:
|
||||||
|
parser.error("No script is provided")
|
||||||
|
|
||||||
# Try to extract a list of hosts and corresponding ips
|
# Try to extract a list of hosts and corresponding ips
|
||||||
if args.hostfile is not None:
|
if args.hostfile is not None:
|
||||||
hosts = parse_hostfile(parser, args.hostfile)
|
hosts = parse_hostfile(parser, args.hostfile)
|
||||||
|
@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
from mlx.nn import init, losses
|
from mlx.nn import init, losses
|
||||||
from mlx.nn.layers import *
|
from mlx.nn.layers import *
|
||||||
from mlx.nn.utils import value_and_grad
|
from mlx.nn.utils import average_gradients, value_and_grad
|
||||||
|
@ -68,19 +68,21 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
group = mx.distributed.init(backend="ring")
|
group = mx.distributed.init(backend="ring")
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
strict (bool, optional): If set to False it returns a singleton group
|
strict (bool, optional): If set to False it returns a singleton group
|
||||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||||
it throws a runtime error. Default: ``False``
|
it throws a runtime error. Default: ``False``
|
||||||
backend (str, optional): Select a specific distributed backend to
|
backend (str, optional): Which distributed backend to initialize.
|
||||||
initialize. If set to ``any`` then try all available backends and
|
Possible values ``mpi``, ``ring``, ``any``. If set to ``any`` all
|
||||||
return the first one that succeeds. Subsequent calls will return
|
available backends are tried and the first one that succeeds
|
||||||
the first backend that was initialized. Default: ``any``
|
becomes the global group which will be returned in subsequent
|
||||||
|
calls. Default: ``any``
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Group: The group representing all the launched processes.
|
Group: The group representing all the launched processes.
|
||||||
|
Loading…
Reference in New Issue
Block a user