Nccl reduce scatter, all gather (#2727)

* Added reduce scatter and all gather for nccl

* fix unused import, delete unused file

* small fix

* deleted useless condition

* fixed comments

* fix bug in eval_gpu, renamed to sum_scatter, fix docs

* final fix docs

* remove and

* Update mlx/distributed/mpi/mpi.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix broken set input output

* fixes set output

* typo

* fix typo

* no cpu, no gpu for reduce scatter

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Anastasiia Filippova
2025-11-05 17:21:11 +01:00
committed by GitHub
parent 761f901a41
commit 27778156dc
19 changed files with 351 additions and 311 deletions

View File

@@ -95,4 +95,9 @@ void Recv::eval_cpu(
distributed::detail::recv(group(), outputs[0], src_, stream());
}
void ReduceScatter::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
}
} // namespace mlx::core::distributed

View File

@@ -53,4 +53,69 @@ void AllReduce::eval_gpu(
"Only all reduce sum, max, and min are supported.");
}
}
void AllGather::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().row_contiguous) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);
auto capture = encoder.capture_context();
distributed::detail::all_gather(group(), input, outputs[0], s);
}
void ReduceScatter::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().row_contiguous) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);
auto capture = encoder.capture_context();
switch (reduce_type_) {
case Sum:
distributed::detail::sum_scatter(group(), input, outputs[0], s);
break;
default:
throw std::runtime_error("Only sum scatter is supported. ");
}
}
} // namespace mlx::core::distributed

View File

@@ -40,7 +40,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace distributed {
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed

View File

@@ -30,4 +30,9 @@ void Recv::eval_gpu(const std::vector<array>&, std::vector<array>&) {
throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation.");
}
void ReduceScatter::eval_gpu(const std::vector<array>&, std::vector<array>&) {
throw std::runtime_error(
"[ReduceScatter::eval_gpu] has no GPU implementation.");
}
} // namespace mlx::core::distributed

View File

@@ -138,6 +138,7 @@ NO_CPU_MULTI(AllReduce)
NO_CPU_MULTI(AllGather)
NO_CPU_MULTI(Send)
NO_CPU_MULTI(Recv)
NO_CPU_MULTI(ReduceScatter)
} // namespace distributed
} // namespace mlx::core

View File

@@ -164,6 +164,7 @@ NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
NO_GPU_MULTI(ReduceScatter)
} // namespace distributed
} // namespace mlx::core

View File

@@ -41,6 +41,14 @@ void recv(Group group, array& out, int src, Stream stream) {
group.raw_group()->recv(out, src, stream);
}
void sum_scatter(
Group group,
const array& input,
array& output,
Stream stream) {
group.raw_group()->sum_scatter(input, output, stream);
}
class EmptyGroup : public GroupImpl {
public:
Stream communication_stream(StreamOrDevice s) override {
@@ -85,6 +93,10 @@ class EmptyGroup : public GroupImpl {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void sum_scatter(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
};
} // namespace detail

View File

@@ -28,6 +28,8 @@ class GroupImpl {
virtual void recv(array& out, int src, Stream stream) = 0;
virtual void all_max(const array& input, array& output, Stream stream) = 0;
virtual void all_min(const array& input, array& output, Stream stream) = 0;
virtual void
sum_scatter(const array& input, array& output, Stream stream) = 0;
};
/* Define the MLX stream that the communication should happen in. */
@@ -51,4 +53,7 @@ void all_max(Group group, const array& input, array& output, Stream stream);
/** Min reduction */
void all_min(Group group, const array& input, array& output, Stream stream);
/** Reduce scatter with average operation */
void sum_scatter(Group group, const array& input, array& output, Stream stream);
} // namespace mlx::core::distributed::detail

View File

@@ -465,6 +465,10 @@ class MPIGroup : public GroupImpl {
});
}
void sum_scatter(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[mpi] sum_scatter not yet implemented.");
}
private:
MPI_Comm comm_;
bool global_;

View File

