Allow quant layer to be unfrozen (#2142)

This commit is contained in:
Awni Hannun 2025-04-30 09:08:29 -07:00 committed by GitHub
parent f1606486d2
commit aa5d84f102
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 7 deletions

View File

@ -193,12 +193,6 @@ class QuantizedLinear(Module):
# Freeze this model's parameters
self.freeze()
def unfreeze(self, *args, **kwargs):
"""Wrap unfreeze so that we unfreeze any layers we might contain but
our parameters will remain frozen."""
super().unfreeze(*args, **kwargs)
self.freeze(recurse=False)
def _extra_repr(self):
out_dims, in_dims = self.weight.shape
in_dims *= 32 // self.bits

View File

@ -8,7 +8,7 @@ import mlx.core as mx
import mlx.nn as nn
import mlx_tests
import numpy as np
from mlx.utils import tree_flatten, tree_map
from mlx.utils import tree_flatten, tree_map, tree_reduce
class TestBase(mlx_tests.MLXTestCase):
@ -198,6 +198,13 @@ class TestBase(mlx_tests.MLXTestCase):
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
def test_quantize_freeze(self):
lin = nn.Linear(512, 512)
qlin = lin.to_quantized()
qlin.unfreeze(keys=["scales"])
size = tree_reduce(lambda acc, p: acc + p.size, qlin.trainable_parameters(), 0)
self.assertTrue(size > 0)
def test_grad_of_module(self):
class Model(nn.Module):
def __init__(self):