Block sparse qmm (#1124)

This commit is contained in:
Angelos Katharopoulos
2024-05-16 15:24:14 -07:00
committed by GitHub
parent 1873ffda01
commit e78a6518fa
15 changed files with 1724 additions and 164 deletions

View File

@@ -4,6 +4,7 @@ import math
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.quantized import QuantizedEmbedding
class Embedding(Module):
@@ -37,3 +38,7 @@ class Embedding(Module):
weights are tied.
"""
return x @ self.weight.T
def to_quantized(self, group_size: int = 64, bits: int = 4):
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
return QuantizedEmbedding.from_embedding(self, group_size, bits)

View File

@@ -5,6 +5,7 @@ from typing import Any
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.quantized import QuantizedLinear
class Identity(Module):
@@ -69,6 +70,10 @@ class Linear(Module):
x = x @ self["weight"].T
return x
def to_quantized(self, group_size: int = 64, bits: int = 4):
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
return QuantizedLinear.from_linear(self, group_size, bits)
class Bilinear(Module):
r"""Applies a bilinear transformation to the inputs.

View File

@@ -5,8 +5,6 @@ from typing import Callable, Optional
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.utils import tree_map_with_path
@@ -18,8 +16,9 @@ def quantize(
):
"""Quantize the sub-modules of a module according to a predicate.
By default all :obj:`Linear` and :obj:`Embedding` layers will be
quantized. Note also, the module is updated in-place.
By default all layers that define a ``to_quantized(group_size, bits)``
method will be quantized. Both :obj:`Linear` and :obj:`Embedding` layers
will be quantized. Note also, the module is updated in-place.
Args:
model (mlx.nn.Module): The model whose leaf modules may be quantized.
@@ -30,18 +29,15 @@ def quantize(
class_predicate (Optional[Callable]): A callable which receives the
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if
it should be quantized and ``False`` otherwise. If ``None``, then
all linear and embedding layers are quantized. Default: ``None``.
all layers that define a ``to_quantized(group_size, bits)`` method
are quantized. Default: ``None``.
"""
class_predicate = class_predicate or (
lambda _, m: isinstance(m, (Linear, Embedding))
)
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
def _maybe_quantize(path, m):
if class_predicate(path, m):
if isinstance(m, Linear):
return QuantizedLinear.from_linear(m, group_size, bits)
elif isinstance(m, Embedding):
return QuantizedEmbedding.from_embedding(m, group_size, bits)
if hasattr(m, "to_quantized"):
return m.to_quantized(group_size, bits)
else:
raise ValueError(f"Unable to quantize model of type {type(m)}")
else:
@@ -129,7 +125,7 @@ class QuantizedEmbedding(Module):
@classmethod
def from_embedding(
cls, embedding_layer: Embedding, group_size: int = 64, bits: int = 4
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
@@ -220,7 +216,7 @@ class QuantizedLinear(Module):
return x
@classmethod
def from_linear(cls, linear_layer: Linear, group_size: int = 64, bits: int = 4):
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)