Rename block sparse (#1149)

* block_sparse_mm to gather_mm

* rename

* nit

* nit
This commit is contained in:
Awni Hannun
2024-05-22 07:48:34 -07:00
committed by GitHub
parent e6fecbb3e1
commit d568c7ee36
16 changed files with 120 additions and 111 deletions

View File

@@ -277,7 +277,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_block_sparse_qmm(self):
def test_gather_qmm(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
@@ -322,8 +322,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
if rhs_indices is not None:
rhs_indices = mx.array(rhs_indices)
c1 = mx.block_sparse_mm(x, w_hat, lhs_indices, rhs_indices)
c2 = mx.block_sparse_qmm(
c1 = mx.gather_mm(x, w_hat, lhs_indices, rhs_indices)
c2 = mx.gather_qmm(
x,
qw,
s,
@@ -390,7 +390,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
test_shape(32, 512, 32, transpose=False, **kwargs)
test_shape(1, 512, 32, transpose=False, **kwargs)
def test_block_sparse_matmul_grad(self):
def test_gather_matmul_grad(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
@@ -406,10 +406,10 @@ class TestQuantized(mlx_tests.MLXTestCase):
w_hat, qw, s, b = quantize(w)
def f_ref(x, w, i1, i2):
return mx.block_sparse_mm(x, w, i1, i2).sum()
return mx.gather_mm(x, w, i1, i2).sum()
def f_test(x, qw, s, b, i1, i2):
return mx.block_sparse_qmm(x, qw, s, b, i1, i2, transpose=True).sum()
return mx.gather_qmm(x, qw, s, b, i1, i2, transpose=True).sum()
r1 = f_ref(x, w_hat, lhs_indices, rhs_indices)
r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices)