mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-21 04:31:48 +08:00
Block sparse qmm (#1124)
This commit is contained in:

committed by
GitHub

parent
1873ffda01
commit
e78a6518fa
@@ -4,6 +4,7 @@ import math
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.quantized import QuantizedEmbedding
|
||||
|
||||
|
||||
class Embedding(Module):
|
||||
@@ -37,3 +38,7 @@ class Embedding(Module):
|
||||
weights are tied.
|
||||
"""
|
||||
return x @ self.weight.T
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
||||
return QuantizedEmbedding.from_embedding(self, group_size, bits)
|
||||
|
@@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
|
||||
|
||||
class Identity(Module):
|
||||
@@ -69,6 +70,10 @@ class Linear(Module):
|
||||
x = x @ self["weight"].T
|
||||
return x
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
|
||||
return QuantizedLinear.from_linear(self, group_size, bits)
|
||||
|
||||
|
||||
class Bilinear(Module):
|
||||
r"""Applies a bilinear transformation to the inputs.
|
||||
|
@@ -5,8 +5,6 @@ from typing import Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.utils import tree_map_with_path
|
||||
|
||||
|
||||
@@ -18,8 +16,9 @@ def quantize(
|
||||
):
|
||||
"""Quantize the sub-modules of a module according to a predicate.
|
||||
|
||||
By default all :obj:`Linear` and :obj:`Embedding` layers will be
|
||||
quantized. Note also, the module is updated in-place.
|
||||
By default all layers that define a ``to_quantized(group_size, bits)``
|
||||
method will be quantized. Both :obj:`Linear` and :obj:`Embedding` layers
|
||||
will be quantized. Note also, the module is updated in-place.
|
||||
|
||||
Args:
|
||||
model (mlx.nn.Module): The model whose leaf modules may be quantized.
|
||||
@@ -30,18 +29,15 @@ def quantize(
|
||||
class_predicate (Optional[Callable]): A callable which receives the
|
||||
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if
|
||||
it should be quantized and ``False`` otherwise. If ``None``, then
|
||||
all linear and embedding layers are quantized. Default: ``None``.
|
||||
all layers that define a ``to_quantized(group_size, bits)`` method
|
||||
are quantized. Default: ``None``.
|
||||
"""
|
||||
class_predicate = class_predicate or (
|
||||
lambda _, m: isinstance(m, (Linear, Embedding))
|
||||
)
|
||||
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
|
||||
|
||||
def _maybe_quantize(path, m):
|
||||
if class_predicate(path, m):
|
||||
if isinstance(m, Linear):
|
||||
return QuantizedLinear.from_linear(m, group_size, bits)
|
||||
elif isinstance(m, Embedding):
|
||||
return QuantizedEmbedding.from_embedding(m, group_size, bits)
|
||||
if hasattr(m, "to_quantized"):
|
||||
return m.to_quantized(group_size, bits)
|
||||
else:
|
||||
raise ValueError(f"Unable to quantize model of type {type(m)}")
|
||||
else:
|
||||
@@ -129,7 +125,7 @@ class QuantizedEmbedding(Module):
|
||||
|
||||
@classmethod
|
||||
def from_embedding(
|
||||
cls, embedding_layer: Embedding, group_size: int = 64, bits: int = 4
|
||||
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
||||
):
|
||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||
embedding_dims, dims = embedding_layer.weight.shape
|
||||
@@ -220,7 +216,7 @@ class QuantizedLinear(Module):
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Linear, group_size: int = 64, bits: int = 4):
|
||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
|
Reference in New Issue
Block a user