mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:29:35 +08:00
Fix the test and add custom min/max reductions for uncommon MPI types (#2060)
This commit is contained in:

committed by
GitHub

parent
dfae2c6989
commit
ddaa4b7dcb
@@ -51,16 +51,25 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
x = (
|
||||
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||
).astype(dt)
|
||||
for reduction in reductions:
|
||||
reducer_distributed = getattr(mx.distributed, f"all_{reduction}")
|
||||
y = reducer_distributed(x[world.rank()])
|
||||
|
||||
reducer = getattr(mx, reduction)
|
||||
z = reducer(x, axis=0)
|
||||
mx.eval(y, z)
|
||||
# All sum
|
||||
y = mx.distributed.all_sum(x[world.rank()])
|
||||
z = x.sum(0)
|
||||
maxrelerror = (y - z).abs()
|
||||
if rtol > 0:
|
||||
maxrelerror /= z.abs()
|
||||
maxrelerror = maxrelerror.max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
|
||||
maxrelerror = ((y - z).abs() / z.abs()).max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
# All max
|
||||
y = mx.distributed.all_max(x[world.rank()])
|
||||
z = x.max(0)
|
||||
self.assertTrue(mx.all(y == z))
|
||||
|
||||
# All min
|
||||
y = mx.distributed.all_min(x[world.rank()])
|
||||
z = x.min(0)
|
||||
self.assertTrue(mx.all(y == z))
|
||||
|
||||
def test_all_gather(self):
|
||||
world = mx.distributed.init()
|
||||
|
Reference in New Issue
Block a user