mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Allow quant layer to be unfrozen
This commit is contained in:
parent
87720a8908
commit
81e8aed49d
@ -193,12 +193,6 @@ class QuantizedLinear(Module):
|
||||
# Freeze this model's parameters
|
||||
self.freeze()
|
||||
|
||||
def unfreeze(self, *args, **kwargs):
|
||||
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
||||
our parameters will remain frozen."""
|
||||
super().unfreeze(*args, **kwargs)
|
||||
self.freeze(recurse=False)
|
||||
|
||||
def _extra_repr(self):
|
||||
out_dims, in_dims = self.weight.shape
|
||||
in_dims *= 32 // self.bits
|
||||
|
@ -8,7 +8,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
|
||||
|
||||
class TestBase(mlx_tests.MLXTestCase):
|
||||
@ -198,6 +198,13 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||
|
||||
def test_quantize_freeze(self):
|
||||
lin = nn.Linear(512, 512)
|
||||
qlin = lin.to_quantized()
|
||||
qlin.unfreeze(keys=["scales"])
|
||||
size = tree_reduce(lambda acc, p: acc + p.size, qlin.trainable_parameters(), 0)
|
||||
self.assertTrue(size > 0)
|
||||
|
||||
def test_grad_of_module(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
|
Loading…
Reference in New Issue
Block a user