From 60cb11764ec06bb881beb9b88ab811742c4a33ef Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 14 May 2024 21:25:42 +0800 Subject: [PATCH] Use correct module type in quantized.py (#1115) --- python/mlx/nn/layers/quantized.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 08910467d..192a28fd0 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -129,7 +129,7 @@ class QuantizedEmbedding(Module): @classmethod def from_embedding( - cls, embedding_layer: Module, group_size: int = 64, bits: int = 4 + cls, embedding_layer: Embedding, group_size: int = 64, bits: int = 4 ): """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" embedding_dims, dims = embedding_layer.weight.shape @@ -220,7 +220,7 @@ class QuantizedLinear(Module): return x @classmethod - def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): + def from_linear(cls, linear_layer: Linear, group_size: int = 64, bits: int = 4): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape ql = cls(input_dims, output_dims, False, group_size, bits)