@@ -296,8 +296,17 @@ class NCCLGroup : public GroupImpl {
}
void all_gather(const array& input, array& output, Stream stream) override {
throw std::runtime_error(
"[nccl] All gather not supported in NCCL backend.");
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclAllGather(
input.data<T>(),
output.data<T>(),
input.size(),
dt,
comm_,
encoder.stream()));
});
}
void send(const array& input, int dst, Stream stream) override {
@@ -309,11 +318,24 @@ class NCCLGroup : public GroupImpl {
}
void all_max(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
});
}
void all_min(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
});
}
void sum_scatter(const array& input, array& output, Stream stream) override {
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
reduce_scatter_impl<T>(input, output, stream, dt, ncclSum);
});
}
template <typename T>
@@ -335,6 +357,25 @@ class NCCLGroup : public GroupImpl {
encoder.stream()));
}
template <typename T>
void reduce_scatter_impl(
const array& input,
array& output,
Stream stream,
ncclDataType_t dt,
ncclRedOp_t op) {
auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclReduceScatter(
input.data<T>(),
output.data<T>(),
output.size(),
dt,
op,
comm_,
encoder.stream()));
}
int rank_, size_;
std::string initMethod_;
ncclUniqueId uniqueId_;

View File

@@ -157,4 +157,30 @@ array recv_like(
return recv(x.shape(), x.dtype(), src, group_, s);
}
array sum_scatter(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
if (x.shape()[0] % group.size() != 0) {
std::ostringstream msg;
msg << "[sum_scatter] Invalid shape=" << x.shape()
<< " for a group of size " << group.size()
<< ". The first dimension (axis 0) must be divisible by the group size.";
throw std::invalid_argument(msg.str());
}
auto result_shape = x.shape();
result_shape[0] /= group.size();
auto stream = detail::communication_stream(group, s);
return array(
std::move(result_shape),
x.dtype(),
std::make_shared<ReduceScatter>(stream, group, ReduceScatter::Sum),
{x});
}
} // namespace mlx::core::distributed

View File

@@ -48,4 +48,9 @@ array all_min(
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array sum_scatter(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
} // namespace mlx::core::distributed

View File

@@ -127,4 +127,30 @@ class Recv : public DistPrimitive {
int src_;
};
class ReduceScatter : public DistPrimitive {
public:
enum ReduceType { Sum, Min, Max };
ReduceScatter(Stream stream, Group group, ReduceType reduce_type)
: DistPrimitive(stream, group), reduce_type_(reduce_type) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
const char* name() const override {
switch (reduce_type_) {
case Sum:
return "Sum ReduceScatter";
case Min:
return "Min ReduceScatter";
case Max:
return "Max ReduceScatter";
}
return "<unknwon ReduceScatter>";
}
private:
ReduceType reduce_type_;
};
} // namespace mlx::core::distributed

View File

