mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add rudimentary test for gather_mm with sorted indices
This commit is contained in:
@@ -1164,7 +1164,35 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||||
|
|
||||||
def test_gather_mm_sorted(self):
|
def test_gather_mm_sorted(self):
|
||||||
pass
|
def gather_mm_ref(a, b, rhs):
|
||||||
|
b = b[rhs]
|
||||||
|
return a @ b
|
||||||
|
|
||||||
|
def gather_mm_test(a, b, rhs):
|
||||||
|
return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True)
|
||||||
|
|
||||||
|
a = mx.random.normal((100, 1, 100))
|
||||||
|
b = mx.random.normal((8, 100, 100))
|
||||||
|
rhs = mx.sort(mx.random.randint(0, 8, shape=(100,)))
|
||||||
|
|
||||||
|
c1 = gather_mm_ref(a, b, rhs)
|
||||||
|
c2 = gather_mm_test(a, b, rhs)
|
||||||
|
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||||
|
|
||||||
|
cotan = mx.random.normal(c1.shape)
|
||||||
|
c1, dc1 = mx.vjp(
|
||||||
|
lambda a, b: gather_mm_ref(a, b, rhs),
|
||||||
|
[a, b],
|
||||||
|
[cotan],
|
||||||
|
)
|
||||||
|
c2, dc2 = mx.vjp(
|
||||||
|
lambda a, b: gather_mm_test(a, b, rhs),
|
||||||
|
[a, b],
|
||||||
|
[cotan],
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(c1[0], c2[0], atol=1e-4))
|
||||||
|
self.assertTrue(mx.allclose(dc1[0], dc2[0], atol=1e-4))
|
||||||
|
self.assertTrue(mx.allclose(dc1[1], dc2[1], atol=1e-4))
|
||||||
|
|
||||||
def test_segmented_mm(self):
|
def test_segmented_mm(self):
|
||||||
def segmented_mm_ref(a, b, s):
|
def segmented_mm_ref(a, b, s):
|
||||||
|
|||||||
Reference in New Issue
Block a user