mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Add docs for the distributed namespace (#1184)
This commit is contained in:
parent
578842954c
commit
0163a8e57a
@ -43,6 +43,7 @@ are the CPU and GPU.
|
|||||||
usage/function_transforms
|
usage/function_transforms
|
||||||
usage/compile
|
usage/compile
|
||||||
usage/numpy
|
usage/numpy
|
||||||
|
usage/distributed
|
||||||
usage/using_streams
|
usage/using_streams
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
@ -69,6 +70,7 @@ are the CPU and GPU.
|
|||||||
python/metal
|
python/metal
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
|
python/distributed
|
||||||
python/tree_utils
|
python/tree_utils
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
19
docs/src/python/distributed.rst
Normal file
19
docs/src/python/distributed.rst
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
.. _distributed:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.distributed
|
||||||
|
|
||||||
|
Distributed Communication
|
||||||
|
==========================
|
||||||
|
|
||||||
|
MLX provides a distributed communication package using MPI. The MPI library is
|
||||||
|
loaded at runtime; if MPI is available then distributed communication is also
|
||||||
|
made available.
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Group
|
||||||
|
is_available
|
||||||
|
init
|
||||||
|
all_sum
|
||||||
|
all_gather
|
166
docs/src/usage/distributed.rst
Normal file
166
docs/src/usage/distributed.rst
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
.. _usage_distributed:
|
||||||
|
|
||||||
|
Distributed Communication
|
||||||
|
=========================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.distributed
|
||||||
|
|
||||||
|
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
|
||||||
|
provide distributed communication operations that allow the computational cost
|
||||||
|
of training or inference to be shared across many physical machines. You can
|
||||||
|
see a list of the supported operations in the :ref:`API docs<distributed>`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
A lot of operations may not be supported or not as fast as they should be.
|
||||||
|
We are adding more and tuning the ones we have as we are figuring out the
|
||||||
|
best way to do distributed computing on Macs using MLX.
|
||||||
|
|
||||||
|
Getting Started
|
||||||
|
---------------
|
||||||
|
|
||||||
|
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||||
|
machine. The minimal distributed program in MLX is as simple as:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
world = mx.distributed.init()
|
||||||
|
x = mx.distributed.all_sum(mx.ones(10))
|
||||||
|
print(world.rank(), x)
|
||||||
|
|
||||||
|
The program above sums the array ``mx.ones(10)`` across all
|
||||||
|
distributed processes. If simply run with ``python``, however, only one
|
||||||
|
process is launched and no distributed communication takes place.
|
||||||
|
|
||||||
|
To launch the program in distributed mode we need to use ``mpirun`` or
|
||||||
|
``mpiexec`` depending on the MPI installation. The simplest possible way is the
|
||||||
|
following:
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
$ mpirun -np 2 python test.py
|
||||||
|
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||||
|
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||||
|
|
||||||
|
The above launches two processes on the same (local) machine and we can see
|
||||||
|
both standard output streams. The processes send the array of 1s to each other
|
||||||
|
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
|
||||||
|
print 4 etc.
|
||||||
|
|
||||||
|
Installing MPI
|
||||||
|
---------------
|
||||||
|
|
||||||
|
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||||
|
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||||
|
with the Anaconda package manager as follows:
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
$ conda install openmpi
|
||||||
|
|
||||||
|
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||||
|
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||||
|
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||||
|
|
||||||
|
Setting up Remote Hosts
|
||||||
|
-----------------------
|
||||||
|
|
||||||
|
MPI can automatically connect to remote hosts and set up the communication over
|
||||||
|
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||||
|
debug connectivity issues is the following:
|
||||||
|
|
||||||
|
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||||
|
password or host confirmation
|
||||||
|
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
|
||||||
|
full path to force all machines to use a specific path.
|
||||||
|
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||||
|
in the ``.ssh/config`` files on all machines.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
|
||||||
|
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
|
||||||
|
|
||||||
|
An easy way to pass the host names to MPI is using a host file. A host file
|
||||||
|
looks like the following, where ``host1`` and ``host2`` should be the fully
|
||||||
|
qualified domain names or IPs for these hosts.
|
||||||
|
|
||||||
|
.. code::
|
||||||
|
|
||||||
|
host1 slots=1
|
||||||
|
host2 slots=1
|
||||||
|
|
||||||
|
When using MLX, it is very likely that you want to use 1 slot per host, ie one
|
||||||
|
process per host. The hostfile also needs to contain the current
|
||||||
|
host if you want to run on the local host. Passing the host file to
|
||||||
|
``mpirun`` is simply done using the ``--hostfile`` command line argument.
|
||||||
|
|
||||||
|
Training Example
|
||||||
|
----------------
|
||||||
|
|
||||||
|
In this section we will adapt an MLX training loop to support data parallel
|
||||||
|
distributed training. Namely, we will average the gradients across a set of
|
||||||
|
hosts before applying them to the model.
|
||||||
|
|
||||||
|
Our training loop looks like the following code snippet if we omit the model,
|
||||||
|
dataset and optimizer initialization.
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
model = ...
|
||||||
|
optimizer = ...
|
||||||
|
dataset = ...
|
||||||
|
|
||||||
|
def step(model, x, y):
|
||||||
|
loss, grads = loss_grad_fn(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
for x, y in dataset:
|
||||||
|
loss = step(model, x, y)
|
||||||
|
mx.eval(loss, model.parameters())
|
||||||
|
|
||||||
|
All we have to do to average the gradients across machines is perform an
|
||||||
|
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
|
||||||
|
have to :func:`mlx.utils.tree_map` the gradients with following function.
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
def all_avg(x):
|
||||||
|
return mx.distributed.all_sum(x) / mx.distributed.init().size()
|
||||||
|
|
||||||
|
Putting everything together our training loop step looks as follows with
|
||||||
|
everything else remaining the same.
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
from mlx.utils import tree_map
|
||||||
|
|
||||||
|
def all_reduce_grads(grads):
|
||||||
|
N = mx.distributed.init()
|
||||||
|
if N == 1:
|
||||||
|
return grads
|
||||||
|
return tree_map(
|
||||||
|
lambda x: mx.distributed.all_sum(x) / N,
|
||||||
|
grads)
|
||||||
|
|
||||||
|
def step(model, x, y):
|
||||||
|
loss, grads = loss_grad_fn(model, x, y)
|
||||||
|
grads = all_reduce_grads(grads) # <--- This line was added
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
Tuning All Reduce
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
We are working on improving the performance of all reduce on MLX but for now
|
||||||
|
the two main things one can do to extract the most out of distributed training with MLX are:
|
||||||
|
|
||||||
|
1. Perform a few large reductions instead of many small ones to improve
|
||||||
|
bandwidth and latency
|
||||||
|
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
|
||||||
|
connections between each host to improve bandwidth
|
@ -16,7 +16,7 @@ int main() {
|
|||||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||||
|
|
||||||
array x = ones({10});
|
array x = ones({10});
|
||||||
array out = distributed::all_reduce_sum(x, global_group);
|
array out = distributed::all_sum(x, global_group);
|
||||||
|
|
||||||
std::cout << out << std::endl;
|
std::cout << out << std::endl;
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ namespace detail {
|
|||||||
Stream communication_stream();
|
Stream communication_stream();
|
||||||
|
|
||||||
/* Perform an all reduce sum operation */
|
/* Perform an all reduce sum operation */
|
||||||
void all_reduce_sum(Group group, const array& input, array& output);
|
void all_sum(Group group, const array& input, array& output);
|
||||||
|
|
||||||
/* Perform an all reduce sum operation */
|
/* Perform an all reduce sum operation */
|
||||||
void all_gather(Group group, const array& input, array& output);
|
void all_gather(Group group, const array& input, array& output);
|
||||||
|
@ -260,7 +260,7 @@ Stream communication_stream() {
|
|||||||
return comm_stream;
|
return comm_stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_reduce_sum(Group group, const array& input_, array& output) {
|
void all_sum(Group group, const array& input_, array& output) {
|
||||||
array input = ensure_row_contiguous(input_);
|
array input = ensure_row_contiguous(input_);
|
||||||
mpi().all_reduce(
|
mpi().all_reduce(
|
||||||
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
||||||
|
@ -31,7 +31,7 @@ Stream communication_stream() {
|
|||||||
return comm_stream;
|
return comm_stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_reduce_sum(Group group, const array& input, array& output) {}
|
void all_sum(Group group, const array& input, array& output) {}
|
||||||
void all_gather(Group group, const array& input, array& output) {}
|
void all_gather(Group group, const array& input, array& output) {}
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
@ -17,7 +17,7 @@ Group to_group(std::optional<Group> group) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
array all_reduce_sum(const array& x, std::optional<Group> group_) {
|
array all_sum(const array& x, std::optional<Group> group_) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
|
array all_sum(const array& x, std::optional<Group> group = std::nullopt);
|
||||||
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
|
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -24,7 +24,7 @@ void AllReduce::eval_cpu(
|
|||||||
|
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Sum:
|
case Sum:
|
||||||
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
|
distributed::detail::all_sum(group(), inputs[0], outputs[0]);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
@ -36,7 +36,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
|||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Sum:
|
case Sum:
|
||||||
return {{all_reduce_sum(inputs[0], group())}, axes};
|
return {{all_sum(inputs[0], group())}, axes};
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
}
|
}
|
||||||
@ -48,7 +48,7 @@ std::vector<array> AllReduce::jvp(
|
|||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Sum:
|
case Sum:
|
||||||
return {all_reduce_sum(tangents[0], group())};
|
return {all_sum(tangents[0], group())};
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
}
|
}
|
||||||
|
@ -69,13 +69,13 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"all_reduce_sum",
|
"all_sum",
|
||||||
&distributed::all_reduce_sum,
|
&distributed::all_sum,
|
||||||
"x"_a,
|
"x"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"group"_a = nb::none(),
|
"group"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
"def all_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
All reduce sum.
|
All reduce sum.
|
||||||
|
|
||||||
|
@ -37,13 +37,13 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||||||
]
|
]
|
||||||
for dt in dtypes:
|
for dt in dtypes:
|
||||||
x = mx.ones((2, 2, 4), dtype=dt)
|
x = mx.ones((2, 2, 4), dtype=dt)
|
||||||
y = mx.distributed.all_reduce_sum(x)
|
y = mx.distributed.all_sum(x)
|
||||||
self.assertTrue(mx.all(y == world.size()))
|
self.assertTrue(mx.all(y == world.size()))
|
||||||
|
|
||||||
sub = world.split(world.rank() % 2)
|
sub = world.split(world.rank() % 2)
|
||||||
for dt in dtypes:
|
for dt in dtypes:
|
||||||
x = mx.ones((2, 2, 4), dtype=dt)
|
x = mx.ones((2, 2, 4), dtype=dt)
|
||||||
y = mx.distributed.all_reduce_sum(x, group=sub)
|
y = mx.distributed.all_sum(x, group=sub)
|
||||||
self.assertTrue(mx.all(y == sub.size()))
|
self.assertTrue(mx.all(y == sub.size()))
|
||||||
|
|
||||||
def test_all_gather(self):
|
def test_all_gather(self):
|
||||||
@ -87,7 +87,7 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||||||
sub_2 = world.split(world.rank() % 2)
|
sub_2 = world.split(world.rank() % 2)
|
||||||
|
|
||||||
x = mx.ones((1, 8)) * world.rank()
|
x = mx.ones((1, 8)) * world.rank()
|
||||||
y = mx.distributed.all_reduce_sum(x, group=sub_1)
|
y = mx.distributed.all_sum(x, group=sub_1)
|
||||||
z = mx.distributed.all_gather(y, group=sub_2)
|
z = mx.distributed.all_gather(y, group=sub_2)
|
||||||
z_target = mx.arange(8).reshape(4, 2).sum(-1, keepdims=True)
|
z_target = mx.arange(8).reshape(4, 2).sum(-1, keepdims=True)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user