From aa5d84f102e8e2fff0f3db3f1d23b61fb1e1a2d7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 30 Apr 2025 09:08:29 -0700 Subject: [PATCH] Allow quant layer to be unfrozen (#2142) --- python/mlx/nn/layers/quantized.py | 6 ------ python/tests/test_nn.py | 9 ++++++++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 823a0084f..2d6dc0882 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -193,12 +193,6 @@ class QuantizedLinear(Module): # Freeze this model's parameters self.freeze() - def unfreeze(self, *args, **kwargs): - """Wrap unfreeze so that we unfreeze any layers we might contain but - our parameters will remain frozen.""" - super().unfreeze(*args, **kwargs) - self.freeze(recurse=False) - def _extra_repr(self): out_dims, in_dims = self.weight.shape in_dims *= 32 // self.bits diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 9cfa25dae..826d53d96 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -8,7 +8,7 @@ import mlx.core as mx import mlx.nn as nn import mlx_tests import numpy as np -from mlx.utils import tree_flatten, tree_map +from mlx.utils import tree_flatten, tree_map, tree_reduce class TestBase(mlx_tests.MLXTestCase): @@ -198,6 +198,13 @@ class TestBase(mlx_tests.MLXTestCase): self.assertTrue(isinstance(m.layers[1], nn.ReLU)) self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) + def test_quantize_freeze(self): + lin = nn.Linear(512, 512) + qlin = lin.to_quantized() + qlin.unfreeze(keys=["scales"]) + size = tree_reduce(lambda acc, p: acc + p.size, qlin.trainable_parameters(), 0) + self.assertTrue(size > 0) + def test_grad_of_module(self): class Model(nn.Module): def __init__(self):