mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
Use correct module type in quantized.py (#1115)
This commit is contained in:
parent
cbd5445ea7
commit
60cb11764e
@ -129,7 +129,7 @@ class QuantizedEmbedding(Module):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_embedding(
|
def from_embedding(
|
||||||
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
cls, embedding_layer: Embedding, group_size: int = 64, bits: int = 4
|
||||||
):
|
):
|
||||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||||
embedding_dims, dims = embedding_layer.weight.shape
|
embedding_dims, dims = embedding_layer.weight.shape
|
||||||
@ -220,7 +220,7 @@ class QuantizedLinear(Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
def from_linear(cls, linear_layer: Linear, group_size: int = 64, bits: int = 4):
|
||||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||||
output_dims, input_dims = linear_layer.weight.shape
|
output_dims, input_dims = linear_layer.weight.shape
|
||||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||||
|
Loading…
Reference in New Issue
Block a user