MoE backward improvements (#2335)

This commit is contained in:
Angelos Katharopoulos
2025-07-07 17:59:53 -07:00
committed by GitHub
parent a4fcc893cd
commit 4a9b29a875
22 changed files with 1130 additions and 60 deletions

View File

@@ -4321,6 +4321,28 @@ void init_ops(nb::module_& m) {
array: The result of the multiplication of ``x`` with ``w``
after gathering using ``lhs_indices`` and ``rhs_indices``.
)pbdoc");
m.def(
"segmented_mm",
&mx::segmented_mm,
nb::arg(),
nb::arg(),
"segments"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform a matrix multiplication but segment the inner dimension and
save the result for each segment separately.
Args:
a (array): Input array of shape ``MxK``.
b (array): Input array of shape ``KxN``.
segments (array): The offsets into the inner dimension for each segment.
Returns:
array: The result per segment of shape ``MxN``.
)pbdoc");
m.def(
"tensordot",
[](const mx::array& a,

View File

@@ -8,6 +8,9 @@ cuda_skip = {
# Gather matmul NYI
"TestBlas.test_gather_matmul",
"TestBlas.test_gather_matmul_grad",
"TestBlas.test_gather_mm_sorted",
# Segmented matmul NYI
"TestBlas.test_segmented_mm",
# Scan NYI
"TestArray.test_api",
"TestAutograd.test_cumprod_grad",
@@ -76,6 +79,7 @@ cuda_skip = {
"TestQuantized.test_gather_matmul_grad",
"TestQuantized.test_gather_qmm",
"TestQuantized.test_gather_qmm_sorted",
"TestQuantized.test_gather_qmm_grad",
"TestQuantized.test_non_multiples",
"TestQuantized.test_qmm",
"TestQuantized.test_qmm_jvp",

View File

@@ -1163,6 +1163,99 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def test_gather_mm_sorted(self):
def gather_mm_ref(a, b, rhs):
b = b[rhs]
return a @ b
def gather_mm_test(a, b, rhs):
return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True)
a = mx.random.normal((100, 1, 100))
b = mx.random.normal((8, 100, 100))
rhs = mx.sort(mx.random.randint(0, 8, shape=(100,)))
c1 = gather_mm_ref(a, b, rhs)
c2 = gather_mm_test(a, b, rhs)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
cotan = mx.random.normal(c1.shape)
c1, dc1 = mx.vjp(
lambda a, b: gather_mm_ref(a, b, rhs),
[a, b],
[cotan],
)
c2, dc2 = mx.vjp(
lambda a, b: gather_mm_test(a, b, rhs),
[a, b],
[cotan],
)
self.assertTrue(mx.allclose(c1[0], c2[0], atol=1e-4))
self.assertTrue(mx.allclose(dc1[0], dc2[0], atol=1e-4))
self.assertTrue(mx.allclose(dc1[1], dc2[1], atol=1e-4))
def test_segmented_mm(self):
def segmented_mm_ref(a, b, s):
s = s.tolist()
c = []
for s1, s2 in s:
c.append(a[:, s1:s2] @ b[s1:s2, :])
return mx.stack(c, axis=0)
shapes = [
(10, 10, 10),
(10, 10, 1000),
(1000, 1000, 1000),
]
all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]
for M, N, K in shapes:
for s in all_segments:
segments = []
for i in range(len(s) - 1):
segments.append([s[i], s[i + 1]])
segments = mx.array(segments)
segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32))
a = mx.random.normal((M, K))
b = mx.random.normal((K, N))
c1 = segmented_mm_ref(a, b, segments)
c2 = mx.segmented_mm(a, b, segments)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
a = mx.random.normal((K, M))
b = mx.random.normal((K, N))
c1 = segmented_mm_ref(a.T, b, segments)
c2 = mx.segmented_mm(a.T, b, segments)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
a = mx.random.normal((M, K))
b = mx.random.normal((N, K))
c1 = segmented_mm_ref(a, b.T, segments)
c2 = mx.segmented_mm(a, b.T, segments)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
a = mx.random.normal((K, M))
b = mx.random.normal((N, K))
c1 = segmented_mm_ref(a.T, b.T, segments)
c2 = mx.segmented_mm(a.T, b.T, segments)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
with self.assertRaises(ValueError):
a = mx.ones((2, 10, 10))
s = mx.array([[0, 5], [5, 10]]).astype(mx.uint32)
mx.segmented_mm(a, a, s)
a = mx.ones((10, 1000))
s = mx.random.randint(0, 16, shape=(1000,))
s = mx.zeros(16, dtype=s.dtype).at[s].add(1)
s = mx.sort(s)
s = mx.cumsum(s)
s = mx.concatenate([mx.array([0]), s])
s = mx.as_strided(s, (16, 2), (1, 1))
s = mx.reshape(s, (2, 2, 4, 2))
c = mx.segmented_mm(a, a.T, s)
self.assertEqual(c.shape, (2, 2, 4, 10, 10))
def test_gemv_gemm_same_precision(self):
mx.random.seed(0)
N = 256

View File

@@ -549,6 +549,49 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
def test_gather_qmm_grad(self):
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
if lhs is not None:
x = x[lhs]
if rhs is not None:
w = w[rhs]
s = s[rhs]
b = b[rhs]
return mx.quantized_matmul(x, w, s, b, transpose=trans)
def gather_qmm(x, w, s, b, lhs, rhs, trans, sort):
return mx.gather_qmm(
x,
w,
s,
b,
transpose=trans,
lhs_indices=lhs,
rhs_indices=rhs,
sorted_indices=sort,
)
x = mx.random.normal((16, 1, 256))
w, s, b = mx.quantize(mx.random.normal((4, 256, 256)))
indices = mx.sort(mx.random.randint(0, 4, shape=(16,)))
cotan = mx.random.normal((16, 1, 256))
(o1,), (dx1, ds1, db1) = mx.vjp(
lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),
[x, s, b],
[cotan],
)
(o2,), (dx2, ds2, db2) = mx.vjp(
lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True),
[x, s, b],
[cotan],
)
self.assertTrue(mx.allclose(o1, o2, atol=1e-4))
self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))
self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))
self.assertTrue(mx.allclose(db1, db2, atol=1e-3))
def test_vjp_scales_biases(self):
mx.random.seed(0)
x = mx.random.normal(shape=(2, 2, 512))