Allow quant layer to be unfrozen (#2142)

This commit is contained in:
Awni Hannun
2025-04-30 09:08:29 -07:00
committed by GitHub
parent f1606486d2
commit aa5d84f102
2 changed files with 8 additions and 7 deletions

View File

@@ -8,7 +8,7 @@ import mlx.core as mx
import mlx.nn as nn
import mlx_tests
import numpy as np
from mlx.utils import tree_flatten, tree_map
from mlx.utils import tree_flatten, tree_map, tree_reduce
class TestBase(mlx_tests.MLXTestCase):
@@ -198,6 +198,13 @@ class TestBase(mlx_tests.MLXTestCase):
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
def test_quantize_freeze(self):
lin = nn.Linear(512, 512)
qlin = lin.to_quantized()
qlin.unfreeze(keys=["scales"])
size = tree_reduce(lambda acc, p: acc + p.size, qlin.trainable_parameters(), 0)
self.assertTrue(size > 0)
def test_grad_of_module(self):
class Model(nn.Module):
def __init__(self):