mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Allow quant layer to be unfrozen (#2142)
This commit is contained in:
parent
f1606486d2
commit
aa5d84f102
@ -193,12 +193,6 @@ class QuantizedLinear(Module):
|
|||||||
# Freeze this model's parameters
|
# Freeze this model's parameters
|
||||||
self.freeze()
|
self.freeze()
|
||||||
|
|
||||||
def unfreeze(self, *args, **kwargs):
|
|
||||||
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
|
||||||
our parameters will remain frozen."""
|
|
||||||
super().unfreeze(*args, **kwargs)
|
|
||||||
self.freeze(recurse=False)
|
|
||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
out_dims, in_dims = self.weight.shape
|
out_dims, in_dims = self.weight.shape
|
||||||
in_dims *= 32 // self.bits
|
in_dims *= 32 // self.bits
|
||||||
|
@ -8,7 +8,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_flatten, tree_map
|
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||||
|
|
||||||
|
|
||||||
class TestBase(mlx_tests.MLXTestCase):
|
class TestBase(mlx_tests.MLXTestCase):
|
||||||
@ -198,6 +198,13 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||||
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||||
|
|
||||||
|
def test_quantize_freeze(self):
|
||||||
|
lin = nn.Linear(512, 512)
|
||||||
|
qlin = lin.to_quantized()
|
||||||
|
qlin.unfreeze(keys=["scales"])
|
||||||
|
size = tree_reduce(lambda acc, p: acc + p.size, qlin.trainable_parameters(), 0)
|
||||||
|
self.assertTrue(size > 0)
|
||||||
|
|
||||||
def test_grad_of_module(self):
|
def test_grad_of_module(self):
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user