mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Rename block sparse (#1149)
* block_sparse_mm to gather_mm * rename * nit * nit
This commit is contained in:
		| @@ -408,9 +408,9 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|                 with self.subTest( | ||||
|                         B=B, # Batch size | ||||
|                         D=D, # Dimension of mm | ||||
|                         n_kv_heads=n_kv_heads, # key-value heads  | ||||
|                         factor=factor, # factor to get query heads  | ||||
|                         qsl=qsl, # Query sequence length  | ||||
|                         n_kv_heads=n_kv_heads, # key-value heads | ||||
|                         factor=factor, # factor to get query heads | ||||
|                         qsl=qsl, # Query sequence length | ||||
|                         ksl=ksl, # Key sequence length | ||||
|                         dtype=dtype # Data type | ||||
|                     ): | ||||
| @@ -432,22 +432,22 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|                     k_np = np.random.uniform(-scale, scale, size=shape_keys).astype(np_dtype) | ||||
|                     v_np = np.random.uniform(-scale, scale, size=shape_values).astype(np_dtype) | ||||
|  | ||||
|                     # Rearrange to move heads up  | ||||
|                     # Rearrange to move heads up | ||||
|                     q_np_reshape = q_np.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4) | ||||
|                     k_np_reshape = k_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1) | ||||
|                     v_np_reshape = v_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4) | ||||
|  | ||||
|                     # Do attn style matmul | ||||
|                     s_np = q_np_reshape @ k_np_reshape | ||||
|                     o_np = s_np @ v_np_reshape  | ||||
|                     o_np = s_np @ v_np_reshape | ||||
|                     o_np = o_np.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1) | ||||
|  | ||||
|                     # Test mlx  | ||||
|                     # Test mlx | ||||
|                     q_mx = mx.array(q_np) | ||||
|                     k_mx = mx.array(k_np) | ||||
|                     v_mx = mx.array(v_np) | ||||
|  | ||||
|                     # Rearrange to move heads up  | ||||
|                     # Rearrange to move heads up | ||||
|                     q_mx_reshape = q_mx.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4) | ||||
|                     k_mx_reshape = k_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1) | ||||
|                     v_mx_reshape = v_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4) | ||||
| @@ -810,8 +810,8 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5)) | ||||
|  | ||||
|     def test_block_sparse_matmul(self): | ||||
|         def np_block_sparse_mm(a, b, lhs_indices=None, rhs_indices=None): | ||||
|     def test_gather_matmul(self): | ||||
|         def np_gather_mm(a, b, lhs_indices=None, rhs_indices=None): | ||||
|             a = a.reshape((-1, a.shape[-2], a.shape[-1])) | ||||
|             b = b.reshape((-1, b.shape[-2], b.shape[-1])) | ||||
|             lhs_indices = lhs_indices or np.arange(a.shape[0]) | ||||
| @@ -848,12 +848,12 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|                 a_mx = mx.array(a_np) | ||||
|                 b_mx = mx.array(b_np) | ||||
|  | ||||
|                 out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices) | ||||
|                 out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices) | ||||
|  | ||||
|                 lhs_indices_mx = None if lhs_indices is None else mx.array(lhs_indices) | ||||
|                 rhs_indices_mx = None if rhs_indices is None else mx.array(rhs_indices) | ||||
|  | ||||
|                 out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx) | ||||
|                 out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx) | ||||
|  | ||||
|                 self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) | ||||
|  | ||||
| @@ -920,7 +920,7 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|         lhs_indices = [0, 13, 12] | ||||
|         rhs_indices = [0, 3, 5] | ||||
|  | ||||
|         out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices) | ||||
|         out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices) | ||||
|  | ||||
|         # MLX | ||||
|         a_mx = a_mx.reshape((5, 1, 32, 32)) | ||||
| @@ -932,7 +932,7 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|         lhs_indices_mx = mx.array(lhs_indices) | ||||
|         rhs_indices_mx = mx.array(rhs_indices) | ||||
|  | ||||
|         out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx) | ||||
|         out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx) | ||||
|  | ||||
|         self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) | ||||
|  | ||||
| @@ -946,17 +946,17 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|         rhs_indices = [0, 2] | ||||
|  | ||||
|         b_np_t = np.swapaxes(b_np, -1, -2) | ||||
|         out_np = np_block_sparse_mm(a_np, b_np_t, lhs_indices, rhs_indices) | ||||
|         out_np = np_gather_mm(a_np, b_np_t, lhs_indices, rhs_indices) | ||||
|  | ||||
|         lhs_indices_mx = mx.array(lhs_indices) | ||||
|         rhs_indices_mx = mx.array(rhs_indices) | ||||
|  | ||||
|         b_mx_t = mx.swapaxes(b_mx, -1, -2) | ||||
|         out_mx = mx.block_sparse_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx) | ||||
|         out_mx = mx.gather_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx) | ||||
|  | ||||
|         self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) | ||||
|  | ||||
|     def test_block_sparse_matmul_grad(self): | ||||
|     def test_gather_matmul_grad(self): | ||||
|  | ||||
|         lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32) | ||||
|         rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32) | ||||
| @@ -977,7 +977,7 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|             return a @ b | ||||
|  | ||||
|         def f_test(a, b): | ||||
|             return mx.block_sparse_mm(a, b, lhs_indices, rhs_indices) | ||||
|             return mx.gather_mm(a, b, lhs_indices, rhs_indices) | ||||
|  | ||||
|         a_mx = mx.random.normal((4, 2, 32, 32)) | ||||
|         b_mx = mx.random.normal((4, 1, 32, 32)) | ||||
|   | ||||
| @@ -277,7 +277,7 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(y_q.shape, y_hat.shape) | ||||
|         self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|  | ||||
|     def test_block_sparse_qmm(self): | ||||
|     def test_gather_qmm(self): | ||||
|         def quantize(w, transpose=True, group_size=64, bits=4): | ||||
|             qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) | ||||
|             w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) | ||||
| @@ -322,8 +322,8 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|                 if rhs_indices is not None: | ||||
|                     rhs_indices = mx.array(rhs_indices) | ||||
|  | ||||
|                 c1 = mx.block_sparse_mm(x, w_hat, lhs_indices, rhs_indices) | ||||
|                 c2 = mx.block_sparse_qmm( | ||||
|                 c1 = mx.gather_mm(x, w_hat, lhs_indices, rhs_indices) | ||||
|                 c2 = mx.gather_qmm( | ||||
|                     x, | ||||
|                     qw, | ||||
|                     s, | ||||
| @@ -390,7 +390,7 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|             test_shape(32, 512, 32, transpose=False, **kwargs) | ||||
|             test_shape(1, 512, 32, transpose=False, **kwargs) | ||||
|  | ||||
|     def test_block_sparse_matmul_grad(self): | ||||
|     def test_gather_matmul_grad(self): | ||||
|         def quantize(w, transpose=True, group_size=64, bits=4): | ||||
|             qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) | ||||
|             w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) | ||||
| @@ -406,10 +406,10 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|         w_hat, qw, s, b = quantize(w) | ||||
|  | ||||
|         def f_ref(x, w, i1, i2): | ||||
|             return mx.block_sparse_mm(x, w, i1, i2).sum() | ||||
|             return mx.gather_mm(x, w, i1, i2).sum() | ||||
|  | ||||
|         def f_test(x, qw, s, b, i1, i2): | ||||
|             return mx.block_sparse_qmm(x, qw, s, b, i1, i2, transpose=True).sum() | ||||
|             return mx.gather_qmm(x, qw, s, b, i1, i2, transpose=True).sum() | ||||
|  | ||||
|         r1 = f_ref(x, w_hat, lhs_indices, rhs_indices) | ||||
|         r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun