diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index 229d295cb..5a7344d9f 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -174,6 +174,7 @@ In detail: value_and_grad quantize + average_gradients .. toctree:: diff --git a/docs/src/usage/distributed.rst b/docs/src/usage/distributed.rst index ee0952194..493819c04 100644 --- a/docs/src/usage/distributed.rst +++ b/docs/src/usage/distributed.rst @@ -5,21 +5,27 @@ Distributed Communication .. currentmodule:: mlx.core.distributed -MLX utilizes `MPI `_ to -provide distributed communication operations that allow the computational cost -of training or inference to be shared across many physical machines. You can -see a list of the supported operations in the :ref:`API docs`. +MLX supports distributed communication operations that allow the computational cost +of training or inference to be shared across many physical machines. At the +moment we support two different communication backends: + +* `MPI `_ a + full-featured and mature distributed communications library +* A **ring** backend of our own that uses native TCP sockets and should be + faster for thunderbolt connections. + +The list of all currently supported operations and their documentation can be +seen in the :ref:`API docs`. .. note:: - A lot of operations may not be supported or not as fast as they should be. + Some operations may not be supported or not as fast as they should be. We are adding more and tuning the ones we have as we are figuring out the best way to do distributed computing on Macs using MLX. Getting Started --------------- -MLX already comes with the ability to "talk" to MPI if it is installed on the -machine. The minimal distributed program in MLX is as simple as: +A distributed program in MLX is as simple as: .. code:: python @@ -30,74 +36,79 @@ machine. The minimal distributed program in MLX is as simple as: print(world.rank(), x) The program above sums the array ``mx.ones(10)`` across all -distributed processes. If simply run with ``python``, however, only one -process is launched and no distributed communication takes place. +distributed processes. However, when this script is run with ``python`` only +one process is launched and no distributed communication takes place. Namely, +all operations in ``mx.distributed`` are noops when the distributed group has a +size of one. This property allows us to avoid code that checks if we are in a +distributed setting similar to the one below: -To launch the program in distributed mode we need to use ``mpirun`` or -``mpiexec`` depending on the MPI installation. The simplest possible way is the -following: +.. code:: python + + import mlx.core as mx + + x = ... + world = mx.distributed.init() + # No need for the check we can simply do x = mx.distributed.all_sum(x) + if world.size() > 1: + x = mx.distributed.all_sum(x) + +Running Distributed Programs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +MLX provides ``mlx.launch`` a helper script to launch distributed programs. +Continuing with our initial example we can run it on localhost with 4 processes using .. code:: shell - $ mpirun -np 2 python test.py - 1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32) - 0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32) + $ mlx.launch -n 4 my_script.py + 3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) + 2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) + 1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) + 0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) -The above launches two processes on the same (local) machine and we can see -both standard output streams. The processes send the array of 1s to each other -and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would -print 4 etc. - -Installing MPI ---------------- - -MPI can be installed with Homebrew, using the Anaconda package manager or -compiled from source. Most of our testing is done using ``openmpi`` installed -with the Anaconda package manager as follows: +We can also run it on some remote hosts by providing their IPs (provided that +the script exists on all hosts and they are reachable by ssh) .. code:: shell - $ conda install conda-forge::openmpi + $ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py + 3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) + 2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) + 1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) + 0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) -Installing with Homebrew may require specifying the location of ``libmpi.dyld`` -so that MLX can find it and load it at runtime. This can simply be achieved by -passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``. +Consult the dedicated :doc:`usage guide` for more +information on using ``mlx.launch``. -.. code:: shell +Selecting Backend +^^^^^^^^^^^^^^^^^ - $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py - -Setting up Remote Hosts ------------------------ - -MPI can automatically connect to remote hosts and set up the communication over -the network if the remote hosts can be accessed via ssh. A good checklist to -debug connectivity issues is the following: - -* ``ssh hostname`` works from all machines to all machines without asking for - password or host confirmation -* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its - full path to force all machines to use a specific path. -* Ensure that the ``hostname`` used by MPI is the one that you have configured - in the ``.ssh/config`` files on all machines. +You can select the backend you want to use when calling :func:`init` by passing +one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to +initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they +both fail then a singleton group is created. .. note:: - For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as - the hostname passed to ssh if the current hostname matches ``*.bar.com``. + After a distributed backend is successfully initialized :func:`init` will + return **the same backend** if called without arguments or with backend set to + ``any``. -An easy way to pass the host names to MPI is using a host file. A host file -looks like the following, where ``host1`` and ``host2`` should be the fully -qualified domain names or IPs for these hosts. +The following examples aim to clarify the backend initialization logic in MLX: -.. code:: +.. code:: python - host1 slots=1 - host2 slots=1 + # Case 1: Initialize MPI regardless if it was possible to initialize the ring backend + world = mx.distributed.init(backend="mpi") + world2 = mx.distributed.init() # subsequent calls return the MPI backend! -When using MLX, it is very likely that you want to use 1 slot per host, ie one -process per host. The hostfile also needs to contain the current -host if you want to run on the local host. Passing the host file to -``mpirun`` is simply done using the ``--hostfile`` command line argument. + # Case 2: Initialize any backend + world = mx.distributed.init(backend="any") # equivalent to no arguments + world2 = mx.distributed.init() # same as above + + # Case 3: Initialize both backends at the same time + world_mpi = mx.distributed.init(backend="mpi") + world_ring = mx.distributed.init(backend="ring") + world_any = mx.distributed.init() # same as MPI because it was initialized first! Training Example ---------------- @@ -155,13 +166,179 @@ everything else remaining the same. optimizer.update(model, grads) return loss -Tuning All Reduce ------------------ +Utilizing ``nn.average_gradients`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -We are working on improving the performance of all reduce on MLX but for now -the two main things one can do to extract the most out of distributed training with MLX are: +Although the code example above works correctly; it performs one communication +per gradient. It is significantly more efficient to aggregate several gradients +together and perform fewer communication steps. -1. Perform a few large reductions instead of many small ones to improve - bandwidth and latency -2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp - connections between each host to improve bandwidth +This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks +almost identical to the example above: + +.. code:: python + + model = ... + optimizer = ... + dataset = ... + + def step(model, x, y): + loss, grads = loss_grad_fn(model, x, y) + grads = mlx.nn.average_gradients(grads) # <---- This line was added + optimizer.update(model, grads) + return loss + + for x, y in dataset: + loss = step(model, x, y) + mx.eval(loss, model.parameters()) + + +Getting Started with MPI +------------------------ + +MLX already comes with the ability to "talk" to MPI if it is installed on the +machine. Launching distributed MLX programs that use MPI can be done with +``mpirun`` as expected. However, in the following examples we will be using +``mlx.launch --backend mpi`` which takes care of some nuisances such as setting +absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared +library. + +The simplest possible usage is the following which, assuming the minimal +example in the beginning of this page, should result in: + +.. code:: shell + + $ mlx.launch --backend mpi -n 2 test.py + 1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32) + 0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32) + +The above launches two processes on the same (local) machine and we can see +both standard output streams. The processes send the array of 1s to each other +and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would +print 4 etc. + +Installing MPI +^^^^^^^^^^^^^^ + +MPI can be installed with Homebrew, using the Anaconda package manager or +compiled from source. Most of our testing is done using ``openmpi`` installed +with the Anaconda package manager as follows: + +.. code:: shell + + $ conda install conda-forge::openmpi + +Installing with Homebrew may require specifying the location of ``libmpi.dyld`` +so that MLX can find it and load it at runtime. This can simply be achieved by +passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is +done automatically by ``mlx.launch``. + +.. code:: shell + + $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py + $ # or simply + $ mlx.launch -n 2 test.py + +Setting up Remote Hosts +^^^^^^^^^^^^^^^^^^^^^^^ + +MPI can automatically connect to remote hosts and set up the communication over +the network if the remote hosts can be accessed via ssh. A good checklist to +debug connectivity issues is the following: + +* ``ssh hostname`` works from all machines to all machines without asking for + password or host confirmation +* ``mpirun`` is accessible on all machines. +* Ensure that the ``hostname`` used by MPI is the one that you have configured + in the ``.ssh/config`` files on all machines. + +Tuning MPI All Reduce +^^^^^^^^^^^^^^^^^^^^^ + +.. note:: + + For faster all reduce consider using the ring backend either with Thunderbolt + connections or over Ethernet. + +Configure MPI to use N tcp connections between each host to improve bandwidth +by passing ``--mca btl_tcp_links N``. + +Force MPI to use the most performant network interface by setting ``--mca +btl_tcp_if_include `` where ```` should be the interface you want +to use. + +Getting Started with Ring +------------------------- + +The ring backend does not depend on any third party library so it is always +available. It uses TCP sockets so the nodes need to be reachable via a network. +As the name suggests the nodes are connected in a ring which means that rank 1 +can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3 +and so on and so forth. As a result :func:`send` and :func:`recv` with +arbitrary sender and receiver is not supported in the ring backend. + +Defining a Ring +^^^^^^^^^^^^^^^ + +The easiest way to define and use a ring is via a JSON hostfile and the +``mlx.launch`` :doc:`helper script `. For each node one +defines a hostname to ssh into to run commands on this node and one or more IPs +that this node will listen to for connections. + +For example the hostfile below defines a 4 node ring. ``hostname1`` will be +rank 0, ``hostname2`` rank 1 etc. + +.. code:: json + + [ + {"ssh": "hostname1", "ips": ["123.123.123.1"]}, + {"ssh": "hostname2", "ips": ["123.123.123.2"]}, + {"ssh": "hostname3", "ips": ["123.123.123.3"]}, + {"ssh": "hostname4", "ips": ["123.123.123.4"]} + ] + +Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each +node, run the script which will listen for connections in each of the provided +IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a +connection from ``123.123.123.4`` and so on and so forth. + +Thunderbolt Ring +^^^^^^^^^^^^^^^^ + +Although the ring backend can have benefits over MPI even for Ethernet, its +main purpose is to use Thunderbolt rings for higher bandwidth communication. +Setting up such thunderbolt rings can be done manually, but is a relatively +tedious process. To simplify this, we provide the utility ``mlx.distributed_config``. + +To use ``mlx.distributed_config`` your computers need to be accessible by ssh via +Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the +utility as follows: + +.. code:: shell + + mlx.distributed_config --verbose --hosts host1,host2,host3,host4 + +By default the script will attempt to discover the thunderbolt ring and provide +you with the commands to configure each node as well as the ``hostfile.json`` +to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes +then ``--auto-setup`` can be used to configure them automatically. + +To validate your connection without configuring anything +``mlx.distributed_config`` can also plot the ring using DOT format. + +.. code:: shell + + mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot + dot -Tpng ring.dot >ring.png + open ring.png + +If you want to go through the process manually, the steps are as follows: + +* Disable the thunderbolt bridge interface +* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces + corresponding to that cable in nodes ``i`` and ``i + 1``. +* Set up a unique subnetwork connecting the two nodes for the corresponding + interfaces. For instance if the cable corresponds to ``en2`` on node ``i`` + and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and + ``192.168.0.2`` respectively to the two nodes. For more details you can see + the commands prepared by the utility script. diff --git a/docs/src/usage/launching_distributed.rst b/docs/src/usage/launching_distributed.rst new file mode 100644 index 000000000..2e956a486 --- /dev/null +++ b/docs/src/usage/launching_distributed.rst @@ -0,0 +1,105 @@ +:orphan: + +.. _usage_launch_distributed: + +Launching Distributed Programs +============================== + +.. currentmodule:: mlx.core.distributed + +Installing the MLX python package provides a helper script ``mlx.launch`` that +can be used to run python scripts distributed on several nodes. It allows +launching using either the MPI backend or the ring backend. See the +:doc:`distributed docs ` for the different backends. + +Usage +----- + +The minimal usage example of ``mlx.launch`` is simply + +.. code:: shell + + mlx.launch --hosts ip1,ip2 my_script.py + +or for testing on localhost + +.. code:: shell + + mlx.launch -n 2 my_script.py + +The ``mlx.launch`` command connects to the provided host and launches the input +script on each host. It monitors each of the launched processes and terminates +the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated. +It also takes care of forwarding the output of each remote process to stdout +and stderr respectively. + +Providing Hosts +^^^^^^^^^^^^^^^^ + +Hosts can be provided as command line arguments, like above, but the way that +allows to fully define a list of hosts is via a JSON hostfile. The hostfile has +a very simple schema. It is simply a list of objects that define each host via +a hostname to ssh to and a list of IPs to utilize for the communication. + +.. code:: json + + [ + {"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]}, + {"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]} + ] + +You can use ``mlx.distributed_config --over ethernet`` to create a hostfile +with IPs corresponding to the ``en0`` interface. + +Setting up Remote Hosts +^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to be able to launch the script on each host we need to be able to +connect via ssh. Moreover the input script and python binary need to be on each +host and on the same path. A good checklist to debug errors is the following: + +* ``ssh hostname`` works without asking for password or host confirmation +* the python binary is available on all hosts at the same path. You can use + ``mlx.launch --print-python`` to see what that path is. +* the script you want to run is available on all hosts at the same path + +.. _mpi_specifics: + +MPI Specifics +------------- + +One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case, +``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover, + +* The IPs in the hostfile are ignored +* The ssh connectivity requirement is stronger as every node needs to be able + to connect to every other node +* ``mpirun`` needs to be available on every node at the same path + +Finally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance +to choose a specific interface for the byte-transfer-layer of MPI we can call +``mlx.launch`` as follows: + +.. code:: shell + + mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py + + +.. _ring_specifics: + +Ring Specifics +-------------- + +The ring backend, which is also the default backend, can be explicitly selected +with the argument ``--backend ring``. The ring backend has some specific +requirements and arguments that are different to MPI: + +* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to + ssh to a hostname that does not correspond to the IP we want to bind to we + have to provide a hostfile. +* ``--starting-port`` defines the port to bind to on the remote hosts. + Specifically rank 0 for the first IP will use this port and each subsequent + IP or rank will add 1 to this port. +* ``--connections-per-ip`` allows us to increase the number of connections + between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for + ``mpirun``. diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 7f45faf69..2c1d27c6e 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -297,7 +297,7 @@ def launch_ring(parser, hosts, args, command): "The ring backend requires IPs to be provided instead of hostnames" ) - port = 5000 + port = args.starting_port ring_hosts = [] for h in hosts: node = [] @@ -669,6 +669,11 @@ def distributed_config(): def main(): parser = argparse.ArgumentParser(description="Launch an MLX distributed program") + parser.add_argument( + "--print-python", + action="store_true", + help="Print the path to the current python executable and exit", + ) parser.add_argument( "--verbose", action="store_true", help="Print debug messages in stdout" ) @@ -707,11 +712,25 @@ def main(): type=int, help="How many connections per ip to use for the ring backend", ) + parser.add_argument( + "--starting-port", + "-p", + type=int, + default=5000, + help="For the ring backend listen on this port increasing by 1 per rank and IP", + ) parser.add_argument( "--cwd", help="Set the working directory on each node to the provided one" ) args, rest = parser.parse_known_args() + if args.print_python: + print(sys.executable) + return + + if len(rest) == 0: + parser.error("No script is provided") + # Try to extract a list of hosts and corresponding ips if args.hostfile is not None: hosts = parse_hostfile(parser, args.hostfile) diff --git a/python/mlx/nn/__init__.py b/python/mlx/nn/__init__.py index b2cb9e0f4..df72e2a94 100644 --- a/python/mlx/nn/__init__.py +++ b/python/mlx/nn/__init__.py @@ -2,4 +2,4 @@ from mlx.nn import init, losses from mlx.nn.layers import * -from mlx.nn.utils import value_and_grad +from mlx.nn.utils import average_gradients, value_and_grad diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index b4b6550af..ff24b8a95 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -68,19 +68,21 @@ void init_distributed(nb::module_& parent_module) { Example: - import mlx.core as mx + .. code:: python - group = mx.distributed.init(backend="ring") + import mlx.core as mx + group = mx.distributed.init(backend="ring") Args: strict (bool, optional): If set to False it returns a singleton group in case ``mx.distributed.is_available()`` returns False otherwise it throws a runtime error. Default: ``False`` - backend (str, optional): Select a specific distributed backend to - initialize. If set to ``any`` then try all available backends and - return the first one that succeeds. Subsequent calls will return - the first backend that was initialized. Default: ``any`` + backend (str, optional): Which distributed backend to initialize. + Possible values ``mpi``, ``ring``, ``any``. If set to ``any`` all + available backends are tried and the first one that succeeds + becomes the global group which will be returned in subsequent + calls. Default: ``any`` Returns: Group: The group representing all the launched processes.