mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 23:15:09 +08:00
Min / max reductions (#2041)
This commit is contained in:
committed by
GitHub
parent
9ecefd56db
commit
515f104926
@@ -117,7 +117,64 @@ void init_distributed(nb::module_& parent_module) {
|
||||
Returns:
|
||||
array: The sum of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"all_max",
|
||||
[](const ScalarOrArray& x,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::all_max(to_array(x), group, s);
|
||||
},
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_max(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
All reduce max.
|
||||
|
||||
Find the maximum of the ``x`` arrays from all processes in the group.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
group (Group): The group of processes that will participate in the
|
||||
reduction. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The maximum of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"all_min",
|
||||
[](const ScalarOrArray& x,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::all_min(to_array(x), group, s);
|
||||
},
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_min(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
All reduce min.
|
||||
|
||||
Find the minimum of the ``x`` arrays from all processes in the group.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
group (Group): The group of processes that will participate in the
|
||||
reduction. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The minimum of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"all_gather",
|
||||
[](const ScalarOrArray& x,
|
||||
|
||||
@@ -124,6 +124,22 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
||||
mx.eval(y, x)
|
||||
|
||||
def test_min_max(self):
|
||||
world = mx.distributed.init()
|
||||
base = mx.arange(16).reshape(4, 4)
|
||||
x = base + world.rank() * 32
|
||||
|
||||
def _test_reduction(reduction: str = "all_max"):
|
||||
|
||||
target = base + ((world.size() - 1) * 16) * (reduction == "max")
|
||||
reducer = getattr(mx.distributed, reduction)
|
||||
y = reducer(x)
|
||||
|
||||
self.assertTrue(mx.allclose(y, target))
|
||||
|
||||
for reduction in ["all_max", "all_min"]:
|
||||
_test_reduction(reduction)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -44,17 +44,23 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
(1024, 1024),
|
||||
]
|
||||
key = mx.random.key(0)
|
||||
reductions = ["min", "max", "sum"]
|
||||
|
||||
for dt, rtol in dtypes:
|
||||
for sh in sizes:
|
||||
x = (
|
||||
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||
).astype(dt)
|
||||
y = mx.distributed.all_sum(x[world.rank()])
|
||||
z = sum(
|
||||
x[i] for i in range(world.size())
|
||||
) # to ensure that we don't sum to int32
|
||||
maxrelerror = ((y - z).abs() / z.abs()).max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
for reduction in reductions:
|
||||
reducer_distributed = getattr(mx.distributed, f"all_{reduction}")
|
||||
y = reducer_distributed(x[world.rank()])
|
||||
|
||||
reducer = getattr(mx, reduction)
|
||||
z = reducer(x, axis=0)
|
||||
mx.eval(y, z)
|
||||
|
||||
maxrelerror = ((y - z).abs() / z.abs()).max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
|
||||
def test_all_gather(self):
|
||||
world = mx.distributed.init()
|
||||
|
||||
Reference in New Issue
Block a user