@@ -731,6 +731,10 @@ class RingGroup : public GroupImpl {
});
}
void sum_scatter(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[ring] sum_scatter not supported.");
}
private:
template <typename T, typename ReduceOp>
void all_reduce(

View File

@@ -229,7 +229,7 @@ void init_distributed(nb::module_& parent_module) {
x (array): Input array.
dst (int): Rank of the destination process in the group.
group (Group): The group of processes that will participate in the
sned. If set to ``None`` the global group is used. Default:
send. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
@@ -301,4 +301,36 @@ void init_distributed(nb::module_& parent_module) {
Returns:
array: The array that was received from ``src``.
)pbdoc");
m.def(
"sum_scatter",
[](const ScalarOrArray& x,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::sum_scatter(to_array(x), group, s);
},
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def sum_scatter(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Sum ``x`` across all processes in the group and shard the result along the first axis across ranks.
``x.shape[0]`` must be divisible by the group size.
The result is equivalent to ``all_sum(x)[rank*chunk_size:(rank+1)*chunk_size]``, where ``chunk_size = x.shape[0] // group.size()`` and ``rank`` is the rank of this process in the group.
Note: ``all_sum`` is mentioned only for illustration; the actual implementation does not perform ``all_sum`` and uses a single reduce-scatter collective instead.
Currently supported only for the NCCL backend.
Args:
x (array): Input array.
group (Group): The group of processes that will participate in the
sum scatter. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The output array with shape ``[x.shape[0] // group.size(), *x.shape[1:]]``.
)pbdoc");
}

View File

@@ -1,7 +1,5 @@
# Copyright © 2025 Apple Inc.
import unittest
import mlx.core as mx
import mlx.nn as nn
import mlx_tests
@@ -63,6 +61,48 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
finally:
mx.distributed.all_sum = original_all_sum
def test_all_reduce(self):
g = mx.distributed.init()
dtypes = [
(mx.int8, 0),
(mx.uint8, 0),
(mx.int32, 0),
(mx.uint32, 0),
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
]
sizes = [
(7,),
(10,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
for dt, rtol in dtypes:
for sh in sizes:
x = (mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10).astype(dt)
# All sum
y = mx.distributed.all_sum(x[g.rank()], group=g)
z = x.sum(0)
maxrelerror = (y - z).abs()
if rtol > 0:
maxrelerror /= z.abs()
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
# All max
y = mx.distributed.all_max(x[g.rank()], group=g)
z = x.max(0)
self.assertTrue(mx.all(y == z))
# All min
y = mx.distributed.all_min(x[g.rank()], group=g)
z = x.min(0)
self.assertTrue(mx.all(y == z))
def test_donation(self):
x = mx.random.normal((1024,))
mx.eval(x)
@@ -103,18 +143,19 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
y = lin(x)
y1 = slin1(x)
y2 = slin2(x[part])
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
self.assertTrue(mx.allclose(y[part], y1))
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
self.assertTrue(mx.allclose(y[part], y1, atol=self.atol, rtol=self.rtol))
# And their quant versions
qlin = lin.to_quantized()
slin1 = shard_linear(qlin, "all-to-sharded")
slin2 = shard_linear(qlin, "sharded-to-all")
y = qlin(x)
y1 = slin1(x)
y2 = slin2(x[part])
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
self.assertTrue(mx.allclose(y[part], y1))
# And their quant versions (QuintizedMatmul is not supported on CUDA)
if not mx.cuda.is_available():
qlin = lin.to_quantized()
slin1 = shard_linear(qlin, "all-to-sharded")
slin2 = shard_linear(qlin, "sharded-to-all")
y = qlin(x)
y1 = slin1(x)
y2 = slin2(x[part])
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
self.assertTrue(mx.allclose(y[part], y1))
# Check the backward works as expected
def dummy_loss(model, x, y):
@@ -197,12 +238,18 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
)
self.assertTrue(
mx.allclose(
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
g1["layers"][1]["bias"],
g2["layers"][1]["bias"],
atol=self.atol,
rtol=self.rtol,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
g1["layers"][3]["bias"],
g2["layers"][3]["bias"],
atol=self.atol,
rtol=self.rtol,
)
)
@@ -248,3 +295,20 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
y1 = mod(x)
y2 = smod(x)
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
def test_all_gather(self):
world = mx.distributed.init()
dtypes = [
mx.int8,
mx.uint8,
mx.int32,
mx.uint32,
mx.float32,
mx.float16,
mx.bfloat16,
]
for dt in dtypes:
x = mx.ones((2, 2, 4), dtype=dt)
y = mx.distributed.all_gather(x)
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
self.assertTrue(mx.all(y == 1))

View File

@@ -1,15 +1,16 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_distributed_tests
import mlx_tests
class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
@classmethod
def setUpClass(cls):
world = mx.distributed.init(strict=True, backend="mpi")
_ = mx.distributed.init(strict=True, backend="mpi")
cls.atol = 1e-6
cls.rtol = 1e-4
def test_groups(self):
world = mx.distributed.init()
@@ -27,18 +28,11 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
sub = world.split(world.rank() // 2)
self.assertEqual(sub.size(), 2)
def test_all_reduce(self):
def test_all_reduce_extra(self):
world = mx.distributed.init()
dtypes = [
(mx.int8, 0),
(mx.uint8, 0),
(mx.int16, 0),
(mx.uint16, 0),
(mx.int32, 0),
(mx.uint32, 0),
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
(mx.complex64, 1e-6),
]
sizes = [
@@ -76,16 +70,11 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
z = x.min(0)
self.assertTrue(mx.all(y == z))
def test_all_gather(self):
def test_all_gather_extra(self):
world = mx.distributed.init()
dtypes = [
mx.int8,
mx.uint8,
mx.int16,
mx.uint16,
mx.int32,
mx.uint32,
mx.float32,
mx.complex64,
]
for dt in dtypes:
@@ -150,4 +139,4 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@@ -1,283 +1,52 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx_distributed_tests
import mlx_tests
from mlx.nn.layers.distributed import shard_inplace, shard_linear
from mlx.nn.utils import average_gradients
class TestNCCLDistributed(mlx_tests.MLXTestCase):
class TestNCCLDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
@classmethod
def setUpClass(cls):
world = mx.distributed.init(strict=True, backend="nccl")
rank = world.rank()
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
_ = mx.distributed.init(strict=True, backend="nccl")
cls.atol = 1e-4
cls.rtol = 1e-4
def test_sum_scatter(self):
def test_all_reduce(self):
world = mx.distributed.init()
dtypes = [
(mx.int8, 0),
(mx.uint8, 0),
(mx.int32, 0),
(mx.uint32, 0),
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
]
sizes = [
(7,),
(10,),
(8,),
(64,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
key = mx.random.key(world.rank())
for dt, rtol in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
x = (mx.random.uniform(shape=sh, key=key) * 10).astype(dt) # shape=sh
# All sum
y = mx.distributed.all_sum(x[world.rank()])
z = x.sum(0)
maxrelerror = (y - z).abs()
# Sum scatter
y = mx.distributed.sum_scatter(x) # shape=sh/world.size()
z = mx.distributed.all_sum(x) # shape=sh
chunk = sh[0] // world.size()
start = world.rank() * chunk
stop = start + chunk
z_ref = z[start:stop]
maxrelerror = (y - z_ref).abs()
if rtol > 0:
maxrelerror /= z.abs()
maxrelerror /= z_ref.abs()
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
def test_average_gradients(self):
original_all_sum = mx.distributed.all_sum
n_calls = 0
xtype = None
def new_all_sum(x, **kwargs):
nonlocal n_calls
nonlocal xtype
n_calls += 1
if xtype is not None:
self.assertEqual(xtype, x.dtype)
return original_all_sum(x, **kwargs)
mx.distributed.all_sum = new_all_sum
try:
grads = [mx.ones(10) for i in range(10)]
new_grads = average_gradients(grads)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 1)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=0)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 10)
n_calls = 0
xtype = mx.float16
new_grads = average_gradients(
grads,
all_reduce_size=2 * 50,
communication_type=mx.float16,
)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)
finally:
mx.distributed.all_sum = original_all_sum
def test_donation(self):
x = mx.random.normal((1024,))
mx.eval(x)
mx.synchronize()
mx.reset_peak_memory()
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
mx.synchronize()
all_sum_only = mx.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
mx.synchronize()
all_sum_with_binary = mx.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)
def test_shard_linear(self):
# Seed the prng to have the same inputs and weights generated everywhere
mx.random.seed(0xF0F0F0F0)
# Prepare inputs
world = mx.distributed.init()
part = (
slice(None),
slice(
world.rank() * 1024 // world.size(),
(world.rank() + 1) * 1024 // world.size(),
),
)
x = mx.random.normal((4, 1024))
# Create and shard some linear layers
lin = nn.Linear(1024, 1024, bias=True)
slin1 = shard_linear(lin, "all-to-sharded")
slin2 = shard_linear(lin, "sharded-to-all")
y = lin(x)
y1 = slin1(x)
y2 = slin2(x[part])
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
# Check the backward works as expected
def dummy_loss(model, x, y):
return (model(x) * y).sum()
mod = nn.Sequential(
nn.Linear(128, 128),
nn.Linear(128, 128),
nn.Linear(128, 128),
nn.Linear(128, 128),
)
smod = nn.Sequential(
shard_linear(mod.layers[0], "all-to-sharded"),
shard_linear(mod.layers[1], "sharded-to-all"),
shard_linear(mod.layers[2], "all-to-sharded"),
shard_linear(mod.layers[3], "sharded-to-all"),
)
grad1 = nn.value_and_grad(mod, dummy_loss)
grad2 = nn.value_and_grad(smod, dummy_loss)
x = mx.random.normal((4, 128))
y = mx.random.normal((4, 128))
l1, g1 = grad1(mod, x, y)
l2, g2 = grad2(smod, x, y)
mx.eval(l1, g1, l2, g2)
part = slice(
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
)
self.assertTrue(mx.allclose(l1, l2))
self.assertTrue(
mx.allclose(
g1["layers"][0]["weight"][part],
g2["layers"][0]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][2]["weight"][part],
g2["layers"][2]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][1]["weight"][:, part],
g2["layers"][1]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][3]["weight"][:, part],
g2["layers"][3]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][0]["bias"][part],
g2["layers"][0]["bias"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][2]["bias"][part],
g2["layers"][2]["bias"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
)
)
self.assertTrue(
mx.allclose(
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
)
)
def test_shard_predicate(self):
mx.random.seed(0xF0F0F0F0)
class MyConv(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.aggregate = kwargs.pop("aggregate", False)
self.conv = nn.Conv2d(*args, **kwargs)
def __call__(self, x):
x = self.conv(x)
if self.aggregate:
x = mx.distributed.all_sum(x)
return x
def sharding(path, weight):
parts = path.split(".")
even = int(parts[1]) % 2 == 0
if even:
return 0
else:
return -1 if parts[-1] != "bias" else None
mod = nn.Sequential(
MyConv(3, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3),
MyConv(128, 3, kernel_size=3),
)
smod = nn.Sequential(
MyConv(3, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3, aggregate=True),
MyConv(128, 128, kernel_size=3),
MyConv(128, 3, kernel_size=3, aggregate=True),
)
smod.update(mod.parameters())
shard_inplace(smod, sharding)
x = mx.random.normal((4, 16, 16, 3))
y1 = mod(x)
y2 = smod(x)
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()

View File

@@ -1,7 +1,5 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_distributed_tests
import mlx_tests
@@ -10,7 +8,9 @@ import mlx_tests
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
@classmethod
def setUpClass(cls):
world = mx.distributed.init(strict=True, backend="ring")
_ = mx.distributed.init(strict=True, backend="ring")
cls.atol = 1e-6
cls.rtol = 1e-4
def test_groups(self):
world = mx.distributed.init()
@@ -24,18 +24,11 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
with self.assertRaises(RuntimeError):
sub = world.split(world.rank() % 2)
def test_all_reduce(self):
def test_all_reduce_extra(self):
world = mx.distributed.init()
dtypes = [
(mx.int8, 0),
(mx.uint8, 0),
(mx.int16, 0),
(mx.uint16, 0),
(mx.int32, 0),
(mx.uint32, 0),
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
(mx.complex64, 1e-6),
]
sizes = [
@@ -45,7 +38,6 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
(1024, 1024),
]
key = mx.random.key(0)
reductions = ["min", "max", "sum"]
for dt, rtol in dtypes:
for sh in sizes:
@@ -72,16 +64,11 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
z = x.min(0)
self.assertTrue(mx.all(y == z))
def test_all_gather(self):
def test_all_gather_extra(self):
world = mx.distributed.init()
dtypes = [
mx.int8,
mx.uint8,
mx.int16,
mx.uint16,
mx.int32,
mx.uint32,
mx.float32,
mx.complex64,
]
for dt in dtypes: