From 0163a8e57ac01bea2c2cce7b4734975d74ae5c77 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 6 Jun 2024 11:37:00 -0700 Subject: [PATCH] Add docs for the distributed namespace (#1184) --- docs/src/index.rst | 2 + docs/src/python/distributed.rst | 19 +++ docs/src/usage/distributed.rst | 166 +++++++++++++++++++++++++++ examples/cpp/distributed.cpp | 2 +- mlx/distributed/distributed.h | 2 +- mlx/distributed/mpi/mpi.cpp | 2 +- mlx/distributed/no_distributed.cpp | 2 +- mlx/distributed/ops.cpp | 2 +- mlx/distributed/ops.h | 2 +- mlx/distributed/primitives.cpp | 6 +- python/src/distributed.cpp | 6 +- python/tests/mpi_test_distributed.py | 6 +- 12 files changed, 202 insertions(+), 15 deletions(-) create mode 100644 docs/src/python/distributed.rst create mode 100644 docs/src/usage/distributed.rst diff --git a/docs/src/index.rst b/docs/src/index.rst index 33d652f6d..fd5147ca6 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -43,6 +43,7 @@ are the CPU and GPU. usage/function_transforms usage/compile usage/numpy + usage/distributed usage/using_streams .. toctree:: @@ -69,6 +70,7 @@ are the CPU and GPU. python/metal python/nn python/optimizers + python/distributed python/tree_utils .. toctree:: diff --git a/docs/src/python/distributed.rst b/docs/src/python/distributed.rst new file mode 100644 index 000000000..cf9eae3f1 --- /dev/null +++ b/docs/src/python/distributed.rst @@ -0,0 +1,19 @@ +.. _distributed: + +.. currentmodule:: mlx.core.distributed + +Distributed Communication +========================== + +MLX provides a distributed communication package using MPI. The MPI library is +loaded at runtime; if MPI is available then distributed communication is also +made available. + +.. autosummary:: + :toctree: _autosummary + + Group + is_available + init + all_sum + all_gather diff --git a/docs/src/usage/distributed.rst b/docs/src/usage/distributed.rst new file mode 100644 index 000000000..702951a0c --- /dev/null +++ b/docs/src/usage/distributed.rst @@ -0,0 +1,166 @@ +.. _usage_distributed: + +Distributed Communication +========================= + +.. currentmodule:: mlx.core.distributed + +MLX utilizes `MPI `_ to +provide distributed communication operations that allow the computational cost +of training or inference to be shared across many physical machines. You can +see a list of the supported operations in the :ref:`API docs`. + +.. note:: + A lot of operations may not be supported or not as fast as they should be. + We are adding more and tuning the ones we have as we are figuring out the + best way to do distributed computing on Macs using MLX. + +Getting Started +--------------- + +MLX already comes with the ability to "talk" to MPI if it is installed on the +machine. The minimal distributed program in MLX is as simple as: + +.. code:: python + + import mlx.core as mx + + world = mx.distributed.init() + x = mx.distributed.all_sum(mx.ones(10)) + print(world.rank(), x) + +The program above sums the array ``mx.ones(10)`` across all +distributed processes. If simply run with ``python``, however, only one +process is launched and no distributed communication takes place. + +To launch the program in distributed mode we need to use ``mpirun`` or +``mpiexec`` depending on the MPI installation. The simplest possible way is the +following: + +.. code:: shell + + $ mpirun -np 2 python test.py + 1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32) + 0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32) + +The above launches two processes on the same (local) machine and we can see +both standard output streams. The processes send the array of 1s to each other +and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would +print 4 etc. + +Installing MPI +--------------- + +MPI can be installed with Homebrew, using the Anaconda package manager or +compiled from source. Most of our testing is done using ``openmpi`` installed +with the Anaconda package manager as follows: + +.. code:: shell + + $ conda install openmpi + +Installing with Homebrew may require specifying the location of ``libmpi.dyld`` +so that MLX can find it and load it at runtime. This can simply be achieved by +passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``. + +.. code:: shell + + $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py + +Setting up Remote Hosts +----------------------- + +MPI can automatically connect to remote hosts and set up the communication over +the network if the remote hosts can be accessed via ssh. A good checklist to +debug connectivity issues is the following: + +* ``ssh hostname`` works from all machines to all machines without asking for + password or host confirmation +* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its + full path to force all machines to use a specific path. +* Ensure that the ``hostname`` used by MPI is the one that you have configured + in the ``.ssh/config`` files on all machines. + +.. note:: + For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as + the hostname passed to ssh if the current hostname matches ``*.bar.com``. + +An easy way to pass the host names to MPI is using a host file. A host file +looks like the following, where ``host1`` and ``host2`` should be the fully +qualified domain names or IPs for these hosts. + +.. code:: + + host1 slots=1 + host2 slots=1 + +When using MLX, it is very likely that you want to use 1 slot per host, ie one +process per host. The hostfile also needs to contain the current +host if you want to run on the local host. Passing the host file to +``mpirun`` is simply done using the ``--hostfile`` command line argument. + +Training Example +---------------- + +In this section we will adapt an MLX training loop to support data parallel +distributed training. Namely, we will average the gradients across a set of +hosts before applying them to the model. + +Our training loop looks like the following code snippet if we omit the model, +dataset and optimizer initialization. + +.. code:: python + + model = ... + optimizer = ... + dataset = ... + + def step(model, x, y): + loss, grads = loss_grad_fn(model, x, y) + optimizer.update(model, grads) + return loss + + for x, y in dataset: + loss = step(model, x, y) + mx.eval(loss, model.parameters()) + +All we have to do to average the gradients across machines is perform an +:func:`all_sum` and divide by the size of the :class:`Group`. Namely we +have to :func:`mlx.utils.tree_map` the gradients with following function. + +.. code:: python + + def all_avg(x): + return mx.distributed.all_sum(x) / mx.distributed.init().size() + +Putting everything together our training loop step looks as follows with +everything else remaining the same. + +.. code:: python + + from mlx.utils import tree_map + + def all_reduce_grads(grads): + N = mx.distributed.init() + if N == 1: + return grads + return tree_map( + lambda x: mx.distributed.all_sum(x) / N, + grads) + + def step(model, x, y): + loss, grads = loss_grad_fn(model, x, y) + grads = all_reduce_grads(grads) # <--- This line was added + optimizer.update(model, grads) + return loss + +Tuning All Reduce +----------------- + +We are working on improving the performance of all reduce on MLX but for now +the two main things one can do to extract the most out of distributed training with MLX are: + +1. Perform a few large reductions instead of many small ones to improve + bandwidth and latency +2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp + connections between each host to improve bandwidth diff --git a/examples/cpp/distributed.cpp b/examples/cpp/distributed.cpp index 283751a62..14229c1a2 100644 --- a/examples/cpp/distributed.cpp +++ b/examples/cpp/distributed.cpp @@ -16,7 +16,7 @@ int main() { std::cout << global_group.rank() << " / " << global_group.size() << std::endl; array x = ones({10}); - array out = distributed::all_reduce_sum(x, global_group); + array out = distributed::all_sum(x, global_group); std::cout << out << std::endl; } diff --git a/mlx/distributed/distributed.h b/mlx/distributed/distributed.h index e2bd0771c..44d40bc73 100644 --- a/mlx/distributed/distributed.h +++ b/mlx/distributed/distributed.h @@ -56,7 +56,7 @@ namespace detail { Stream communication_stream(); /* Perform an all reduce sum operation */ -void all_reduce_sum(Group group, const array& input, array& output); +void all_sum(Group group, const array& input, array& output); /* Perform an all reduce sum operation */ void all_gather(Group group, const array& input, array& output); diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 8e9b6caa3..4ea4d8573 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -260,7 +260,7 @@ Stream communication_stream() { return comm_stream; } -void all_reduce_sum(Group group, const array& input_, array& output) { +void all_sum(Group group, const array& input_, array& output) { array input = ensure_row_contiguous(input_); mpi().all_reduce( (input.data() == output.data()) ? MPI_IN_PLACE diff --git a/mlx/distributed/no_distributed.cpp b/mlx/distributed/no_distributed.cpp index df889d9df..fcf346ad8 100644 --- a/mlx/distributed/no_distributed.cpp +++ b/mlx/distributed/no_distributed.cpp @@ -31,7 +31,7 @@ Stream communication_stream() { return comm_stream; } -void all_reduce_sum(Group group, const array& input, array& output) {} +void all_sum(Group group, const array& input, array& output) {} void all_gather(Group group, const array& input, array& output) {} } // namespace detail diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 69cf196cb..67f03cd31 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -17,7 +17,7 @@ Group to_group(std::optional group) { } // namespace -array all_reduce_sum(const array& x, std::optional group_) { +array all_sum(const array& x, std::optional group_) { auto group = to_group(group_); if (group.size() == 1) { diff --git a/mlx/distributed/ops.h b/mlx/distributed/ops.h index 1afe8dcc8..bc3fab08d 100644 --- a/mlx/distributed/ops.h +++ b/mlx/distributed/ops.h @@ -8,7 +8,7 @@ namespace mlx::core::distributed { -array all_reduce_sum(const array& x, std::optional group = std::nullopt); +array all_sum(const array& x, std::optional group = std::nullopt); array all_gather(const array& x, std::optional group = std::nullopt); } // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index 91e230b6c..c4b786787 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -24,7 +24,7 @@ void AllReduce::eval_cpu( switch (reduce_type_) { case Sum: - distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]); + distributed::detail::all_sum(group(), inputs[0], outputs[0]); break; default: throw std::runtime_error("Only all reduce sum is supported for now"); @@ -36,7 +36,7 @@ std::pair, std::vector> AllReduce::vmap( const std::vector& axes) { switch (reduce_type_) { case Sum: - return {{all_reduce_sum(inputs[0], group())}, axes}; + return {{all_sum(inputs[0], group())}, axes}; default: throw std::runtime_error("Only all reduce sum is supported for now"); } @@ -48,7 +48,7 @@ std::vector AllReduce::jvp( const std::vector& argnums) { switch (reduce_type_) { case Sum: - return {all_reduce_sum(tangents[0], group())}; + return {all_sum(tangents[0], group())}; default: throw std::runtime_error("Only all reduce sum is supported for now"); } diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index 5b01ffed3..e0a11a4fc 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -69,13 +69,13 @@ void init_distributed(nb::module_& parent_module) { )pbdoc"); m.def( - "all_reduce_sum", - &distributed::all_reduce_sum, + "all_sum", + &distributed::all_sum, "x"_a, nb::kw_only(), "group"_a = nb::none(), nb::sig( - "def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"), + "def all_sum(x: array, *, group: Optional[Group] = None) -> array"), R"pbdoc( All reduce sum. diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index 3e0504b76..6c6e96009 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -37,13 +37,13 @@ class TestDistributed(mlx_tests.MLXTestCase): ] for dt in dtypes: x = mx.ones((2, 2, 4), dtype=dt) - y = mx.distributed.all_reduce_sum(x) + y = mx.distributed.all_sum(x) self.assertTrue(mx.all(y == world.size())) sub = world.split(world.rank() % 2) for dt in dtypes: x = mx.ones((2, 2, 4), dtype=dt) - y = mx.distributed.all_reduce_sum(x, group=sub) + y = mx.distributed.all_sum(x, group=sub) self.assertTrue(mx.all(y == sub.size())) def test_all_gather(self): @@ -87,7 +87,7 @@ class TestDistributed(mlx_tests.MLXTestCase): sub_2 = world.split(world.rank() % 2) x = mx.ones((1, 8)) * world.rank() - y = mx.distributed.all_reduce_sum(x, group=sub_1) + y = mx.distributed.all_sum(x, group=sub_1) z = mx.distributed.all_gather(y, group=sub_2) z_target = mx.arange(8).reshape(4, 2).sum(-1, keepdims=True)