mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
Nccl reduce scatter, all gather (#2727)
* Added reduce scatter and all gather for nccl * fix unused import, delete unused file * small fix * deleted useless condition * fixed comments * fix bug in eval_gpu, renamed to sum_scatter, fix docs * final fix docs * remove and * Update mlx/distributed/mpi/mpi.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * fix broken set input output * fixes set output * typo * fix typo * no cpu, no gpu for reduce scatter --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
committed by
GitHub
parent
761f901a41
commit
27778156dc
@@ -229,7 +229,7 @@ void init_distributed(nb::module_& parent_module) {
|
||||
x (array): Input array.
|
||||
dst (int): Rank of the destination process in the group.
|
||||
group (Group): The group of processes that will participate in the
|
||||
sned. If set to ``None`` the global group is used. Default:
|
||||
send. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
@@ -301,4 +301,36 @@ void init_distributed(nb::module_& parent_module) {
|
||||
Returns:
|
||||
array: The array that was received from ``src``.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"sum_scatter",
|
||||
[](const ScalarOrArray& x,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::sum_scatter(to_array(x), group, s);
|
||||
},
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def sum_scatter(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Sum ``x`` across all processes in the group and shard the result along the first axis across ranks.
|
||||
``x.shape[0]`` must be divisible by the group size.
|
||||
|
||||
The result is equivalent to ``all_sum(x)[rank*chunk_size:(rank+1)*chunk_size]``, where ``chunk_size = x.shape[0] // group.size()`` and ``rank`` is the rank of this process in the group.
|
||||
Note: ``all_sum`` is mentioned only for illustration; the actual implementation does not perform ``all_sum`` and uses a single reduce-scatter collective instead.
|
||||
Currently supported only for the NCCL backend.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
group (Group): The group of processes that will participate in the
|
||||
sum scatter. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
Returns:
|
||||
array: The output array with shape ``[x.shape[0] // group.size(), *x.shape[1:]]``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_tests
|
||||
@@ -63,6 +61,48 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
||||
finally:
|
||||
mx.distributed.all_sum = original_all_sum
|
||||
|
||||
def test_all_reduce(self):
|
||||
g = mx.distributed.init()
|
||||
dtypes = [
|
||||
(mx.int8, 0),
|
||||
(mx.uint8, 0),
|
||||
(mx.int32, 0),
|
||||
(mx.uint32, 0),
|
||||
(mx.float32, 1e-6),
|
||||
(mx.float16, 5e-3),
|
||||
(mx.bfloat16, 1e-1),
|
||||
]
|
||||
sizes = [
|
||||
(7,),
|
||||
(10,),
|
||||
(1024,),
|
||||
(1024, 1024),
|
||||
]
|
||||
key = mx.random.key(0)
|
||||
|
||||
for dt, rtol in dtypes:
|
||||
for sh in sizes:
|
||||
x = (mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10).astype(dt)
|
||||
|
||||
# All sum
|
||||
y = mx.distributed.all_sum(x[g.rank()], group=g)
|
||||
z = x.sum(0)
|
||||
maxrelerror = (y - z).abs()
|
||||
if rtol > 0:
|
||||
maxrelerror /= z.abs()
|
||||
maxrelerror = maxrelerror.max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
|
||||
# All max
|
||||
y = mx.distributed.all_max(x[g.rank()], group=g)
|
||||
z = x.max(0)
|
||||
self.assertTrue(mx.all(y == z))
|
||||
|
||||
# All min
|
||||
y = mx.distributed.all_min(x[g.rank()], group=g)
|
||||
z = x.min(0)
|
||||
self.assertTrue(mx.all(y == z))
|
||||
|
||||
def test_donation(self):
|
||||
x = mx.random.normal((1024,))
|
||||
mx.eval(x)
|
||||
@@ -103,18 +143,19 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
||||
y = lin(x)
|
||||
y1 = slin1(x)
|
||||
y2 = slin2(x[part])
|
||||
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
|
||||
self.assertTrue(mx.allclose(y[part], y1))
|
||||
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
|
||||
self.assertTrue(mx.allclose(y[part], y1, atol=self.atol, rtol=self.rtol))
|
||||
|
||||
# And their quant versions
|
||||
qlin = lin.to_quantized()
|
||||
slin1 = shard_linear(qlin, "all-to-sharded")
|
||||
slin2 = shard_linear(qlin, "sharded-to-all")
|
||||
y = qlin(x)
|
||||
y1 = slin1(x)
|
||||
y2 = slin2(x[part])
|
||||
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
|
||||
self.assertTrue(mx.allclose(y[part], y1))
|
||||
# And their quant versions (QuintizedMatmul is not supported on CUDA)
|
||||
if not mx.cuda.is_available():
|
||||
qlin = lin.to_quantized()
|
||||
slin1 = shard_linear(qlin, "all-to-sharded")
|
||||
slin2 = shard_linear(qlin, "sharded-to-all")
|
||||
y = qlin(x)
|
||||
y1 = slin1(x)
|
||||
y2 = slin2(x[part])
|
||||
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
|
||||
self.assertTrue(mx.allclose(y[part], y1))
|
||||
|
||||
# Check the backward works as expected
|
||||
def dummy_loss(model, x, y):
|
||||
@@ -197,12 +238,18 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
||||
g1["layers"][1]["bias"],
|
||||
g2["layers"][1]["bias"],
|
||||
atol=self.atol,
|
||||
rtol=self.rtol,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
||||
g1["layers"][3]["bias"],
|
||||
g2["layers"][3]["bias"],
|
||||
atol=self.atol,
|
||||
rtol=self.rtol,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -248,3 +295,20 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
||||
y1 = mod(x)
|
||||
y2 = smod(x)
|
||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
||||
|
||||
def test_all_gather(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.float16,
|
||||
mx.bfloat16,
|
||||
]
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_gather(x)
|
||||
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
|
||||
self.assertTrue(mx.all(y == 1))
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_distributed_tests
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
world = mx.distributed.init(strict=True, backend="mpi")
|
||||
_ = mx.distributed.init(strict=True, backend="mpi")
|
||||
cls.atol = 1e-6
|
||||
cls.rtol = 1e-4
|
||||
|
||||
def test_groups(self):
|
||||
world = mx.distributed.init()
|
||||
@@ -27,18 +28,11 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
sub = world.split(world.rank() // 2)
|
||||
self.assertEqual(sub.size(), 2)
|
||||
|
||||
def test_all_reduce(self):
|
||||
def test_all_reduce_extra(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
(mx.int8, 0),
|
||||
(mx.uint8, 0),
|
||||
(mx.int16, 0),
|
||||
(mx.uint16, 0),
|
||||
(mx.int32, 0),
|
||||
(mx.uint32, 0),
|
||||
(mx.float32, 1e-6),
|
||||
(mx.float16, 5e-3),
|
||||
(mx.bfloat16, 1e-1),
|
||||
(mx.complex64, 1e-6),
|
||||
]
|
||||
sizes = [
|
||||
@@ -76,16 +70,11 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
z = x.min(0)
|
||||
self.assertTrue(mx.all(y == z))
|
||||
|
||||
def test_all_gather(self):
|
||||
def test_all_gather_extra(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int16,
|
||||
mx.uint16,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.complex64,
|
||||
]
|
||||
for dt in dtypes:
|
||||
@@ -150,4 +139,4 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
@@ -1,283 +1,52 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_distributed_tests
|
||||
import mlx_tests
|
||||
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
||||
from mlx.nn.utils import average_gradients
|
||||
|
||||
|
||||
class TestNCCLDistributed(mlx_tests.MLXTestCase):
|
||||
class TestNCCLDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
world = mx.distributed.init(strict=True, backend="nccl")
|
||||
rank = world.rank()
|
||||
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
|
||||
_ = mx.distributed.init(strict=True, backend="nccl")
|
||||
cls.atol = 1e-4
|
||||
cls.rtol = 1e-4
|
||||
|
||||
def test_sum_scatter(self):
|
||||
|
||||
def test_all_reduce(self):
|
||||
world = mx.distributed.init()
|
||||
|
||||
dtypes = [
|
||||
(mx.int8, 0),
|
||||
(mx.uint8, 0),
|
||||
(mx.int32, 0),
|
||||
(mx.uint32, 0),
|
||||
(mx.float32, 1e-6),
|
||||
(mx.float16, 5e-3),
|
||||
(mx.bfloat16, 1e-1),
|
||||
]
|
||||
sizes = [
|
||||
(7,),
|
||||
(10,),
|
||||
(8,),
|
||||
(64,),
|
||||
(1024,),
|
||||
(1024, 1024),
|
||||
]
|
||||
key = mx.random.key(0)
|
||||
key = mx.random.key(world.rank())
|
||||
|
||||
for dt, rtol in dtypes:
|
||||
for sh in sizes:
|
||||
x = (
|
||||
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||
).astype(dt)
|
||||
x = (mx.random.uniform(shape=sh, key=key) * 10).astype(dt) # shape=sh
|
||||
|
||||
# All sum
|
||||
y = mx.distributed.all_sum(x[world.rank()])
|
||||
z = x.sum(0)
|
||||
maxrelerror = (y - z).abs()
|
||||
# Sum scatter
|
||||
y = mx.distributed.sum_scatter(x) # shape=sh/world.size()
|
||||
z = mx.distributed.all_sum(x) # shape=sh
|
||||
chunk = sh[0] // world.size()
|
||||
start = world.rank() * chunk
|
||||
stop = start + chunk
|
||||
z_ref = z[start:stop]
|
||||
|
||||
maxrelerror = (y - z_ref).abs()
|
||||
if rtol > 0:
|
||||
maxrelerror /= z.abs()
|
||||
maxrelerror /= z_ref.abs()
|
||||
maxrelerror = maxrelerror.max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
|
||||
def test_average_gradients(self):
|
||||
original_all_sum = mx.distributed.all_sum
|
||||
n_calls = 0
|
||||
xtype = None
|
||||
|
||||
def new_all_sum(x, **kwargs):
|
||||
nonlocal n_calls
|
||||
nonlocal xtype
|
||||
|
||||
n_calls += 1
|
||||
if xtype is not None:
|
||||
self.assertEqual(xtype, x.dtype)
|
||||
|
||||
return original_all_sum(x, **kwargs)
|
||||
|
||||
mx.distributed.all_sum = new_all_sum
|
||||
try:
|
||||
grads = [mx.ones(10) for i in range(10)]
|
||||
new_grads = average_gradients(grads)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 1)
|
||||
|
||||
n_calls = 0
|
||||
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 2)
|
||||
|
||||
n_calls = 0
|
||||
new_grads = average_gradients(grads, all_reduce_size=0)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 10)
|
||||
|
||||
n_calls = 0
|
||||
xtype = mx.float16
|
||||
new_grads = average_gradients(
|
||||
grads,
|
||||
all_reduce_size=2 * 50,
|
||||
communication_type=mx.float16,
|
||||
)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 2)
|
||||
|
||||
finally:
|
||||
mx.distributed.all_sum = original_all_sum
|
||||
|
||||
def test_donation(self):
|
||||
x = mx.random.normal((1024,))
|
||||
mx.eval(x)
|
||||
mx.synchronize()
|
||||
|
||||
mx.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
y = mx.distributed.all_sum(x)
|
||||
mx.eval(y)
|
||||
mx.synchronize()
|
||||
all_sum_only = mx.get_peak_memory()
|
||||
y = mx.distributed.all_sum(x) * scale
|
||||
mx.eval(y)
|
||||
mx.synchronize()
|
||||
all_sum_with_binary = mx.get_peak_memory()
|
||||
|
||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||
|
||||
def test_shard_linear(self):
|
||||
# Seed the prng to have the same inputs and weights generated everywhere
|
||||
mx.random.seed(0xF0F0F0F0)
|
||||
|
||||
# Prepare inputs
|
||||
world = mx.distributed.init()
|
||||
part = (
|
||||
slice(None),
|
||||
slice(
|
||||
world.rank() * 1024 // world.size(),
|
||||
(world.rank() + 1) * 1024 // world.size(),
|
||||
),
|
||||
)
|
||||
x = mx.random.normal((4, 1024))
|
||||
|
||||
# Create and shard some linear layers
|
||||
lin = nn.Linear(1024, 1024, bias=True)
|
||||
slin1 = shard_linear(lin, "all-to-sharded")
|
||||
slin2 = shard_linear(lin, "sharded-to-all")
|
||||
y = lin(x)
|
||||
y1 = slin1(x)
|
||||
y2 = slin2(x[part])
|
||||
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
|
||||
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
|
||||
|
||||
# Check the backward works as expected
|
||||
def dummy_loss(model, x, y):
|
||||
return (model(x) * y).sum()
|
||||
|
||||
mod = nn.Sequential(
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 128),
|
||||
)
|
||||
smod = nn.Sequential(
|
||||
shard_linear(mod.layers[0], "all-to-sharded"),
|
||||
shard_linear(mod.layers[1], "sharded-to-all"),
|
||||
shard_linear(mod.layers[2], "all-to-sharded"),
|
||||
shard_linear(mod.layers[3], "sharded-to-all"),
|
||||
)
|
||||
|
||||
grad1 = nn.value_and_grad(mod, dummy_loss)
|
||||
grad2 = nn.value_and_grad(smod, dummy_loss)
|
||||
|
||||
x = mx.random.normal((4, 128))
|
||||
y = mx.random.normal((4, 128))
|
||||
|
||||
l1, g1 = grad1(mod, x, y)
|
||||
l2, g2 = grad2(smod, x, y)
|
||||
mx.eval(l1, g1, l2, g2)
|
||||
|
||||
part = slice(
|
||||
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(l1, l2))
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][0]["weight"][part],
|
||||
g2["layers"][0]["weight"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][2]["weight"][part],
|
||||
g2["layers"][2]["weight"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][1]["weight"][:, part],
|
||||
g2["layers"][1]["weight"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][3]["weight"][:, part],
|
||||
g2["layers"][3]["weight"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][0]["bias"][part],
|
||||
g2["layers"][0]["bias"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][2]["bias"][part],
|
||||
g2["layers"][2]["bias"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
||||
)
|
||||
)
|
||||
|
||||
def test_shard_predicate(self):
|
||||
mx.random.seed(0xF0F0F0F0)
|
||||
|
||||
class MyConv(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.aggregate = kwargs.pop("aggregate", False)
|
||||
self.conv = nn.Conv2d(*args, **kwargs)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv(x)
|
||||
if self.aggregate:
|
||||
x = mx.distributed.all_sum(x)
|
||||
return x
|
||||
|
||||
def sharding(path, weight):
|
||||
parts = path.split(".")
|
||||
even = int(parts[1]) % 2 == 0
|
||||
if even:
|
||||
return 0
|
||||
else:
|
||||
return -1 if parts[-1] != "bias" else None
|
||||
|
||||
mod = nn.Sequential(
|
||||
MyConv(3, 128, kernel_size=3),
|
||||
MyConv(128, 128, kernel_size=3),
|
||||
MyConv(128, 128, kernel_size=3),
|
||||
MyConv(128, 3, kernel_size=3),
|
||||
)
|
||||
smod = nn.Sequential(
|
||||
MyConv(3, 128, kernel_size=3),
|
||||
MyConv(128, 128, kernel_size=3, aggregate=True),
|
||||
MyConv(128, 128, kernel_size=3),
|
||||
MyConv(128, 3, kernel_size=3, aggregate=True),
|
||||
)
|
||||
smod.update(mod.parameters())
|
||||
shard_inplace(smod, sharding)
|
||||
|
||||
x = mx.random.normal((4, 16, 16, 3))
|
||||
y1 = mod(x)
|
||||
y2 = smod(x)
|
||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_distributed_tests
|
||||
import mlx_tests
|
||||
@@ -10,7 +8,9 @@ import mlx_tests
|
||||
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
world = mx.distributed.init(strict=True, backend="ring")
|
||||
_ = mx.distributed.init(strict=True, backend="ring")
|
||||
cls.atol = 1e-6
|
||||
cls.rtol = 1e-4
|
||||
|
||||
def test_groups(self):
|
||||
world = mx.distributed.init()
|
||||
@@ -24,18 +24,11 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
sub = world.split(world.rank() % 2)
|
||||
|
||||
def test_all_reduce(self):
|
||||
def test_all_reduce_extra(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
(mx.int8, 0),
|
||||
(mx.uint8, 0),
|
||||
(mx.int16, 0),
|
||||
(mx.uint16, 0),
|
||||
(mx.int32, 0),
|
||||
(mx.uint32, 0),
|
||||
(mx.float32, 1e-6),
|
||||
(mx.float16, 5e-3),
|
||||
(mx.bfloat16, 1e-1),
|
||||
(mx.complex64, 1e-6),
|
||||
]
|
||||
sizes = [
|
||||
@@ -45,7 +38,6 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
(1024, 1024),
|
||||
]
|
||||
key = mx.random.key(0)
|
||||
reductions = ["min", "max", "sum"]
|
||||
|
||||
for dt, rtol in dtypes:
|
||||
for sh in sizes:
|
||||
@@ -72,16 +64,11 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
z = x.min(0)
|
||||
self.assertTrue(mx.all(y == z))
|
||||
|
||||
def test_all_gather(self):
|
||||
def test_all_gather_extra(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int16,
|
||||
mx.uint16,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.complex64,
|
||||
]
|
||||
for dt in dtypes:
|
||||
|
||||
Reference in New Issue
Block a user