mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 01:19:21 +08:00
Nccl reduce scatter, all gather (#2727)
* Added reduce scatter and all gather for nccl * fix unused import, delete unused file * small fix * deleted useless condition * fixed comments * fix bug in eval_gpu, renamed to sum_scatter, fix docs * final fix docs * remove and * Update mlx/distributed/mpi/mpi.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * fix broken set input output * fixes set output * typo * fix typo * no cpu, no gpu for reduce scatter --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
committed by
GitHub
parent
761f901a41
commit
27778156dc
@@ -1,15 +1,16 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_distributed_tests
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
world = mx.distributed.init(strict=True, backend="mpi")
|
||||
_ = mx.distributed.init(strict=True, backend="mpi")
|
||||
cls.atol = 1e-6
|
||||
cls.rtol = 1e-4
|
||||
|
||||
def test_groups(self):
|
||||
world = mx.distributed.init()
|
||||
@@ -27,18 +28,11 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
sub = world.split(world.rank() // 2)
|
||||
self.assertEqual(sub.size(), 2)
|
||||
|
||||
def test_all_reduce(self):
|
||||
def test_all_reduce_extra(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
(mx.int8, 0),
|
||||
(mx.uint8, 0),
|
||||
(mx.int16, 0),
|
||||
(mx.uint16, 0),
|
||||
(mx.int32, 0),
|
||||
(mx.uint32, 0),
|
||||
(mx.float32, 1e-6),
|
||||
(mx.float16, 5e-3),
|
||||
(mx.bfloat16, 1e-1),
|
||||
(mx.complex64, 1e-6),
|
||||
]
|
||||
sizes = [
|
||||
@@ -76,16 +70,11 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
z = x.min(0)
|
||||
self.assertTrue(mx.all(y == z))
|
||||
|
||||
def test_all_gather(self):
|
||||
def test_all_gather_extra(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int16,
|
||||
mx.uint16,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.complex64,
|
||||
]
|
||||
for dt in dtypes:
|
||||
@@ -150,4 +139,4 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user