mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
251 lines
7.7 KiB
Python
251 lines
7.7 KiB
Python
# Copyright © 2025 Apple Inc.
|
|
|
|
import unittest
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import mlx_tests
|
|
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
|
from mlx.nn.utils import average_gradients
|
|
|
|
|
|
class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
|
def test_average_gradients(self):
|
|
original_all_sum = mx.distributed.all_sum
|
|
n_calls = 0
|
|
xtype = None
|
|
|
|
def new_all_sum(x, **kwargs):
|
|
nonlocal n_calls
|
|
nonlocal xtype
|
|
|
|
n_calls += 1
|
|
if xtype is not None:
|
|
self.assertEqual(xtype, x.dtype)
|
|
|
|
return original_all_sum(x, **kwargs)
|
|
|
|
mx.distributed.all_sum = new_all_sum
|
|
|
|
try:
|
|
grads = [mx.ones(10) for i in range(10)]
|
|
new_grads = average_gradients(grads)
|
|
mx.eval(new_grads)
|
|
self.assertEqual(len(new_grads), 10)
|
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
self.assertEqual(n_calls, 1)
|
|
|
|
n_calls = 0
|
|
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
|
|
mx.eval(new_grads)
|
|
self.assertEqual(len(new_grads), 10)
|
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
self.assertEqual(n_calls, 2)
|
|
|
|
n_calls = 0
|
|
new_grads = average_gradients(grads, all_reduce_size=0)
|
|
mx.eval(new_grads)
|
|
self.assertEqual(len(new_grads), 10)
|
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
self.assertEqual(n_calls, 10)
|
|
|
|
n_calls = 0
|
|
xtype = mx.float16
|
|
new_grads = average_gradients(
|
|
grads, all_reduce_size=2 * 50, communication_type=mx.float16
|
|
)
|
|
mx.eval(new_grads)
|
|
self.assertEqual(len(new_grads), 10)
|
|
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
self.assertEqual(n_calls, 2)
|
|
|
|
finally:
|
|
mx.distributed.all_sum = original_all_sum
|
|
|
|
def test_donation(self):
|
|
x = mx.random.normal((1024,))
|
|
mx.eval(x)
|
|
mx.synchronize()
|
|
|
|
mx.reset_peak_memory()
|
|
scale = mx.array(2.0)
|
|
y = mx.distributed.all_sum(x)
|
|
mx.eval(y)
|
|
mx.synchronize()
|
|
all_sum_only = mx.get_peak_memory()
|
|
y = mx.distributed.all_sum(x) * scale
|
|
mx.eval(y)
|
|
mx.synchronize()
|
|
all_sum_with_binary = mx.get_peak_memory()
|
|
|
|
self.assertEqual(all_sum_only, all_sum_with_binary)
|
|
|
|
def test_shard_linear(self):
|
|
# Seed the prng to have the same inputs and weights generated everywhere
|
|
mx.random.seed(0xF0F0F0F0)
|
|
|
|
# Prepare inputs
|
|
world = mx.distributed.init()
|
|
part = (
|
|
slice(None),
|
|
slice(
|
|
world.rank() * 1024 // world.size(),
|
|
(world.rank() + 1) * 1024 // world.size(),
|
|
),
|
|
)
|
|
x = mx.random.normal((4, 1024))
|
|
|
|
# Create and shard some linear layers
|
|
lin = nn.Linear(1024, 1024, bias=True)
|
|
slin1 = shard_linear(lin, "all-to-sharded")
|
|
slin2 = shard_linear(lin, "sharded-to-all")
|
|
y = lin(x)
|
|
y1 = slin1(x)
|
|
y2 = slin2(x[part])
|
|
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
|
|
self.assertTrue(mx.allclose(y[part], y1))
|
|
|
|
# And their quant versions
|
|
qlin = lin.to_quantized()
|
|
slin1 = shard_linear(qlin, "all-to-sharded")
|
|
slin2 = shard_linear(qlin, "sharded-to-all")
|
|
y = qlin(x)
|
|
y1 = slin1(x)
|
|
y2 = slin2(x[part])
|
|
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
|
|
self.assertTrue(mx.allclose(y[part], y1))
|
|
|
|
# Check the backward works as expected
|
|
def dummy_loss(model, x, y):
|
|
return (model(x) * y).sum()
|
|
|
|
mod = nn.Sequential(
|
|
nn.Linear(128, 128),
|
|
nn.Linear(128, 128),
|
|
nn.Linear(128, 128),
|
|
nn.Linear(128, 128),
|
|
)
|
|
smod = nn.Sequential(
|
|
shard_linear(mod.layers[0], "all-to-sharded"),
|
|
shard_linear(mod.layers[1], "sharded-to-all"),
|
|
shard_linear(mod.layers[2], "all-to-sharded"),
|
|
shard_linear(mod.layers[3], "sharded-to-all"),
|
|
)
|
|
|
|
grad1 = nn.value_and_grad(mod, dummy_loss)
|
|
grad2 = nn.value_and_grad(smod, dummy_loss)
|
|
|
|
x = mx.random.normal((4, 128))
|
|
y = mx.random.normal((4, 128))
|
|
|
|
l1, g1 = grad1(mod, x, y)
|
|
l2, g2 = grad2(smod, x, y)
|
|
mx.eval(l1, g1, l2, g2)
|
|
|
|
part = slice(
|
|
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
|
|
)
|
|
self.assertTrue(mx.allclose(l1, l2))
|
|
self.assertTrue(
|
|
mx.allclose(
|
|
g1["layers"][0]["weight"][part],
|
|
g2["layers"][0]["weight"],
|
|
atol=1e-6,
|
|
rtol=1e-4,
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
mx.allclose(
|
|
g1["layers"][2]["weight"][part],
|
|
g2["layers"][2]["weight"],
|
|
atol=1e-6,
|
|
rtol=1e-4,
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
mx.allclose(
|
|
g1["layers"][1]["weight"][:, part],
|
|
g2["layers"][1]["weight"],
|
|
atol=1e-6,
|
|
rtol=1e-4,
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
mx.allclose(
|
|
g1["layers"][3]["weight"][:, part],
|
|
g2["layers"][3]["weight"],
|
|
atol=1e-6,
|
|
rtol=1e-4,
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
mx.allclose(
|
|
g1["layers"][0]["bias"][part],
|
|
g2["layers"][0]["bias"],
|
|
atol=1e-6,
|
|
rtol=1e-4,
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
mx.allclose(
|
|
g1["layers"][2]["bias"][part],
|
|
g2["layers"][2]["bias"],
|
|
atol=1e-6,
|
|
rtol=1e-4,
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
mx.allclose(
|
|
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
mx.allclose(
|
|
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
|
)
|
|
)
|
|
|
|
def test_shard_predicate(self):
|
|
mx.random.seed(0xF0F0F0F0)
|
|
|
|
class MyConv(nn.Module):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__()
|
|
self.aggregate = kwargs.pop("aggregate", False)
|
|
self.conv = nn.Conv2d(*args, **kwargs)
|
|
|
|
def __call__(self, x):
|
|
x = self.conv(x)
|
|
if self.aggregate:
|
|
x = mx.distributed.all_sum(x)
|
|
return x
|
|
|
|
def sharding(path, weight):
|
|
parts = path.split(".")
|
|
even = int(parts[1]) % 2 == 0
|
|
if even:
|
|
return 0
|
|
else:
|
|
return -1 if parts[-1] != "bias" else None
|
|
|
|
mod = nn.Sequential(
|
|
MyConv(3, 128, kernel_size=3),
|
|
MyConv(128, 128, kernel_size=3),
|
|
MyConv(128, 128, kernel_size=3),
|
|
MyConv(128, 3, kernel_size=3),
|
|
)
|
|
smod = nn.Sequential(
|
|
MyConv(3, 128, kernel_size=3),
|
|
MyConv(128, 128, kernel_size=3, aggregate=True),
|
|
MyConv(128, 128, kernel_size=3),
|
|
MyConv(128, 3, kernel_size=3, aggregate=True),
|
|
)
|
|
smod.update(mod.parameters())
|
|
shard_inplace(smod, sharding)
|
|
|
|
x = mx.random.normal((4, 16, 16, 3))
|
|
y1 = mod(x)
|
|
y2 = smod(x)
|
|
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|