mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:10:15 +08:00
Rename block sparse (#1149)
* block_sparse_mm to gather_mm * rename * nit * nit
This commit is contained in:
@@ -408,9 +408,9 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
with self.subTest(
|
||||
B=B, # Batch size
|
||||
D=D, # Dimension of mm
|
||||
n_kv_heads=n_kv_heads, # key-value heads
|
||||
factor=factor, # factor to get query heads
|
||||
qsl=qsl, # Query sequence length
|
||||
n_kv_heads=n_kv_heads, # key-value heads
|
||||
factor=factor, # factor to get query heads
|
||||
qsl=qsl, # Query sequence length
|
||||
ksl=ksl, # Key sequence length
|
||||
dtype=dtype # Data type
|
||||
):
|
||||
@@ -432,22 +432,22 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
k_np = np.random.uniform(-scale, scale, size=shape_keys).astype(np_dtype)
|
||||
v_np = np.random.uniform(-scale, scale, size=shape_values).astype(np_dtype)
|
||||
|
||||
# Rearrange to move heads up
|
||||
# Rearrange to move heads up
|
||||
q_np_reshape = q_np.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
|
||||
k_np_reshape = k_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
|
||||
v_np_reshape = v_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
|
||||
|
||||
# Do attn style matmul
|
||||
s_np = q_np_reshape @ k_np_reshape
|
||||
o_np = s_np @ v_np_reshape
|
||||
o_np = s_np @ v_np_reshape
|
||||
o_np = o_np.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)
|
||||
|
||||
# Test mlx
|
||||
# Test mlx
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
# Rearrange to move heads up
|
||||
# Rearrange to move heads up
|
||||
q_mx_reshape = q_mx.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
|
||||
k_mx_reshape = k_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
|
||||
v_mx_reshape = v_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
|
||||
@@ -810,8 +810,8 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))
|
||||
|
||||
def test_block_sparse_matmul(self):
|
||||
def np_block_sparse_mm(a, b, lhs_indices=None, rhs_indices=None):
|
||||
def test_gather_matmul(self):
|
||||
def np_gather_mm(a, b, lhs_indices=None, rhs_indices=None):
|
||||
a = a.reshape((-1, a.shape[-2], a.shape[-1]))
|
||||
b = b.reshape((-1, b.shape[-2], b.shape[-1]))
|
||||
lhs_indices = lhs_indices or np.arange(a.shape[0])
|
||||
@@ -848,12 +848,12 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices)
|
||||
out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices)
|
||||
|
||||
lhs_indices_mx = None if lhs_indices is None else mx.array(lhs_indices)
|
||||
rhs_indices_mx = None if rhs_indices is None else mx.array(rhs_indices)
|
||||
|
||||
out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
|
||||
out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
|
||||
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
@@ -920,7 +920,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
lhs_indices = [0, 13, 12]
|
||||
rhs_indices = [0, 3, 5]
|
||||
|
||||
out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices)
|
||||
out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices)
|
||||
|
||||
# MLX
|
||||
a_mx = a_mx.reshape((5, 1, 32, 32))
|
||||
@@ -932,7 +932,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
lhs_indices_mx = mx.array(lhs_indices)
|
||||
rhs_indices_mx = mx.array(rhs_indices)
|
||||
|
||||
out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
|
||||
out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
|
||||
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
@@ -946,17 +946,17 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
rhs_indices = [0, 2]
|
||||
|
||||
b_np_t = np.swapaxes(b_np, -1, -2)
|
||||
out_np = np_block_sparse_mm(a_np, b_np_t, lhs_indices, rhs_indices)
|
||||
out_np = np_gather_mm(a_np, b_np_t, lhs_indices, rhs_indices)
|
||||
|
||||
lhs_indices_mx = mx.array(lhs_indices)
|
||||
rhs_indices_mx = mx.array(rhs_indices)
|
||||
|
||||
b_mx_t = mx.swapaxes(b_mx, -1, -2)
|
||||
out_mx = mx.block_sparse_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx)
|
||||
out_mx = mx.gather_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx)
|
||||
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
def test_block_sparse_matmul_grad(self):
|
||||
def test_gather_matmul_grad(self):
|
||||
|
||||
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
|
||||
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
||||
@@ -977,7 +977,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
return a @ b
|
||||
|
||||
def f_test(a, b):
|
||||
return mx.block_sparse_mm(a, b, lhs_indices, rhs_indices)
|
||||
return mx.gather_mm(a, b, lhs_indices, rhs_indices)
|
||||
|
||||
a_mx = mx.random.normal((4, 2, 32, 32))
|
||||
b_mx = mx.random.normal((4, 1, 32, 32))
|
||||
|
Reference in New Issue
Block a user