Rename block sparse (#1149)

* block_sparse_mm to gather_mm

* rename

* nit

* nit
This commit is contained in:
Awni Hannun
2024-05-22 07:48:34 -07:00
committed by GitHub
parent e6fecbb3e1
commit d568c7ee36
16 changed files with 120 additions and 111 deletions

View File

@@ -40,6 +40,15 @@ double scalar_to_double(Scalar s) {
}
void init_ops(nb::module_& m) {
// TODO, remove deprecation errors in a future release
m.def("block_sparse_mm", [](nb::args, nb::kwargs) {
throw std::invalid_argument(
"block_sparse_mm is deprecated. Please use gather_mm which has the same signature");
});
m.def("block_sparse_qmm", [](nb::args, nb::kwargs) {
throw std::invalid_argument(
"block_sparse_qmm is deprecated. Please use gather_qmm which has the same signature");
});
m.def(
"reshape",
&reshape,
@@ -3748,8 +3757,8 @@ void init_ops(nb::module_& m) {
array: The dequantized version of ``w``
)pbdoc");
m.def(
"block_sparse_qmm",
&block_sparse_qmm,
"gater_qmm",
&gather_qmm,
nb::arg(),
nb::arg(),
"scales"_a,
@@ -3762,12 +3771,12 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def block_sparse_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather.
This operation is the quantized equivalent to :func:`block_sparse_mm`.
Similar to :func:`block_sparse_mm`, the indices ``lhs_indices`` and
This operation is the quantized equivalent to :func:`gather_mm`.
Similar to :func:`gather_mm`, the indices ``lhs_indices`` and
``rhs_indices`` contain flat indices along the batch dimensions (i.e.
all but the last two dimensions) of ``x`` and ``w`` respectively.
@@ -3965,8 +3974,8 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"block_sparse_mm",
&block_sparse_mm,
"gather_mm",
&gather_mm,
nb::arg(),
nb::arg(),
"lhs_indices"_a = nb::none(),
@@ -3974,20 +3983,24 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def block_sparse_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Matrix multiplication with matrix-level gather.
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
This operation is more efficient than explicitly applying a :func:`take` followed by a :func:`matmul`.
Performs a gather of the operands with the given indices followed by a
(possibly batched) matrix multiplication of two arrays. This operation
is more efficient than explicitly applying a :func:`take` followed by a
:func:`matmul`.
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices
along the batch dimensions (i.e. all but the last two dimensions) of
``a`` and ``b`` respectively.
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``,
``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)``
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, ``lhs_indices``
contains indices from the range ``[0, A1 * A2 * ... * AS)``
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``,
``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)``
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``
contains indices from the range ``[0, B1 * B2 * ... * BS)``
Args:
a (array): Input array.

View File

@@ -408,9 +408,9 @@ class TestBlas(mlx_tests.MLXTestCase):
with self.subTest(
B=B, # Batch size
D=D, # Dimension of mm
n_kv_heads=n_kv_heads, # key-value heads
factor=factor, # factor to get query heads
qsl=qsl, # Query sequence length
n_kv_heads=n_kv_heads, # key-value heads
factor=factor, # factor to get query heads
qsl=qsl, # Query sequence length
ksl=ksl, # Key sequence length
dtype=dtype # Data type
):
@@ -432,22 +432,22 @@ class TestBlas(mlx_tests.MLXTestCase):
k_np = np.random.uniform(-scale, scale, size=shape_keys).astype(np_dtype)
v_np = np.random.uniform(-scale, scale, size=shape_values).astype(np_dtype)
# Rearrange to move heads up
# Rearrange to move heads up
q_np_reshape = q_np.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
k_np_reshape = k_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
v_np_reshape = v_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
# Do attn style matmul
s_np = q_np_reshape @ k_np_reshape
o_np = s_np @ v_np_reshape
o_np = s_np @ v_np_reshape
o_np = o_np.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)
# Test mlx
# Test mlx
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
# Rearrange to move heads up
# Rearrange to move heads up
q_mx_reshape = q_mx.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
k_mx_reshape = k_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
v_mx_reshape = v_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
@@ -810,8 +810,8 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))
def test_block_sparse_matmul(self):
def np_block_sparse_mm(a, b, lhs_indices=None, rhs_indices=None):
def test_gather_matmul(self):
def np_gather_mm(a, b, lhs_indices=None, rhs_indices=None):
a = a.reshape((-1, a.shape[-2], a.shape[-1]))
b = b.reshape((-1, b.shape[-2], b.shape[-1]))
lhs_indices = lhs_indices or np.arange(a.shape[0])
@@ -848,12 +848,12 @@ class TestBlas(mlx_tests.MLXTestCase):
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices)
out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices)
lhs_indices_mx = None if lhs_indices is None else mx.array(lhs_indices)
rhs_indices_mx = None if rhs_indices is None else mx.array(rhs_indices)
out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
@@ -920,7 +920,7 @@ class TestBlas(mlx_tests.MLXTestCase):
lhs_indices = [0, 13, 12]
rhs_indices = [0, 3, 5]
out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices)
out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices)
# MLX
a_mx = a_mx.reshape((5, 1, 32, 32))
@@ -932,7 +932,7 @@ class TestBlas(mlx_tests.MLXTestCase):
lhs_indices_mx = mx.array(lhs_indices)
rhs_indices_mx = mx.array(rhs_indices)
out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
@@ -946,17 +946,17 @@ class TestBlas(mlx_tests.MLXTestCase):
rhs_indices = [0, 2]
b_np_t = np.swapaxes(b_np, -1, -2)
out_np = np_block_sparse_mm(a_np, b_np_t, lhs_indices, rhs_indices)
out_np = np_gather_mm(a_np, b_np_t, lhs_indices, rhs_indices)
lhs_indices_mx = mx.array(lhs_indices)
rhs_indices_mx = mx.array(rhs_indices)
b_mx_t = mx.swapaxes(b_mx, -1, -2)
out_mx = mx.block_sparse_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx)
out_mx = mx.gather_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
def test_block_sparse_matmul_grad(self):
def test_gather_matmul_grad(self):
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
@@ -977,7 +977,7 @@ class TestBlas(mlx_tests.MLXTestCase):
return a @ b
def f_test(a, b):
return mx.block_sparse_mm(a, b, lhs_indices, rhs_indices)
return mx.gather_mm(a, b, lhs_indices, rhs_indices)
a_mx = mx.random.normal((4, 2, 32, 32))
b_mx = mx.random.normal((4, 1, 32, 32))

