Use correct module type in quantized.py (#1115)

This commit is contained in:
Cheng 2024-05-14 21:25:42 +08:00 committed by GitHub
parent cbd5445ea7
commit 60cb11764e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -129,7 +129,7 @@ class QuantizedEmbedding(Module):
@classmethod
def from_embedding(
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
cls, embedding_layer: Embedding, group_size: int = 64, bits: int = 4
):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
@ -220,7 +220,7 @@ class QuantizedLinear(Module):
return x
@classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
def from_linear(cls, linear_layer: Linear, group_size: int = 64, bits: int = 4):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)