mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Masked gemv (#1211)
This commit is contained in:
@@ -263,7 +263,9 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
mlx_mat_f=lambda x: x,
|
||||
mlx_vec_f=lambda x: x,
|
||||
):
|
||||
with self.subTest(shape=shape_mat):
|
||||
with self.subTest(
|
||||
shape_mat=shape_mat, shape_vec=shape_vec, mat_first=mat_first
|
||||
):
|
||||
np.random.seed(42)
|
||||
scale = max(np.sum(shape_mat), 32)
|
||||
mat_npy = np.random.normal(0.0, 1.0 / scale, shape_mat).astype(np_dtype)
|
||||
@@ -794,10 +796,12 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
out_ref, dout_ref = mx.vjp(f_ref, [a, b], [cotan])
|
||||
out_test, dout_test = mx.vjp(f_test, [a, b], [cotan])
|
||||
|
||||
mx.eval((out_ref, dout_ref, out_test, dout_test))
|
||||
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())
|
||||
|
||||
for r, t in zip(dout_ref, dout_test):
|
||||
self.assertEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
|
||||
def run_test_mask_vjp(a, b, block_size, out_mask, a_mask, b_mask, cotan):
|
||||
def f_ref(a_, b_, a_mask_, b_mask_):
|
||||
return ref_block_masked_mm(
|
||||
@@ -896,6 +900,8 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
(64, 64, 16, 32),
|
||||
(128, 128, 128, 32),
|
||||
(256, 256, 128, 64),
|
||||
(1, 128, 128, 32),
|
||||
(256, 1, 128, 64),
|
||||
)
|
||||
|
||||
for M, N, K, block_size in shapes:
|
||||
@@ -903,21 +909,51 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
# Test broadcasting
|
||||
test_shape(64, 64, 64, 32, batch_A=(1, 2), batch_B=(2, 2))
|
||||
test_shape(1, 128, 128, 32, batch_A=(1, 2), batch_B=(2, 2))
|
||||
test_shape(128, 1, 128, 32, batch_A=(1, 2), batch_B=(2, 2))
|
||||
|
||||
# Test gemv
|
||||
a_np = np.random.normal(size=(64, 64)).astype(np.float32)
|
||||
b_np = np.random.normal(size=(64,)).astype(np.float32)
|
||||
mask_np = np.array([True, False]).astype(np.bool_)
|
||||
a_np = np.ones((128, 256)).astype(np.float32)
|
||||
b_np = np.ones((128, 1)).astype(np.float32)
|
||||
d_np = np.ones((1, 256)).astype(np.float32)
|
||||
a_mask_np = np.random.normal(size=(4, 8)).astype(np.float32)
|
||||
b_mask_np = np.ones((4, 1)).astype(np.bool_)
|
||||
d_mask_np = np.ones((1, 8)).astype(np.bool_)
|
||||
c_mask_np = np.random.normal(size=(8, 1)).astype(np.float32)
|
||||
e_mask_np = np.random.normal(size=(1, 4)).astype(np.float32)
|
||||
|
||||
a_mask_np[a_mask_np < 0.0] = 0.0
|
||||
e_mask_np[e_mask_np < 0.0] = 0.0
|
||||
c_mask_np[c_mask_np < 0.0] = 0.0
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
mask_mx = mx.array(mask_np)
|
||||
d_mx = mx.array(d_np)
|
||||
a_mask_mx = mx.array(a_mask_np)
|
||||
b_mask_mx = mx.array(b_mask_np)
|
||||
d_mask_mx = mx.array(d_mask_np)
|
||||
e_mask_mx = mx.array(e_mask_np)
|
||||
c_mask_mx = mx.array(c_mask_np)
|
||||
|
||||
c_mx = mx.block_masked_mm(a_mx, b_mx, 32, mask_mx)
|
||||
c_np = a_np @ b_np
|
||||
c_np[32:] = 0.0
|
||||
c_mx = mx.block_masked_mm(a_mx.T, b_mx, 32, c_mask_mx, a_mask_mx.T, b_mask_mx)
|
||||
e_mx = mx.block_masked_mm(d_mx, a_mx.T, 32, e_mask_mx, d_mask_mx, a_mask_mx.T)
|
||||
|
||||
a_mask_np = np.broadcast_to(np.expand_dims(a_mask_np, (-3, -1)), (4, 32, 8, 32))
|
||||
a_mask_np = a_mask_np.reshape((128, 256))
|
||||
a_np *= a_mask_np
|
||||
|
||||
c_np = a_np.T @ b_np
|
||||
e_np = d_np @ a_np.T
|
||||
|
||||
c_mask_np = np.broadcast_to(np.expand_dims(c_mask_np, (-2)), (8, 32, 1))
|
||||
c_mask_np = c_mask_np.reshape((256, 1))
|
||||
c_np *= c_mask_np
|
||||
|
||||
e_mask_np = np.broadcast_to(np.expand_dims(e_mask_np, (-1)), (1, 4, 32))
|
||||
e_mask_np = e_mask_np.reshape((1, 128))
|
||||
e_np *= e_mask_np
|
||||
|
||||
self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))
|
||||
self.assertTrue(np.allclose(e_mx, e_np, atol=1e-5))
|
||||
|
||||
def test_gather_matmul(self):
|
||||
def np_gather_mm(a, b, lhs_indices=None, rhs_indices=None):
|
||||
|
Reference in New Issue
Block a user