mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Nccl reduce scatter, all gather (#2727)
* Added reduce scatter and all gather for nccl * fix unused import, delete unused file * small fix * deleted useless condition * fixed comments * fix bug in eval_gpu, renamed to sum_scatter, fix docs * final fix docs * remove and * Update mlx/distributed/mpi/mpi.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * fix broken set input output * fixes set output * typo * fix typo * no cpu, no gpu for reduce scatter --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
committed by
GitHub
parent
761f901a41
commit
27778156dc
@@ -95,4 +95,9 @@ void Recv::eval_cpu(
|
|||||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ReduceScatter::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
|
||||||
|
}
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
|||||||
@@ -53,4 +53,69 @@ void AllReduce::eval_gpu(
|
|||||||
"Only all reduce sum, max, and min are supported.");
|
"Only all reduce sum, max, and min are supported.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AllGather::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||||
|
if (x.flags().row_contiguous) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
|
encoder.add_temporary(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto input = ensure_contiguous(inputs[0]);
|
||||||
|
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||||
|
|
||||||
|
encoder.set_input_array(input);
|
||||||
|
encoder.set_output_array(outputs[0]);
|
||||||
|
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
distributed::detail::all_gather(group(), input, outputs[0], s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReduceScatter::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||||
|
if (x.flags().row_contiguous) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
|
encoder.add_temporary(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto input = ensure_contiguous(inputs[0]);
|
||||||
|
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||||
|
|
||||||
|
encoder.set_input_array(input);
|
||||||
|
encoder.set_output_array(outputs[0]);
|
||||||
|
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Sum:
|
||||||
|
distributed::detail::sum_scatter(group(), input, outputs[0], s);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Only sum scatter is supported. ");
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ NO_GPU_MULTI(Eig)
|
|||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
NO_GPU_MULTI(AllGather)
|
|
||||||
NO_GPU_MULTI(Send)
|
NO_GPU_MULTI(Send)
|
||||||
NO_GPU_MULTI(Recv)
|
NO_GPU_MULTI(Recv)
|
||||||
} // namespace distributed
|
} // namespace distributed
|
||||||
|
|||||||
@@ -30,4 +30,9 @@ void Recv::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
|||||||
throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation.");
|
throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ReduceScatter::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[ReduceScatter::eval_gpu] has no GPU implementation.");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
|||||||
@@ -138,6 +138,7 @@ NO_CPU_MULTI(AllReduce)
|
|||||||
NO_CPU_MULTI(AllGather)
|
NO_CPU_MULTI(AllGather)
|
||||||
NO_CPU_MULTI(Send)
|
NO_CPU_MULTI(Send)
|
||||||
NO_CPU_MULTI(Recv)
|
NO_CPU_MULTI(Recv)
|
||||||
|
NO_CPU_MULTI(ReduceScatter)
|
||||||
} // namespace distributed
|
} // namespace distributed
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -164,6 +164,7 @@ NO_GPU_MULTI(AllReduce)
|
|||||||
NO_GPU_MULTI(AllGather)
|
NO_GPU_MULTI(AllGather)
|
||||||
NO_GPU_MULTI(Send)
|
NO_GPU_MULTI(Send)
|
||||||
NO_GPU_MULTI(Recv)
|
NO_GPU_MULTI(Recv)
|
||||||
|
NO_GPU_MULTI(ReduceScatter)
|
||||||
} // namespace distributed
|
} // namespace distributed
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -41,6 +41,14 @@ void recv(Group group, array& out, int src, Stream stream) {
|
|||||||
group.raw_group()->recv(out, src, stream);
|
group.raw_group()->recv(out, src, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void sum_scatter(
|
||||||
|
Group group,
|
||||||
|
const array& input,
|
||||||
|
array& output,
|
||||||
|
Stream stream) {
|
||||||
|
group.raw_group()->sum_scatter(input, output, stream);
|
||||||
|
}
|
||||||
|
|
||||||
class EmptyGroup : public GroupImpl {
|
class EmptyGroup : public GroupImpl {
|
||||||
public:
|
public:
|
||||||
Stream communication_stream(StreamOrDevice s) override {
|
Stream communication_stream(StreamOrDevice s) override {
|
||||||
@@ -85,6 +93,10 @@ class EmptyGroup : public GroupImpl {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Communication not implemented in an empty distributed group.");
|
"Communication not implemented in an empty distributed group.");
|
||||||
}
|
}
|
||||||
|
void sum_scatter(const array&, array&, Stream) override {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Communication not implemented in an empty distributed group.");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ class GroupImpl {
|
|||||||
virtual void recv(array& out, int src, Stream stream) = 0;
|
virtual void recv(array& out, int src, Stream stream) = 0;
|
||||||
virtual void all_max(const array& input, array& output, Stream stream) = 0;
|
virtual void all_max(const array& input, array& output, Stream stream) = 0;
|
||||||
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
||||||
|
virtual void
|
||||||
|
sum_scatter(const array& input, array& output, Stream stream) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Define the MLX stream that the communication should happen in. */
|
/* Define the MLX stream that the communication should happen in. */
|
||||||
@@ -51,4 +53,7 @@ void all_max(Group group, const array& input, array& output, Stream stream);
|
|||||||
/** Min reduction */
|
/** Min reduction */
|
||||||
void all_min(Group group, const array& input, array& output, Stream stream);
|
void all_min(Group group, const array& input, array& output, Stream stream);
|
||||||
|
|
||||||
|
/** Reduce scatter with average operation */
|
||||||
|
void sum_scatter(Group group, const array& input, array& output, Stream stream);
|
||||||
|
|
||||||
} // namespace mlx::core::distributed::detail
|
} // namespace mlx::core::distributed::detail
|
||||||
|
|||||||
@@ -465,6 +465,10 @@ class MPIGroup : public GroupImpl {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void sum_scatter(const array& input, array& output, Stream stream) override {
|
||||||
|
throw std::runtime_error("[mpi] sum_scatter not yet implemented.");
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MPI_Comm comm_;
|
MPI_Comm comm_;
|
||||||
bool global_;
|
bool global_;
|
||||||
|
|||||||
@@ -296,8 +296,17 @@ class NCCLGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void all_gather(const array& input, array& output, Stream stream) override {
|
void all_gather(const array& input, array& output, Stream stream) override {
|
||||||
throw std::runtime_error(
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
"[nccl] All gather not supported in NCCL backend.");
|
using T = typename decltype(type_tag)::type;
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
CHECK_NCCL(ncclAllGather(
|
||||||
|
input.data<T>(),
|
||||||
|
output.data<T>(),
|
||||||
|
input.size(),
|
||||||
|
dt,
|
||||||
|
comm_,
|
||||||
|
encoder.stream()));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void send(const array& input, int dst, Stream stream) override {
|
void send(const array& input, int dst, Stream stream) override {
|
||||||
@@ -309,11 +318,24 @@ class NCCLGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void all_max(const array& input, array& output, Stream stream) override {
|
void all_max(const array& input, array& output, Stream stream) override {
|
||||||
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_min(const array& input, array& output, Stream stream) override {
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void sum_scatter(const array& input, array& output, Stream stream) override {
|
||||||
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
reduce_scatter_impl<T>(input, output, stream, dt, ncclSum);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -335,6 +357,25 @@ class NCCLGroup : public GroupImpl {
|
|||||||
encoder.stream()));
|
encoder.stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void reduce_scatter_impl(
|
||||||
|
const array& input,
|
||||||
|
array& output,
|
||||||
|
Stream stream,
|
||||||
|
ncclDataType_t dt,
|
||||||
|
ncclRedOp_t op) {
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
|
||||||
|
CHECK_NCCL(ncclReduceScatter(
|
||||||
|
input.data<T>(),
|
||||||
|
output.data<T>(),
|
||||||
|
output.size(),
|
||||||
|
dt,
|
||||||
|
op,
|
||||||
|
comm_,
|
||||||
|
encoder.stream()));
|
||||||
|
}
|
||||||
|
|
||||||
int rank_, size_;
|
int rank_, size_;
|
||||||
std::string initMethod_;
|
std::string initMethod_;
|
||||||
ncclUniqueId uniqueId_;
|
ncclUniqueId uniqueId_;
|
||||||
|
|||||||
@@ -157,4 +157,30 @@ array recv_like(
|
|||||||
return recv(x.shape(), x.dtype(), src, group_, s);
|
return recv(x.shape(), x.dtype(), src, group_, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array sum_scatter(
|
||||||
|
const array& x,
|
||||||
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto group = to_group(group_);
|
||||||
|
if (group.size() == 1) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
if (x.shape()[0] % group.size() != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[sum_scatter] Invalid shape=" << x.shape()
|
||||||
|
<< " for a group of size " << group.size()
|
||||||
|
<< ". The first dimension (axis 0) must be divisible by the group size.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result_shape = x.shape();
|
||||||
|
result_shape[0] /= group.size();
|
||||||
|
auto stream = detail::communication_stream(group, s);
|
||||||
|
|
||||||
|
return array(
|
||||||
|
std::move(result_shape),
|
||||||
|
x.dtype(),
|
||||||
|
std::make_shared<ReduceScatter>(stream, group, ReduceScatter::Sum),
|
||||||
|
{x});
|
||||||
|
}
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
|||||||
@@ -48,4 +48,9 @@ array all_min(
|
|||||||
std::optional<Group> group = std::nullopt,
|
std::optional<Group> group = std::nullopt,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array sum_scatter(
|
||||||
|
const array& x,
|
||||||
|
std::optional<Group> group = std::nullopt,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
|||||||
@@ -127,4 +127,30 @@ class Recv : public DistPrimitive {
|
|||||||
int src_;
|
int src_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ReduceScatter : public DistPrimitive {
|
||||||
|
public:
|
||||||
|
enum ReduceType { Sum, Min, Max };
|
||||||
|
ReduceScatter(Stream stream, Group group, ReduceType reduce_type)
|
||||||
|
: DistPrimitive(stream, group), reduce_type_(reduce_type) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
|
||||||
|
const char* name() const override {
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Sum:
|
||||||
|
return "Sum ReduceScatter";
|
||||||
|
case Min:
|
||||||
|
return "Min ReduceScatter";
|
||||||
|
case Max:
|
||||||
|
return "Max ReduceScatter";
|
||||||
|
}
|
||||||
|
return "<unknwon ReduceScatter>";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
ReduceType reduce_type_;
|
||||||
|
};
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
|||||||
@@ -731,6 +731,10 @@ class RingGroup : public GroupImpl {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void sum_scatter(const array& input, array& output, Stream stream) override {
|
||||||
|
throw std::runtime_error("[ring] sum_scatter not supported.");
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename T, typename ReduceOp>
|
template <typename T, typename ReduceOp>
|
||||||
void all_reduce(
|
void all_reduce(
|
||||||
|
|||||||
@@ -229,7 +229,7 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
x (array): Input array.
|
x (array): Input array.
|
||||||
dst (int): Rank of the destination process in the group.
|
dst (int): Rank of the destination process in the group.
|
||||||
group (Group): The group of processes that will participate in the
|
group (Group): The group of processes that will participate in the
|
||||||
sned. If set to ``None`` the global group is used. Default:
|
send. If set to ``None`` the global group is used. Default:
|
||||||
``None``.
|
``None``.
|
||||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
in which case the default stream of the default device is used.
|
in which case the default stream of the default device is used.
|
||||||
@@ -301,4 +301,36 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The array that was received from ``src``.
|
array: The array that was received from ``src``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"sum_scatter",
|
||||||
|
[](const ScalarOrArray& x,
|
||||||
|
std::optional<mx::distributed::Group> group,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
return mx::distributed::sum_scatter(to_array(x), group, s);
|
||||||
|
},
|
||||||
|
"x"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"group"_a = nb::none(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def sum_scatter(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Sum ``x`` across all processes in the group and shard the result along the first axis across ranks.
|
||||||
|
``x.shape[0]`` must be divisible by the group size.
|
||||||
|
|
||||||
|
The result is equivalent to ``all_sum(x)[rank*chunk_size:(rank+1)*chunk_size]``, where ``chunk_size = x.shape[0] // group.size()`` and ``rank`` is the rank of this process in the group.
|
||||||
|
Note: ``all_sum`` is mentioned only for illustration; the actual implementation does not perform ``all_sum`` and uses a single reduce-scatter collective instead.
|
||||||
|
Currently supported only for the NCCL backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (array): Input array.
|
||||||
|
group (Group): The group of processes that will participate in the
|
||||||
|
sum scatter. If set to ``None`` the global group is used. Default:
|
||||||
|
``None``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
Returns:
|
||||||
|
array: The output array with shape ``[x.shape[0] // group.size(), *x.shape[1:]]``.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
# Copyright © 2025 Apple Inc.
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
@@ -63,6 +61,48 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
|||||||
finally:
|
finally:
|
||||||
mx.distributed.all_sum = original_all_sum
|
mx.distributed.all_sum = original_all_sum
|
||||||
|
|
||||||
|
def test_all_reduce(self):
|
||||||
|
g = mx.distributed.init()
|
||||||
|
dtypes = [
|
||||||
|
(mx.int8, 0),
|
||||||
|
(mx.uint8, 0),
|
||||||
|
(mx.int32, 0),
|
||||||
|
(mx.uint32, 0),
|
||||||
|
(mx.float32, 1e-6),
|
||||||
|
(mx.float16, 5e-3),
|
||||||
|
(mx.bfloat16, 1e-1),
|
||||||
|
]
|
||||||
|
sizes = [
|
||||||
|
(7,),
|
||||||
|
(10,),
|
||||||
|
(1024,),
|
||||||
|
(1024, 1024),
|
||||||
|
]
|
||||||
|
key = mx.random.key(0)
|
||||||
|
|
||||||
|
for dt, rtol in dtypes:
|
||||||
|
for sh in sizes:
|
||||||
|
x = (mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10).astype(dt)
|
||||||
|
|
||||||
|
# All sum
|
||||||
|
y = mx.distributed.all_sum(x[g.rank()], group=g)
|
||||||
|
z = x.sum(0)
|
||||||
|
maxrelerror = (y - z).abs()
|
||||||
|
if rtol > 0:
|
||||||
|
maxrelerror /= z.abs()
|
||||||
|
maxrelerror = maxrelerror.max()
|
||||||
|
self.assertLessEqual(maxrelerror, rtol)
|
||||||
|
|
||||||
|
# All max
|
||||||
|
y = mx.distributed.all_max(x[g.rank()], group=g)
|
||||||
|
z = x.max(0)
|
||||||
|
self.assertTrue(mx.all(y == z))
|
||||||
|
|
||||||
|
# All min
|
||||||
|
y = mx.distributed.all_min(x[g.rank()], group=g)
|
||||||
|
z = x.min(0)
|
||||||
|
self.assertTrue(mx.all(y == z))
|
||||||
|
|
||||||
def test_donation(self):
|
def test_donation(self):
|
||||||
x = mx.random.normal((1024,))
|
x = mx.random.normal((1024,))
|
||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
@@ -103,18 +143,19 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
|||||||
y = lin(x)
|
y = lin(x)
|
||||||
y1 = slin1(x)
|
y1 = slin1(x)
|
||||||
y2 = slin2(x[part])
|
y2 = slin2(x[part])
|
||||||
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
|
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
|
||||||
self.assertTrue(mx.allclose(y[part], y1))
|
self.assertTrue(mx.allclose(y[part], y1, atol=self.atol, rtol=self.rtol))
|
||||||
|
|
||||||
# And their quant versions
|
# And their quant versions (QuintizedMatmul is not supported on CUDA)
|
||||||
qlin = lin.to_quantized()
|
if not mx.cuda.is_available():
|
||||||
slin1 = shard_linear(qlin, "all-to-sharded")
|
qlin = lin.to_quantized()
|
||||||
slin2 = shard_linear(qlin, "sharded-to-all")
|
slin1 = shard_linear(qlin, "all-to-sharded")
|
||||||
y = qlin(x)
|
slin2 = shard_linear(qlin, "sharded-to-all")
|
||||||
y1 = slin1(x)
|
y = qlin(x)
|
||||||
y2 = slin2(x[part])
|
y1 = slin1(x)
|
||||||
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
|
y2 = slin2(x[part])
|
||||||
self.assertTrue(mx.allclose(y[part], y1))
|
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
|
||||||
|
self.assertTrue(mx.allclose(y[part], y1))
|
||||||
|
|
||||||
# Check the backward works as expected
|
# Check the backward works as expected
|
||||||
def dummy_loss(model, x, y):
|
def dummy_loss(model, x, y):
|
||||||
@@ -197,12 +238,18 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
mx.allclose(
|
mx.allclose(
|
||||||
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
g1["layers"][1]["bias"],
|
||||||
|
g2["layers"][1]["bias"],
|
||||||
|
atol=self.atol,
|
||||||
|
rtol=self.rtol,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
mx.allclose(
|
mx.allclose(
|
||||||
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
g1["layers"][3]["bias"],
|
||||||
|
g2["layers"][3]["bias"],
|
||||||
|
atol=self.atol,
|
||||||
|
rtol=self.rtol,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -248,3 +295,20 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
|||||||
y1 = mod(x)
|
y1 = mod(x)
|
||||||
y2 = smod(x)
|
y2 = smod(x)
|
||||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
||||||
|
|
||||||
|
def test_all_gather(self):
|
||||||
|
world = mx.distributed.init()
|
||||||
|
dtypes = [
|
||||||
|
mx.int8,
|
||||||
|
mx.uint8,
|
||||||
|
mx.int32,
|
||||||
|
mx.uint32,
|
||||||
|
mx.float32,
|
||||||
|
mx.float16,
|
||||||
|
mx.bfloat16,
|
||||||
|
]
|
||||||
|
for dt in dtypes:
|
||||||
|
x = mx.ones((2, 2, 4), dtype=dt)
|
||||||
|
y = mx.distributed.all_gather(x)
|
||||||
|
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
|
||||||
|
self.assertTrue(mx.all(y == 1))
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx_distributed_tests
|
import mlx_distributed_tests
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
world = mx.distributed.init(strict=True, backend="mpi")
|
_ = mx.distributed.init(strict=True, backend="mpi")
|
||||||
|
cls.atol = 1e-6
|
||||||
|
cls.rtol = 1e-4
|
||||||
|
|
||||||
def test_groups(self):
|
def test_groups(self):
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
@@ -27,18 +28,11 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
|||||||
sub = world.split(world.rank() // 2)
|
sub = world.split(world.rank() // 2)
|
||||||
self.assertEqual(sub.size(), 2)
|
self.assertEqual(sub.size(), 2)
|
||||||
|
|
||||||
def test_all_reduce(self):
|
def test_all_reduce_extra(self):
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
dtypes = [
|
dtypes = [
|
||||||
(mx.int8, 0),
|
|
||||||
(mx.uint8, 0),
|
|
||||||
(mx.int16, 0),
|
(mx.int16, 0),
|
||||||
(mx.uint16, 0),
|
(mx.uint16, 0),
|
||||||
(mx.int32, 0),
|
|
||||||
(mx.uint32, 0),
|
|
||||||
(mx.float32, 1e-6),
|
|
||||||
(mx.float16, 5e-3),
|
|
||||||
(mx.bfloat16, 1e-1),
|
|
||||||
(mx.complex64, 1e-6),
|
(mx.complex64, 1e-6),
|
||||||
]
|
]
|
||||||
sizes = [
|
sizes = [
|
||||||
@@ -76,16 +70,11 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
|||||||
z = x.min(0)
|
z = x.min(0)
|
||||||
self.assertTrue(mx.all(y == z))
|
self.assertTrue(mx.all(y == z))
|
||||||
|
|
||||||
def test_all_gather(self):
|
def test_all_gather_extra(self):
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
dtypes = [
|
dtypes = [
|
||||||
mx.int8,
|
|
||||||
mx.uint8,
|
|
||||||
mx.int16,
|
mx.int16,
|
||||||
mx.uint16,
|
mx.uint16,
|
||||||
mx.int32,
|
|
||||||
mx.uint32,
|
|
||||||
mx.float32,
|
|
||||||
mx.complex64,
|
mx.complex64,
|
||||||
]
|
]
|
||||||
for dt in dtypes:
|
for dt in dtypes:
|
||||||
@@ -150,4 +139,4 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
@@ -1,283 +1,52 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx_distributed_tests
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
|
||||||
from mlx.nn.utils import average_gradients
|
|
||||||
|
|
||||||
|
|
||||||
class TestNCCLDistributed(mlx_tests.MLXTestCase):
|
class TestNCCLDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
world = mx.distributed.init(strict=True, backend="nccl")
|
_ = mx.distributed.init(strict=True, backend="nccl")
|
||||||
rank = world.rank()
|
cls.atol = 1e-4
|
||||||
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
|
cls.rtol = 1e-4
|
||||||
|
|
||||||
|
def test_sum_scatter(self):
|
||||||
|
|
||||||
def test_all_reduce(self):
|
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
|
|
||||||
dtypes = [
|
dtypes = [
|
||||||
(mx.int8, 0),
|
|
||||||
(mx.uint8, 0),
|
|
||||||
(mx.int32, 0),
|
|
||||||
(mx.uint32, 0),
|
|
||||||
(mx.float32, 1e-6),
|
(mx.float32, 1e-6),
|
||||||
(mx.float16, 5e-3),
|
(mx.float16, 5e-3),
|
||||||
(mx.bfloat16, 1e-1),
|
(mx.bfloat16, 1e-1),
|
||||||
]
|
]
|
||||||
sizes = [
|
sizes = [
|
||||||
(7,),
|
(8,),
|
||||||
(10,),
|
(64,),
|
||||||
(1024,),
|
(1024,),
|
||||||
(1024, 1024),
|
(1024, 1024),
|
||||||
]
|
]
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(world.rank())
|
||||||
|
|
||||||
for dt, rtol in dtypes:
|
for dt, rtol in dtypes:
|
||||||
for sh in sizes:
|
for sh in sizes:
|
||||||
x = (
|
x = (mx.random.uniform(shape=sh, key=key) * 10).astype(dt) # shape=sh
|
||||||
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
|
||||||
).astype(dt)
|
|
||||||
|
|
||||||
# All sum
|
# Sum scatter
|
||||||
y = mx.distributed.all_sum(x[world.rank()])
|
y = mx.distributed.sum_scatter(x) # shape=sh/world.size()
|
||||||
z = x.sum(0)
|
z = mx.distributed.all_sum(x) # shape=sh
|
||||||
maxrelerror = (y - z).abs()
|
chunk = sh[0] // world.size()
|
||||||
|
start = world.rank() * chunk
|
||||||
|
stop = start + chunk
|
||||||
|
z_ref = z[start:stop]
|
||||||
|
|
||||||
|
maxrelerror = (y - z_ref).abs()
|
||||||
if rtol > 0:
|
if rtol > 0:
|
||||||
maxrelerror /= z.abs()
|
maxrelerror /= z_ref.abs()
|
||||||
maxrelerror = maxrelerror.max()
|
maxrelerror = maxrelerror.max()
|
||||||
self.assertLessEqual(maxrelerror, rtol)
|
self.assertLessEqual(maxrelerror, rtol)
|
||||||
|
|
||||||
def test_average_gradients(self):
|
|
||||||
original_all_sum = mx.distributed.all_sum
|
|
||||||
n_calls = 0
|
|
||||||
xtype = None
|
|
||||||
|
|
||||||
def new_all_sum(x, **kwargs):
|
|
||||||
nonlocal n_calls
|
|
||||||
nonlocal xtype
|
|
||||||
|
|
||||||
n_calls += 1
|
|
||||||
if xtype is not None:
|
|
||||||
self.assertEqual(xtype, x.dtype)
|
|
||||||
|
|
||||||
return original_all_sum(x, **kwargs)
|
|
||||||
|
|
||||||
mx.distributed.all_sum = new_all_sum
|
|
||||||
try:
|
|
||||||
grads = [mx.ones(10) for i in range(10)]
|
|
||||||
new_grads = average_gradients(grads)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 1)
|
|
||||||
|
|
||||||
n_calls = 0
|
|
||||||
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 2)
|
|
||||||
|
|
||||||
n_calls = 0
|
|
||||||
new_grads = average_gradients(grads, all_reduce_size=0)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 10)
|
|
||||||
|
|
||||||
n_calls = 0
|
|
||||||
xtype = mx.float16
|
|
||||||
new_grads = average_gradients(
|
|
||||||
grads,
|
|
||||||
all_reduce_size=2 * 50,
|
|
||||||
communication_type=mx.float16,
|
|
||||||
)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 2)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
mx.distributed.all_sum = original_all_sum
|
|
||||||
|
|
||||||
def test_donation(self):
|
|
||||||
x = mx.random.normal((1024,))
|
|
||||||
mx.eval(x)
|
|
||||||
mx.synchronize()
|
|
||||||
|
|
||||||
mx.reset_peak_memory()
|
|
||||||
scale = mx.array(2.0)
|
|
||||||
y = mx.distributed.all_sum(x)
|
|
||||||
mx.eval(y)
|
|
||||||
mx.synchronize()
|
|
||||||
all_sum_only = mx.get_peak_memory()
|
|
||||||
y = mx.distributed.all_sum(x) * scale
|
|
||||||
mx.eval(y)
|
|
||||||
mx.synchronize()
|
|
||||||
all_sum_with_binary = mx.get_peak_memory()
|
|
||||||
|
|
||||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
|
||||||
|
|
||||||
def test_shard_linear(self):
|
|
||||||
# Seed the prng to have the same inputs and weights generated everywhere
|
|
||||||
mx.random.seed(0xF0F0F0F0)
|
|
||||||
|
|
||||||
# Prepare inputs
|
|
||||||
world = mx.distributed.init()
|
|
||||||
part = (
|
|
||||||
slice(None),
|
|
||||||
slice(
|
|
||||||
world.rank() * 1024 // world.size(),
|
|
||||||
(world.rank() + 1) * 1024 // world.size(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
x = mx.random.normal((4, 1024))
|
|
||||||
|
|
||||||
# Create and shard some linear layers
|
|
||||||
lin = nn.Linear(1024, 1024, bias=True)
|
|
||||||
slin1 = shard_linear(lin, "all-to-sharded")
|
|
||||||
slin2 = shard_linear(lin, "sharded-to-all")
|
|
||||||
y = lin(x)
|
|
||||||
y1 = slin1(x)
|
|
||||||
y2 = slin2(x[part])
|
|
||||||
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
|
|
||||||
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
|
|
||||||
|
|
||||||
# Check the backward works as expected
|
|
||||||
def dummy_loss(model, x, y):
|
|
||||||
return (model(x) * y).sum()
|
|
||||||
|
|
||||||
mod = nn.Sequential(
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
)
|
|
||||||
smod = nn.Sequential(
|
|
||||||
shard_linear(mod.layers[0], "all-to-sharded"),
|
|
||||||
shard_linear(mod.layers[1], "sharded-to-all"),
|
|
||||||
shard_linear(mod.layers[2], "all-to-sharded"),
|
|
||||||
shard_linear(mod.layers[3], "sharded-to-all"),
|
|
||||||
)
|
|
||||||
|
|
||||||
grad1 = nn.value_and_grad(mod, dummy_loss)
|
|
||||||
grad2 = nn.value_and_grad(smod, dummy_loss)
|
|
||||||
|
|
||||||
x = mx.random.normal((4, 128))
|
|
||||||
y = mx.random.normal((4, 128))
|
|
||||||
|
|
||||||
l1, g1 = grad1(mod, x, y)
|
|
||||||
l2, g2 = grad2(smod, x, y)
|
|
||||||
mx.eval(l1, g1, l2, g2)
|
|
||||||
|
|
||||||
part = slice(
|
|
||||||
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(l1, l2))
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][0]["weight"][part],
|
|
||||||
g2["layers"][0]["weight"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][2]["weight"][part],
|
|
||||||
g2["layers"][2]["weight"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][1]["weight"][:, part],
|
|
||||||
g2["layers"][1]["weight"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][3]["weight"][:, part],
|
|
||||||
g2["layers"][3]["weight"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][0]["bias"][part],
|
|
||||||
g2["layers"][0]["bias"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][2]["bias"][part],
|
|
||||||
g2["layers"][2]["bias"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_shard_predicate(self):
|
|
||||||
mx.random.seed(0xF0F0F0F0)
|
|
||||||
|
|
||||||
class MyConv(nn.Module):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.aggregate = kwargs.pop("aggregate", False)
|
|
||||||
self.conv = nn.Conv2d(*args, **kwargs)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
if self.aggregate:
|
|
||||||
x = mx.distributed.all_sum(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def sharding(path, weight):
|
|
||||||
parts = path.split(".")
|
|
||||||
even = int(parts[1]) % 2 == 0
|
|
||||||
if even:
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
return -1 if parts[-1] != "bias" else None
|
|
||||||
|
|
||||||
mod = nn.Sequential(
|
|
||||||
MyConv(3, 128, kernel_size=3),
|
|
||||||
MyConv(128, 128, kernel_size=3),
|
|
||||||
MyConv(128, 128, kernel_size=3),
|
|
||||||
MyConv(128, 3, kernel_size=3),
|
|
||||||
)
|
|
||||||
smod = nn.Sequential(
|
|
||||||
MyConv(3, 128, kernel_size=3),
|
|
||||||
MyConv(128, 128, kernel_size=3, aggregate=True),
|
|
||||||
MyConv(128, 128, kernel_size=3),
|
|
||||||
MyConv(128, 3, kernel_size=3, aggregate=True),
|
|
||||||
)
|
|
||||||
smod.update(mod.parameters())
|
|
||||||
shard_inplace(smod, sharding)
|
|
||||||
|
|
||||||
x = mx.random.normal((4, 16, 16, 3))
|
|
||||||
y1 = mod(x)
|
|
||||||
y2 = smod(x)
|
|
||||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx_distributed_tests
|
import mlx_distributed_tests
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
@@ -10,7 +8,9 @@ import mlx_tests
|
|||||||
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
world = mx.distributed.init(strict=True, backend="ring")
|
_ = mx.distributed.init(strict=True, backend="ring")
|
||||||
|
cls.atol = 1e-6
|
||||||
|
cls.rtol = 1e-4
|
||||||
|
|
||||||
def test_groups(self):
|
def test_groups(self):
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
@@ -24,18 +24,11 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
|||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
sub = world.split(world.rank() % 2)
|
sub = world.split(world.rank() % 2)
|
||||||
|
|
||||||
def test_all_reduce(self):
|
def test_all_reduce_extra(self):
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
dtypes = [
|
dtypes = [
|
||||||
(mx.int8, 0),
|
|
||||||
(mx.uint8, 0),
|
|
||||||
(mx.int16, 0),
|
(mx.int16, 0),
|
||||||
(mx.uint16, 0),
|
(mx.uint16, 0),
|
||||||
(mx.int32, 0),
|
|
||||||
(mx.uint32, 0),
|
|
||||||
(mx.float32, 1e-6),
|
|
||||||
(mx.float16, 5e-3),
|
|
||||||
(mx.bfloat16, 1e-1),
|
|
||||||
(mx.complex64, 1e-6),
|
(mx.complex64, 1e-6),
|
||||||
]
|
]
|
||||||
sizes = [
|
sizes = [
|
||||||
@@ -45,7 +38,6 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
|||||||
(1024, 1024),
|
(1024, 1024),
|
||||||
]
|
]
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
reductions = ["min", "max", "sum"]
|
|
||||||
|
|
||||||
for dt, rtol in dtypes:
|
for dt, rtol in dtypes:
|
||||||
for sh in sizes:
|
for sh in sizes:
|
||||||
@@ -72,16 +64,11 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
|||||||
z = x.min(0)
|
z = x.min(0)
|
||||||
self.assertTrue(mx.all(y == z))
|
self.assertTrue(mx.all(y == z))
|
||||||
|
|
||||||
def test_all_gather(self):
|
def test_all_gather_extra(self):
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
dtypes = [
|
dtypes = [
|
||||||
mx.int8,
|
|
||||||
mx.uint8,
|
|
||||||
mx.int16,
|
mx.int16,
|
||||||
mx.uint16,
|
mx.uint16,
|
||||||
mx.int32,
|
|
||||||
mx.uint32,
|
|
||||||
mx.float32,
|
|
||||||
mx.complex64,
|
mx.complex64,
|
||||||
]
|
]
|
||||||
for dt in dtypes:
|
for dt in dtypes:
|
||||||
|
|||||||
Reference in New Issue
Block a user