mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add rudimentary test for gather_mm with sorted indices
This commit is contained in:
@@ -1164,7 +1164,35 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
|
||||
def test_gather_mm_sorted(self):
|
||||
pass
|
||||
def gather_mm_ref(a, b, rhs):
|
||||
b = b[rhs]
|
||||
return a @ b
|
||||
|
||||
def gather_mm_test(a, b, rhs):
|
||||
return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True)
|
||||
|
||||
a = mx.random.normal((100, 1, 100))
|
||||
b = mx.random.normal((8, 100, 100))
|
||||
rhs = mx.sort(mx.random.randint(0, 8, shape=(100,)))
|
||||
|
||||
c1 = gather_mm_ref(a, b, rhs)
|
||||
c2 = gather_mm_test(a, b, rhs)
|
||||
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||
|
||||
cotan = mx.random.normal(c1.shape)
|
||||
c1, dc1 = mx.vjp(
|
||||
lambda a, b: gather_mm_ref(a, b, rhs),
|
||||
[a, b],
|
||||
[cotan],
|
||||
)
|
||||
c2, dc2 = mx.vjp(
|
||||
lambda a, b: gather_mm_test(a, b, rhs),
|
||||
[a, b],
|
||||
[cotan],
|
||||
)
|
||||
self.assertTrue(mx.allclose(c1[0], c2[0], atol=1e-4))
|
||||
self.assertTrue(mx.allclose(dc1[0], dc2[0], atol=1e-4))
|
||||
self.assertTrue(mx.allclose(dc1[1], dc2[1], atol=1e-4))
|
||||
|
||||
def test_segmented_mm(self):
|
||||
def segmented_mm_ref(a, b, s):
|
||||
|
||||
Reference in New Issue
Block a user