Masked gemv (#1211)

This commit is contained in:
Jagrit Digani
2024-06-14 09:52:26 -07:00
committed by GitHub
parent fe3167d7ea
commit 2d6cd47713
6 changed files with 1729 additions and 421 deletions

View File

@@ -263,7 +263,9 @@ class TestBlas(mlx_tests.MLXTestCase):
mlx_mat_f=lambda x: x,
mlx_vec_f=lambda x: x,
):
with self.subTest(shape=shape_mat):
with self.subTest(
shape_mat=shape_mat, shape_vec=shape_vec, mat_first=mat_first
):
np.random.seed(42)
scale = max(np.sum(shape_mat), 32)
mat_npy = np.random.normal(0.0, 1.0 / scale, shape_mat).astype(np_dtype)
@@ -794,10 +796,12 @@ class TestBlas(mlx_tests.MLXTestCase):
out_ref, dout_ref = mx.vjp(f_ref, [a, b], [cotan])
out_test, dout_test = mx.vjp(f_test, [a, b], [cotan])
mx.eval((out_ref, dout_ref, out_test, dout_test))
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())
for r, t in zip(dout_ref, dout_test):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def run_test_mask_vjp(a, b, block_size, out_mask, a_mask, b_mask, cotan):
def f_ref(a_, b_, a_mask_, b_mask_):
return ref_block_masked_mm(
@@ -896,6 +900,8 @@ class TestBlas(mlx_tests.MLXTestCase):
(64, 64, 16, 32),
(128, 128, 128, 32),
(256, 256, 128, 64),
(1, 128, 128, 32),
(256, 1, 128, 64),
)
for M, N, K, block_size in shapes:
@@ -903,21 +909,51 @@ class TestBlas(mlx_tests.MLXTestCase):
# Test broadcasting
test_shape(64, 64, 64, 32, batch_A=(1, 2), batch_B=(2, 2))
test_shape(1, 128, 128, 32, batch_A=(1, 2), batch_B=(2, 2))
test_shape(128, 1, 128, 32, batch_A=(1, 2), batch_B=(2, 2))
# Test gemv
a_np = np.random.normal(size=(64, 64)).astype(np.float32)
b_np = np.random.normal(size=(64,)).astype(np.float32)
mask_np = np.array([True, False]).astype(np.bool_)
a_np = np.ones((128, 256)).astype(np.float32)
b_np = np.ones((128, 1)).astype(np.float32)
d_np = np.ones((1, 256)).astype(np.float32)
a_mask_np = np.random.normal(size=(4, 8)).astype(np.float32)
b_mask_np = np.ones((4, 1)).astype(np.bool_)
d_mask_np = np.ones((1, 8)).astype(np.bool_)
c_mask_np = np.random.normal(size=(8, 1)).astype(np.float32)
e_mask_np = np.random.normal(size=(1, 4)).astype(np.float32)
a_mask_np[a_mask_np < 0.0] = 0.0
e_mask_np[e_mask_np < 0.0] = 0.0
c_mask_np[c_mask_np < 0.0] = 0.0
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
mask_mx = mx.array(mask_np)
d_mx = mx.array(d_np)
a_mask_mx = mx.array(a_mask_np)
b_mask_mx = mx.array(b_mask_np)
d_mask_mx = mx.array(d_mask_np)
e_mask_mx = mx.array(e_mask_np)
c_mask_mx = mx.array(c_mask_np)
c_mx = mx.block_masked_mm(a_mx, b_mx, 32, mask_mx)
c_np = a_np @ b_np
c_np[32:] = 0.0
c_mx = mx.block_masked_mm(a_mx.T, b_mx, 32, c_mask_mx, a_mask_mx.T, b_mask_mx)
e_mx = mx.block_masked_mm(d_mx, a_mx.T, 32, e_mask_mx, d_mask_mx, a_mask_mx.T)
a_mask_np = np.broadcast_to(np.expand_dims(a_mask_np, (-3, -1)), (4, 32, 8, 32))
a_mask_np = a_mask_np.reshape((128, 256))
a_np *= a_mask_np
c_np = a_np.T @ b_np
e_np = d_np @ a_np.T
c_mask_np = np.broadcast_to(np.expand_dims(c_mask_np, (-2)), (8, 32, 1))
c_mask_np = c_mask_np.reshape((256, 1))
c_np *= c_mask_np
e_mask_np = np.broadcast_to(np.expand_dims(e_mask_np, (-1)), (1, 4, 32))
e_mask_np = e_mask_np.reshape((1, 128))
e_np *= e_mask_np
self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))
self.assertTrue(np.allclose(e_mx, e_np, atol=1e-5))
def test_gather_matmul(self):
def np_gather_mm(a, b, lhs_indices=None, rhs_indices=None):