View File

@@ -277,7 +277,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_block_sparse_qmm(self):
def test_gather_qmm(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
@@ -322,8 +322,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
if rhs_indices is not None:
rhs_indices = mx.array(rhs_indices)
c1 = mx.block_sparse_mm(x, w_hat, lhs_indices, rhs_indices)
c2 = mx.block_sparse_qmm(
c1 = mx.gather_mm(x, w_hat, lhs_indices, rhs_indices)
c2 = mx.gather_qmm(
x,
qw,
s,
@@ -390,7 +390,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
test_shape(32, 512, 32, transpose=False, **kwargs)
test_shape(1, 512, 32, transpose=False, **kwargs)
def test_block_sparse_matmul_grad(self):
def test_gather_matmul_grad(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
@@ -406,10 +406,10 @@ class TestQuantized(mlx_tests.MLXTestCase):
w_hat, qw, s, b = quantize(w)
def f_ref(x, w, i1, i2):
return mx.block_sparse_mm(x, w, i1, i2).sum()
return mx.gather_mm(x, w, i1, i2).sum()
def f_test(x, qw, s, b, i1, i2):
return mx.block_sparse_qmm(x, qw, s, b, i1, i2, transpose=True).sum()
return mx.gather_qmm(x, qw, s, b, i1, i2, transpose=True).sum()
r1 = f_ref(x, w_hat, lhs_indices, rhs_indices)
r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices)