Gather qmm batched kernel and refactoring of quantized (#2078)

This commit is contained in:
Angelos Katharopoulos
2025-04-17 13:53:11 -07:00
committed by GitHub
parent 99eefd2ec0
commit 5de6d94a90
15 changed files with 1479 additions and 449 deletions

View File

@@ -4250,9 +4250,10 @@ void init_ops(nb::module_& m) {
"group_size"_a = 64,
"bits"_a = 4,
nb::kw_only(),
"sorted_indices"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather.
@@ -4265,23 +4266,25 @@ void init_ops(nb::module_& m) {
as ``w`` since they represent the same quantized matrix.
Args:
x (array): Input array
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. Default: ``True``.
group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
x (array): Input array
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. Default: ``True``.
group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
sorted_indices (bool, optional): May allow a faster implementation
if the passed indices are sorted. Default: ``False``.
Returns:
array: The result of the multiplication of ``x`` with ``w``
after gathering using ``lhs_indices`` and ``rhs_indices``.
array: The result of the multiplication of ``x`` with ``w``
after gathering using ``lhs_indices`` and ``rhs_indices``.
)pbdoc");
m.def(
"tensordot",
@@ -4311,16 +4314,16 @@ void init_ops(nb::module_& m) {
Compute the tensor dot product along the specified axes.
Args:
a (array): Input array
b (array): Input array
axes (int or list(list(int)), optional): The number of dimensions to
sum over. If an integer is provided, then sum over the last
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
``b``. If a list of lists is provided, then sum over the
corresponding dimensions of ``a`` and ``b``. Default: 2.
a (array): Input array
b (array): Input array
axes (int or list(list(int)), optional): The number of dimensions to
sum over. If an integer is provided, then sum over the last
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
``b``. If a list of lists is provided, then sum over the
corresponding dimensions of ``a`` and ``b``. Default: 2.
Returns:
array: The tensor dot product.
array: The tensor dot product.
)pbdoc");
m.def(
"inner",

View File

@@ -174,12 +174,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
tests = product(
[128, 64, 32], # group_size
[2, 3, 4, 6, 8], # bits
[128, 256], # M
[32, 128, 256], # M
[128, 256, 67], # N
[0, 1, 3, 8], # B
)
for group_size, bits, M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
if M < group_size:
continue
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
@@ -448,6 +450,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
)
for kwargs in inputs:
test_shape(1, 32, 128, **kwargs)
test_shape(32, 32, 256, **kwargs)
test_shape(1, 32, 256, **kwargs)
test_shape(32, 256, 32, transpose=False, **kwargs)
@@ -486,6 +489,66 @@ class TestQuantized(mlx_tests.MLXTestCase):
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
def test_gather_qmm_sorted(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
parameters = [
# L, K, D, E, I, transpose
(128, 1024, 1024, 32, 4, True),
(128, 1024, 544, 32, 4, True),
(433, 1024, 1024, 32, 4, True),
(433, 1024, 555, 32, 4, True),
(433, 2048, 1024, 32, 4, True),
(128, 1024, 1024, 32, 4, False),
(128, 1024, 544, 32, 4, False),
(433, 1024, 1024, 32, 4, False),
(433, 1024, 544, 32, 4, False),
(433, 1024, 555, 32, 4, False),
(433, 2048, 1024, 32, 4, False),
]
for L, K, D, E, I, transpose in parameters:
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
wshape = (E, D, K) if transpose else (E, K, D)
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(w, transpose=transpose)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm(
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
if __name__ == "__main__":
unittest.main()