Min / max reductions (#2041)

This commit is contained in:
Anastasiia Filippova
2025-04-10 08:22:20 +02:00
committed by GitHub
parent 9ecefd56db
commit 515f104926
11 changed files with 276 additions and 27 deletions

View File

@@ -124,6 +124,22 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
x = mx.distributed.recv_like(x, neighbor, group=pairs)
mx.eval(y, x)
def test_min_max(self):
world = mx.distributed.init()
base = mx.arange(16).reshape(4, 4)
x = base + world.rank() * 32
def _test_reduction(reduction: str = "all_max"):
target = base + ((world.size() - 1) * 16) * (reduction == "max")
reducer = getattr(mx.distributed, reduction)
y = reducer(x)
self.assertTrue(mx.allclose(y, target))
for reduction in ["all_max", "all_min"]:
_test_reduction(reduction)
if __name__ == "__main__":
unittest.main()