mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Block sparse qmm (#1124)
This commit is contained in:

committed by
GitHub

parent
1873ffda01
commit
e78a6518fa
@@ -4,6 +4,7 @@ import math
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.quantized import QuantizedEmbedding
|
||||
|
||||
|
||||
class Embedding(Module):
|
||||
@@ -37,3 +38,7 @@ class Embedding(Module):
|
||||
weights are tied.
|
||||
"""
|
||||
return x @ self.weight.T
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
||||
return QuantizedEmbedding.from_embedding(self, group_size, bits)
|
||||
|
@@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
|
||||
|
||||
class Identity(Module):
|
||||
@@ -69,6 +70,10 @@ class Linear(Module):
|
||||
x = x @ self["weight"].T
|
||||
return x
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
|
||||
return QuantizedLinear.from_linear(self, group_size, bits)
|
||||
|
||||
|
||||
class Bilinear(Module):
|
||||
r"""Applies a bilinear transformation to the inputs.
|
||||
|
@@ -5,8 +5,6 @@ from typing import Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.utils import tree_map_with_path
|
||||
|
||||
|
||||
@@ -18,8 +16,9 @@ def quantize(
|
||||
):
|
||||
"""Quantize the sub-modules of a module according to a predicate.
|
||||
|
||||
By default all :obj:`Linear` and :obj:`Embedding` layers will be
|
||||
quantized. Note also, the module is updated in-place.
|
||||
By default all layers that define a ``to_quantized(group_size, bits)``
|
||||
method will be quantized. Both :obj:`Linear` and :obj:`Embedding` layers
|
||||
will be quantized. Note also, the module is updated in-place.
|
||||
|
||||
Args:
|
||||
model (mlx.nn.Module): The model whose leaf modules may be quantized.
|
||||
@@ -30,18 +29,15 @@ def quantize(
|
||||
class_predicate (Optional[Callable]): A callable which receives the
|
||||
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if
|
||||
it should be quantized and ``False`` otherwise. If ``None``, then
|
||||
all linear and embedding layers are quantized. Default: ``None``.
|
||||
all layers that define a ``to_quantized(group_size, bits)`` method
|
||||
are quantized. Default: ``None``.
|
||||
"""
|
||||
class_predicate = class_predicate or (
|
||||
lambda _, m: isinstance(m, (Linear, Embedding))
|
||||
)
|
||||
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
|
||||
|
||||
def _maybe_quantize(path, m):
|
||||
if class_predicate(path, m):
|
||||
if isinstance(m, Linear):
|
||||
return QuantizedLinear.from_linear(m, group_size, bits)
|
||||
elif isinstance(m, Embedding):
|
||||
return QuantizedEmbedding.from_embedding(m, group_size, bits)
|
||||
if hasattr(m, "to_quantized"):
|
||||
return m.to_quantized(group_size, bits)
|
||||
else:
|
||||
raise ValueError(f"Unable to quantize model of type {type(m)}")
|
||||
else:
|
||||
@@ -129,7 +125,7 @@ class QuantizedEmbedding(Module):
|
||||
|
||||
@classmethod
|
||||
def from_embedding(
|
||||
cls, embedding_layer: Embedding, group_size: int = 64, bits: int = 4
|
||||
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
||||
):
|
||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||
embedding_dims, dims = embedding_layer.weight.shape
|
||||
@@ -220,7 +216,7 @@ class QuantizedLinear(Module):
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Linear, group_size: int = 64, bits: int = 4):
|
||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
|
@@ -3747,6 +3747,52 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
result (array): The dequantized version of ``w``
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"block_sparse_qmm",
|
||||
&block_sparse_qmm,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"lhs_indices"_a = nb::none(),
|
||||
"rhs_indices"_a = nb::none(),
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def block_sparse_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform quantized matrix multiplication with matrix-level gather.
|
||||
|
||||
This operation is the quantized equivalent to :func:`block_sparse_mm`.
|
||||
Similar to :func:`block_sparse_mm`, the indices ``lhs_indices`` and
|
||||
``rhs_indices`` contain flat indices along the batch dimensions (i.e.
|
||||
all but the last two dimensions) of ``x`` and ``w`` respectively.
|
||||
|
||||
Note that ``scales`` and ``biases`` must have the same batch dimensions
|
||||
as ``w`` since they represent the same quantized matrix.
|
||||
|
||||
Args:
|
||||
x (array): Input array
|
||||
w (array): Quantized matrix packed in unsigned integers
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
lhs_indices (array, optional): Integer indices for ``x`` (default: ``None``)
|
||||
rhs_indices (array, optional): Integer indices for ``w`` (default: ``None``)
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. (default: ``True``)
|
||||
group_size (int, optional): The size of the group in ``w`` that
|
||||
shares a scale and bias. (default: ``64``)
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
|
||||
Returns:
|
||||
result (array): The result of the multiplication of ``x`` with ``w``
|
||||
after gathering using ``lhs_indices`` and ``rhs_indices``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"tensordot",
|
||||
[](const array& a,
|
||||
@@ -3933,7 +3979,7 @@ void init_ops(nb::module_& m) {
|
||||
Matrix multiplication with matrix-level gather.
|
||||
|
||||
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
|
||||
This operation is more efficient than explicitly applying a :func:``take`` followed by a :func:``matmul``.
|
||||
This operation is more efficient than explicitly applying a :func:`take` followed by a :func:`matmul`.
|
||||
|
||||
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.
|
||||
|
||||
|
@@ -277,6 +277,148 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_block_sparse_qmm(self):
|
||||
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
||||
if transpose:
|
||||
w_hat = w_hat.swapaxes(-1, -2)
|
||||
return w_hat, qw, s, b
|
||||
|
||||
def test_shape(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
dtype=mx.float32,
|
||||
batch_A=(),
|
||||
batch_B=(),
|
||||
lhs_indices=None,
|
||||
rhs_indices=None,
|
||||
transpose=True,
|
||||
group_size=64,
|
||||
bits=4,
|
||||
):
|
||||
with self.subTest(
|
||||
M=M,
|
||||
N=N,
|
||||
K=K,
|
||||
dtype=dtype,
|
||||
batch_A=batch_A,
|
||||
batch_B=batch_B,
|
||||
lhs_indices=lhs_indices,
|
||||
rhs_indices=rhs_indices,
|
||||
transpose=transpose,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
):
|
||||
x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)
|
||||
w = mx.random.normal(
|
||||
shape=batch_B + ((N, K) if transpose else (K, N))
|
||||
).astype(dtype)
|
||||
w_hat, qw, s, b = quantize(w, transpose, group_size, bits)
|
||||
|
||||
if lhs_indices is not None:
|
||||
lhs_indices = mx.array(lhs_indices)
|
||||
if rhs_indices is not None:
|
||||
rhs_indices = mx.array(rhs_indices)
|
||||
|
||||
c1 = mx.block_sparse_mm(x, w_hat, lhs_indices, rhs_indices)
|
||||
c2 = mx.block_sparse_qmm(
|
||||
x,
|
||||
qw,
|
||||
s,
|
||||
b,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
transpose=transpose,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||
|
||||
inputs = (
|
||||
{
|
||||
"batch_A": (1,),
|
||||
"lhs_indices": (0,),
|
||||
"batch_B": (3,),
|
||||
"rhs_indices": (2, 1),
|
||||
},
|
||||
{
|
||||
"batch_A": (1,),
|
||||
"lhs_indices": None,
|
||||
"batch_B": (3,),
|
||||
"rhs_indices": (2, 1),
|
||||
},
|
||||
{
|
||||
"batch_A": (2,),
|
||||
"lhs_indices": None,
|
||||
"batch_B": (3,),
|
||||
"rhs_indices": (2, 1),
|
||||
},
|
||||
{
|
||||
"batch_A": (3,),
|
||||
"lhs_indices": (0, 2),
|
||||
"batch_B": (1,),
|
||||
"rhs_indices": (0,),
|
||||
},
|
||||
{
|
||||
"batch_A": (5,),
|
||||
"lhs_indices": (0, 2),
|
||||
"batch_B": (3,),
|
||||
"rhs_indices": (2, 1),
|
||||
},
|
||||
{
|
||||
"batch_A": (4, 2),
|
||||
"lhs_indices": (
|
||||
(7, 6),
|
||||
(5, 4),
|
||||
(1, 2),
|
||||
),
|
||||
"batch_B": (4, 1),
|
||||
"rhs_indices": ((2,), (0,), (1,)),
|
||||
},
|
||||
)
|
||||
|
||||
for kwargs in inputs:
|
||||
test_shape(32, 32, 256, **kwargs)
|
||||
test_shape(1, 32, 256, **kwargs)
|
||||
test_shape(32, 256, 32, transpose=False, **kwargs)
|
||||
test_shape(1, 256, 32, transpose=False, **kwargs)
|
||||
test_shape(32, 32, 512, **kwargs)
|
||||
test_shape(1, 32, 512, **kwargs)
|
||||
test_shape(32, 512, 32, transpose=False, **kwargs)
|
||||
test_shape(1, 512, 32, transpose=False, **kwargs)
|
||||
|
||||
def test_block_sparse_matmul_grad(self):
|
||||
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
||||
if transpose:
|
||||
w_hat = w_hat.swapaxes(-1, -2)
|
||||
return w_hat, qw, s, b
|
||||
|
||||
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
|
||||
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
||||
|
||||
x = mx.random.normal((4, 2, 32, 256))
|
||||
w = mx.random.normal((4, 1, 32, 256))
|
||||
w_hat, qw, s, b = quantize(w)
|
||||
|
||||
def f_ref(x, w, i1, i2):
|
||||
return mx.block_sparse_mm(x, w, i1, i2).sum()
|
||||
|
||||
def f_test(x, qw, s, b, i1, i2):
|
||||
return mx.block_sparse_qmm(x, qw, s, b, i1, i2, transpose=True).sum()
|
||||
|
||||
r1 = f_ref(x, w_hat, lhs_indices, rhs_indices)
|
||||
r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices)
|
||||
self.assertTrue(mx.allclose(r1, r2, atol=1e-4))
|
||||
|
||||
g1 = mx.grad(f_ref)(x, w_hat, lhs_indices, rhs_indices)
|
||||
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
|
||||
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user