mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Nccl reduce scatter, all gather (#2727)
* Added reduce scatter and all gather for nccl * fix unused import, delete unused file * small fix * deleted useless condition * fixed comments * fix bug in eval_gpu, renamed to sum_scatter, fix docs * final fix docs * remove and * Update mlx/distributed/mpi/mpi.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * fix broken set input output * fixes set output * typo * fix typo * no cpu, no gpu for reduce scatter --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
committed by
GitHub
parent
761f901a41
commit
27778156dc
@@ -95,4 +95,9 @@ void Recv::eval_cpu(
|
||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||
}
|
||||
|
||||
void ReduceScatter::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
|
||||
}
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -53,4 +53,69 @@ void AllReduce::eval_gpu(
|
||||
"Only all reduce sum, max, and min are supported.");
|
||||
}
|
||||
}
|
||||
|
||||
void AllGather::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto input = ensure_contiguous(inputs[0]);
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(outputs[0]);
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
distributed::detail::all_gather(group(), input, outputs[0], s);
|
||||
}
|
||||
|
||||
void ReduceScatter::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto input = ensure_contiguous(inputs[0]);
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(outputs[0]);
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::sum_scatter(group(), input, outputs[0], s);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only sum scatter is supported. ");
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -40,7 +40,6 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
} // namespace distributed
|
||||
|
||||
@@ -30,4 +30,9 @@ void Recv::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
void ReduceScatter::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error(
|
||||
"[ReduceScatter::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -138,6 +138,7 @@ NO_CPU_MULTI(AllReduce)
|
||||
NO_CPU_MULTI(AllGather)
|
||||
NO_CPU_MULTI(Send)
|
||||
NO_CPU_MULTI(Recv)
|
||||
NO_CPU_MULTI(ReduceScatter)
|
||||
} // namespace distributed
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -164,6 +164,7 @@ NO_GPU_MULTI(AllReduce)
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
NO_GPU_MULTI(ReduceScatter)
|
||||
} // namespace distributed
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -41,6 +41,14 @@ void recv(Group group, array& out, int src, Stream stream) {
|
||||
group.raw_group()->recv(out, src, stream);
|
||||
}
|
||||
|
||||
void sum_scatter(
|
||||
Group group,
|
||||
const array& input,
|
||||
array& output,
|
||||
Stream stream) {
|
||||
group.raw_group()->sum_scatter(input, output, stream);
|
||||
}
|
||||
|
||||
class EmptyGroup : public GroupImpl {
|
||||
public:
|
||||
Stream communication_stream(StreamOrDevice s) override {
|
||||
@@ -85,6 +93,10 @@ class EmptyGroup : public GroupImpl {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
void sum_scatter(const array&, array&, Stream) override {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@@ -28,6 +28,8 @@ class GroupImpl {
|
||||
virtual void recv(array& out, int src, Stream stream) = 0;
|
||||
virtual void all_max(const array& input, array& output, Stream stream) = 0;
|
||||
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
||||
virtual void
|
||||
sum_scatter(const array& input, array& output, Stream stream) = 0;
|
||||
};
|
||||
|
||||
/* Define the MLX stream that the communication should happen in. */
|
||||
@@ -51,4 +53,7 @@ void all_max(Group group, const array& input, array& output, Stream stream);
|
||||
/** Min reduction */
|
||||
void all_min(Group group, const array& input, array& output, Stream stream);
|
||||
|
||||
/** Reduce scatter with average operation */
|
||||
void sum_scatter(Group group, const array& input, array& output, Stream stream);
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
||||
|
||||
@@ -465,6 +465,10 @@ class MPIGroup : public GroupImpl {
|
||||
});
|
||||
}
|
||||
|
||||
void sum_scatter(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error("[mpi] sum_scatter not yet implemented.");
|
||||
}
|
||||
|
||||
private:
|
||||
MPI_Comm comm_;
|
||||
bool global_;
|
||||
|
||||
@@ -296,8 +296,17 @@ class NCCLGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void all_gather(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error(
|
||||
"[nccl] All gather not supported in NCCL backend.");
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
auto& encoder = cu::get_command_encoder(stream);
|
||||
CHECK_NCCL(ncclAllGather(
|
||||
input.data<T>(),
|
||||
output.data<T>(),
|
||||
input.size(),
|
||||
dt,
|
||||
comm_,
|
||||
encoder.stream()));
|
||||
});
|
||||
}
|
||||
|
||||
void send(const array& input, int dst, Stream stream) override {
|
||||
@@ -309,11 +318,24 @@ class NCCLGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void all_max(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
|
||||
});
|
||||
}
|
||||
|
||||
void all_min(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
|
||||
});
|
||||
}
|
||||
|
||||
void sum_scatter(const array& input, array& output, Stream stream) override {
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
reduce_scatter_impl<T>(input, output, stream, dt, ncclSum);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -335,6 +357,25 @@ class NCCLGroup : public GroupImpl {
|
||||
encoder.stream()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void reduce_scatter_impl(
|
||||
const array& input,
|
||||
array& output,
|
||||
Stream stream,
|
||||
ncclDataType_t dt,
|
||||
ncclRedOp_t op) {
|
||||
auto& encoder = cu::get_command_encoder(stream);
|
||||
|
||||
CHECK_NCCL(ncclReduceScatter(
|
||||
input.data<T>(),
|
||||
output.data<T>(),
|
||||
output.size(),
|
||||
dt,
|
||||
op,
|
||||
comm_,
|
||||
encoder.stream()));
|
||||
}
|
||||
|
||||
int rank_, size_;
|
||||
std::string initMethod_;
|
||||
ncclUniqueId uniqueId_;
|
||||
|
||||
@@ -157,4 +157,30 @@ array recv_like(
|
||||
return recv(x.shape(), x.dtype(), src, group_, s);
|
||||
}
|
||||
|
||||
array sum_scatter(
|
||||
const array& x,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
if (x.shape()[0] % group.size() != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[sum_scatter] Invalid shape=" << x.shape()
|
||||
<< " for a group of size " << group.size()
|
||||
<< ". The first dimension (axis 0) must be divisible by the group size.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto result_shape = x.shape();
|
||||
result_shape[0] /= group.size();
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
return array(
|
||||
std::move(result_shape),
|
||||
x.dtype(),
|
||||
std::make_shared<ReduceScatter>(stream, group, ReduceScatter::Sum),
|
||||
{x});
|
||||
}
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -48,4 +48,9 @@ array all_min(
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array sum_scatter(
|
||||
const array& x,
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -127,4 +127,30 @@ class Recv : public DistPrimitive {
|
||||
int src_;
|
||||
};
|
||||
|
||||
class ReduceScatter : public DistPrimitive {
|
||||
public:
|
||||
enum ReduceType { Sum, Min, Max };
|
||||
ReduceScatter(Stream stream, Group group, ReduceType reduce_type)
|
||||
: DistPrimitive(stream, group), reduce_type_(reduce_type) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
const char* name() const override {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return "Sum ReduceScatter";
|
||||
case Min:
|
||||
return "Min ReduceScatter";
|
||||
case Max:
|
||||
return "Max ReduceScatter";
|
||||
}
|
||||
return "<unknwon ReduceScatter>";
|
||||
}
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
};
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -731,6 +731,10 @@ class RingGroup : public GroupImpl {
|
||||
});
|
||||
}
|
||||
|
||||
void sum_scatter(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error("[ring] sum_scatter not supported.");
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T, typename ReduceOp>
|
||||
void all_reduce(
|
||||
|
||||
@@ -229,7 +229,7 @@ void init_distributed(nb::module_& parent_module) {
|
||||
x (array): Input array.
|
||||
dst (int): Rank of the destination process in the group.
|
||||
group (Group): The group of processes that will participate in the
|
||||
sned. If set to ``None`` the global group is used. Default:
|
||||
send. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
@@ -301,4 +301,36 @@ void init_distributed(nb::module_& parent_module) {
|
||||
Returns:
|
||||
array: The array that was received from ``src``.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"sum_scatter",
|
||||
[](const ScalarOrArray& x,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::sum_scatter(to_array(x), group, s);
|
||||
},
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def sum_scatter(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Sum ``x`` across all processes in the group and shard the result along the first axis across ranks.
|
||||
``x.shape[0]`` must be divisible by the group size.
|
||||
|
||||
The result is equivalent to ``all_sum(x)[rank*chunk_size:(rank+1)*chunk_size]``, where ``chunk_size = x.shape[0] // group.size()`` and ``rank`` is the rank of this process in the group.
|
||||
Note: ``all_sum`` is mentioned only for illustration; the actual implementation does not perform ``all_sum`` and uses a single reduce-scatter collective instead.
|
||||
Currently supported only for the NCCL backend.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
group (Group): The group of processes that will participate in the
|
||||
sum scatter. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
Returns:
|
||||
array: The output array with shape ``[x.shape[0] // group.size(), *x.shape[1:]]``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_tests
|
||||
@@ -63,6 +61,48 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
||||
finally:
|
||||
mx.distributed.all_sum = original_all_sum
|
||||
|
||||
def test_all_reduce(self):
|
||||
g = mx.distributed.init()
|
||||
dtypes = [
|
||||
(mx.int8, 0),
|
||||
(mx.uint8, 0),
|
||||
(mx.int32, 0),
|
||||
(mx.uint32, 0),
|
||||
(mx.float32, 1e-6),
|
||||
(mx.float16, 5e-3),
|
||||
(mx.bfloat16, 1e-1),
|
||||
]
|
||||
sizes = [
|
||||
(7,),
|
||||
(10,),
|
||||
(1024,),
|
||||
(1024, 1024),
|
||||
]
|
||||
key = mx.random.key(0)
|
||||
|
||||
for dt, rtol in dtypes:
|
||||
for sh in sizes:
|
||||
x = (mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10).astype(dt)
|
||||
|
||||
# All sum
|
||||
y = mx.distributed.all_sum(x[g.rank()], group=g)
|
||||
z = x.sum(0)
|
||||
maxrelerror = (y - z).abs()
|
||||
if rtol > 0:
|
||||
maxrelerror /= z.abs()
|
||||
maxrelerror = maxrelerror.max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
|
||||
# All max
|
||||
y = mx.distributed.all_max(x[g.rank()], group=g)
|
||||
z = x.max(0)
|
||||
self.assertTrue(mx.all(y == z))
|
||||
|
||||
# All min
|
||||
y = mx.distributed.all_min(x[g.rank()], group=g)
|
||||
z = x.min(0)
|
||||
self.assertTrue(mx.all(y == z))
|
||||
|
||||
def test_donation(self):
|
||||
x = mx.random.normal((1024,))
|
||||
mx.eval(x)
|
||||
@@ -103,18 +143,19 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
||||
y = lin(x)
|
||||
y1 = slin1(x)
|
||||
y2 = slin2(x[part])
|
||||
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
|
||||
self.assertTrue(mx.allclose(y[part], y1))
|
||||
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
|
||||
self.assertTrue(mx.allclose(y[part], y1, atol=self.atol, rtol=self.rtol))
|
||||
|
||||
# And their quant versions
|
||||
qlin = lin.to_quantized()
|
||||
slin1 = shard_linear(qlin, "all-to-sharded")
|
||||
slin2 = shard_linear(qlin, "sharded-to-all")
|
||||
y = qlin(x)
|
||||
y1 = slin1(x)
|
||||
y2 = slin2(x[part])
|
||||
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
|
||||
self.assertTrue(mx.allclose(y[part], y1))
|
||||
# And their quant versions (QuintizedMatmul is not supported on CUDA)
|
||||
if not mx.cuda.is_available():
|
||||
qlin = lin.to_quantized()
|
||||
slin1 = shard_linear(qlin, "all-to-sharded")
|
||||
slin2 = shard_linear(qlin, "sharded-to-all")
|
||||
y = qlin(x)
|
||||
y1 = slin1(x)
|
||||
y2 = slin2(x[part])
|
||||
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
|
||||
self.assertTrue(mx.allclose(y[part], y1))
|
||||
|
||||
# Check the backward works as expected
|
||||
def dummy_loss(model, x, y):
|
||||
@@ -197,12 +238,18 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
||||
g1["layers"][1]["bias"],
|
||||
g2["layers"][1]["bias"],
|
||||
atol=self.atol,
|
||||
rtol=self.rtol,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
||||
g1["layers"][3]["bias"],
|
||||
g2["layers"][3]["bias"],
|
||||
atol=self.atol,
|
||||
rtol=self.rtol,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -248,3 +295,20 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
||||
y1 = mod(x)
|
||||
y2 = smod(x)
|
||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
||||
|
||||
def test_all_gather(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.float16,
|
||||
mx.bfloat16,
|
||||
]
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_gather(x)
|
||||
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
|
||||
self.assertTrue(mx.all(y == 1))
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_distributed_tests
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
world = mx.distributed.init(strict=True, backend="mpi")
|
||||
_ = mx.distributed.init(strict=True, backend="mpi")
|
||||
cls.atol = 1e-6
|
||||
cls.rtol = 1e-4
|
||||
|
||||
def test_groups(self):
|
||||
world = mx.distributed.init()
|
||||
@@ -27,18 +28,11 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
sub = world.split(world.rank() // 2)
|
||||
self.assertEqual(sub.size(), 2)
|
||||
|
||||
def test_all_reduce(self):
|
||||
def test_all_reduce_extra(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
(mx.int8, 0),
|
||||
(mx.uint8, 0),
|
||||
(mx.int16, 0),
|
||||
(mx.uint16, 0),
|
||||
(mx.int32, 0),
|
||||
(mx.uint32, 0),
|
||||
(mx.float32, 1e-6),
|
||||
(mx.float16, 5e-3),
|
||||
(mx.bfloat16, 1e-1),
|
||||
(mx.complex64, 1e-6),
|
||||
]
|
||||
sizes = [
|
||||
@@ -76,16 +70,11 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
z = x.min(0)
|
||||
self.assertTrue(mx.all(y == z))
|
||||
|
||||
def test_all_gather(self):
|
||||
def test_all_gather_extra(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int16,
|
||||
mx.uint16,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.complex64,
|
||||
]
|
||||
for dt in dtypes:
|
||||
@@ -150,4 +139,4 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
@@ -1,283 +1,52 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_distributed_tests
|
||||
import mlx_tests
|
||||
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
||||
from mlx.nn.utils import average_gradients
|
||||
|
||||
|
||||
class TestNCCLDistributed(mlx_tests.MLXTestCase):
|
||||
class TestNCCLDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
world = mx.distributed.init(strict=True, backend="nccl")
|
||||
rank = world.rank()
|
||||
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
|
||||
_ = mx.distributed.init(strict=True, backend="nccl")
|
||||
cls.atol = 1e-4
|
||||
cls.rtol = 1e-4
|
||||
|
||||
def test_sum_scatter(self):
|
||||
|
||||
def test_all_reduce(self):
|
||||
world = mx.distributed.init()
|
||||
|
||||
dtypes = [
|
||||
(mx.int8, 0),
|
||||
(mx.uint8, 0),
|
||||
(mx.int32, 0),
|
||||
(mx.uint32, 0),
|
||||
(mx.float32, 1e-6),
|
||||
(mx.float16, 5e-3),
|
||||
(mx.bfloat16, 1e-1),
|
||||
]
|
||||
sizes = [
|
||||
(7,),
|
||||
(10,),
|
||||
(8,),
|
||||
(64,),
|
||||
(1024,),
|
||||
(1024, 1024),
|
||||
]
|
||||
key = mx.random.key(0)
|
||||
key = mx.random.key(world.rank())
|
||||
|
||||
for dt, rtol in dtypes:
|
||||
for sh in sizes:
|
||||
x = (
|
||||
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||
).astype(dt)
|
||||
x = (mx.random.uniform(shape=sh, key=key) * 10).astype(dt) # shape=sh
|
||||
|
||||
# All sum
|
||||
y = mx.distributed.all_sum(x[world.rank()])
|
||||
z = x.sum(0)
|
||||
maxrelerror = (y - z).abs()
|
||||
# Sum scatter
|
||||
y = mx.distributed.sum_scatter(x) # shape=sh/world.size()
|
||||
z = mx.distributed.all_sum(x) # shape=sh
|
||||
chunk = sh[0] // world.size()
|
||||
start = world.rank() * chunk
|
||||
stop = start + chunk
|
||||
z_ref = z[start:stop]
|
||||
|
||||
maxrelerror = (y - z_ref).abs()
|
||||
if rtol > 0:
|
||||
maxrelerror /= z.abs()
|
||||
maxrelerror /= z_ref.abs()
|
||||
maxrelerror = maxrelerror.max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
|
||||
def test_average_gradients(self):
|
||||
original_all_sum = mx.distributed.all_sum
|
||||
n_calls = 0
|
||||
xtype = None
|
||||
|
||||
def new_all_sum(x, **kwargs):
|
||||
nonlocal n_calls
|
||||
nonlocal xtype
|
||||
|
||||
n_calls += 1
|
||||
if xtype is not None:
|
||||
self.assertEqual(xtype, x.dtype)
|
||||
|
||||
return original_all_sum(x, **kwargs)
|
||||
|
||||
mx.distributed.all_sum = new_all_sum
|
||||
try:
|
||||
grads = [mx.ones(10) for i in range(10)]
|
||||
new_grads = average_gradients(grads)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 1)
|
||||
|
||||
n_calls = 0
|
||||
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 2)
|
||||
|
||||
n_calls = 0
|
||||
new_grads = average_gradients(grads, all_reduce_size=0)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 10)
|
||||
|
||||
n_calls = 0
|
||||
xtype = mx.float16
|
||||
new_grads = average_gradients(
|
||||
grads,
|
||||
all_reduce_size=2 * 50,
|
||||
communication_type=mx.float16,
|
||||
)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 2)
|
||||
|
||||
finally:
|
||||
mx.distributed.all_sum = original_all_sum
|
||||
|
||||
def test_donation(self):
|
||||
x = mx.random.normal((1024,))
|
||||
mx.eval(x)
|
||||
mx.synchronize()
|
||||
|
||||
mx.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
y = mx.distributed.all_sum(x)
|
||||
mx.eval(y)
|
||||
mx.synchronize()
|
||||
all_sum_only = mx.get_peak_memory()
|
||||
y = mx.distributed.all_sum(x) * scale
|
||||
mx.eval(y)
|
||||
mx.synchronize()
|
||||
all_sum_with_binary = mx.get_peak_memory()
|
||||
|
||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||
|
||||
def test_shard_linear(self):
|
||||
# Seed the prng to have the same inputs and weights generated everywhere
|
||||
mx.random.seed(0xF0F0F0F0)
|
||||
|
||||
# Prepare inputs
|
||||
world = mx.distributed.init()
|
||||
part = (
|
||||
slice(None),
|
||||
slice(
|
||||
world.rank() * 1024 // world.size(),
|
||||
(world.rank() + 1) * 1024 // world.size(),
|
||||
),
|
||||
)
|
||||
x = mx.random.normal((4, 1024))
|
||||
|
||||
# Create and shard some linear layers
|
||||
lin = nn.Linear(1024, 1024, bias=True)
|
||||
slin1 = shard_linear(lin, "all-to-sharded")
|
||||
slin2 = shard_linear(lin, "sharded-to-all")
|
||||
y = lin(x)
|
||||
y1 = slin1(x)
|
||||
y2 = slin2(x[part])
|
||||
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
|
||||
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
|
||||
|
||||
# Check the backward works as expected
|
||||
def dummy_loss(model, x, y):
|
||||
return (model(x) * y).sum()
|
||||
|
||||
mod = nn.Sequential(
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 128),
|
||||
)
|
||||
smod = nn.Sequential(
|
||||
shard_linear(mod.layers[0], "all-to-sharded"),
|
||||
shard_linear(mod.layers[1], "sharded-to-all"),
|
||||
shard_linear(mod.layers[2], "all-to-sharded"),
|
||||
shard_linear(mod.layers[3], "sharded-to-all"),
|
||||
)
|
||||
|
||||
grad1 = nn.value_and_grad(mod, dummy_loss)
|
||||
grad2 = nn.value_and_grad(smod, dummy_loss)
|
||||
|
||||
x = mx.random.normal((4, 128))
|
||||
y = mx.random.normal((4, 128))
|
||||
|
||||
l1, g1 = grad1(mod, x, y)
|
||||
l2, g2 = grad2(smod, x, y)
|
||||
mx.eval(l1, g1, l2, g2)
|
||||
|
||||
part = slice(
|
||||
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(l1, l2))
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][0]["weight"][part],
|
||||
g2["layers"][0]["weight"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][2]["weight"][part],
|
||||
g2["layers"][2]["weight"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][1]["weight"][:, part],
|
||||
g2["layers"][1]["weight"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][3]["weight"][:, part],
|
||||
g2["layers"][3]["weight"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][0]["bias"][part],
|
||||
g2["layers"][0]["bias"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][2]["bias"][part],
|
||||
g2["layers"][2]["bias"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
||||
)
|
||||
)
|
||||
|
||||
def test_shard_predicate(self):
|
||||
mx.random.seed(0xF0F0F0F0)
|
||||
|
||||
class MyConv(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.aggregate = kwargs.pop("aggregate", False)
|
||||
self.conv = nn.Conv2d(*args, **kwargs)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv(x)
|
||||
if self.aggregate:
|
||||
x = mx.distributed.all_sum(x)
|
||||
return x
|
||||
|
||||
def sharding(path, weight):
|
||||
parts = path.split(".")
|
||||
even = int(parts[1]) % 2 == 0
|
||||
if even:
|
||||
return 0
|
||||
else:
|
||||
return -1 if parts[-1] != "bias" else None
|
||||
|
||||
mod = nn.Sequential(
|
||||
MyConv(3, 128, kernel_size=3),
|
||||
MyConv(128, 128, kernel_size=3),
|
||||
MyConv(128, 128, kernel_size=3),
|
||||
MyConv(128, 3, kernel_size=3),
|
||||
)
|
||||
smod = nn.Sequential(
|
||||
MyConv(3, 128, kernel_size=3),
|
||||
MyConv(128, 128, kernel_size=3, aggregate=True),
|
||||
MyConv(128, 128, kernel_size=3),
|
||||
MyConv(128, 3, kernel_size=3, aggregate=True),
|
||||
)
|
||||
smod.update(mod.parameters())
|
||||
shard_inplace(smod, sharding)
|
||||
|
||||
x = mx.random.normal((4, 16, 16, 3))
|
||||
y1 = mod(x)
|
||||
y2 = smod(x)
|
||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_distributed_tests
|
||||
import mlx_tests
|
||||
@@ -10,7 +8,9 @@ import mlx_tests
|
||||
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
world = mx.distributed.init(strict=True, backend="ring")
|
||||
_ = mx.distributed.init(strict=True, backend="ring")
|
||||
cls.atol = 1e-6
|
||||
cls.rtol = 1e-4
|
||||
|
||||
def test_groups(self):
|
||||
world = mx.distributed.init()
|
||||
@@ -24,18 +24,11 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
sub = world.split(world.rank() % 2)
|
||||
|
||||
def test_all_reduce(self):
|
||||
def test_all_reduce_extra(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
(mx.int8, 0),
|
||||
(mx.uint8, 0),
|
||||
(mx.int16, 0),
|
||||
(mx.uint16, 0),
|
||||
(mx.int32, 0),
|
||||
(mx.uint32, 0),
|
||||
(mx.float32, 1e-6),
|
||||
(mx.float16, 5e-3),
|
||||
(mx.bfloat16, 1e-1),
|
||||
(mx.complex64, 1e-6),
|
||||
]
|
||||
sizes = [
|
||||
@@ -45,7 +38,6 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
(1024, 1024),
|
||||
]
|
||||
key = mx.random.key(0)
|
||||
reductions = ["min", "max", "sum"]
|
||||
|
||||
for dt, rtol in dtypes:
|
||||
for sh in sizes:
|
||||
@@ -72,16 +64,11 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
z = x.min(0)
|
||||
self.assertTrue(mx.all(y == z))
|
||||
|
||||
def test_all_gather(self):
|
||||
def test_all_gather_extra(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int16,
|
||||
mx.uint16,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.complex64,
|
||||
]
|
||||
for dt in dtypes:
|
||||
|
||||
Reference in New Issue
Block a user