diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 08910467d..192a28fd0 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -129,7 +129,7 @@ class QuantizedEmbedding(Module): @classmethod def from_embedding( - cls, embedding_layer: Module, group_size: int = 64, bits: int = 4 + cls, embedding_layer: Embedding, group_size: int = 64, bits: int = 4 ): """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" embedding_dims, dims = embedding_layer.weight.shape @@ -220,7 +220,7 @@ class QuantizedLinear(Module): return x @classmethod - def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): + def from_linear(cls, linear_layer: Linear, group_size: int = 64, bits: int = 4): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape ql = cls(input_dims, output_dims, False, group_size, bits)