mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Allow quant layer to be unfrozen (#2142)
This commit is contained in:
@@ -8,7 +8,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
|
||||
|
||||
class TestBase(mlx_tests.MLXTestCase):
|
||||
@@ -198,6 +198,13 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||
|
||||
def test_quantize_freeze(self):
|
||||
lin = nn.Linear(512, 512)
|
||||
qlin = lin.to_quantized()
|
||||
qlin.unfreeze(keys=["scales"])
|
||||
size = tree_reduce(lambda acc, p: acc + p.size, qlin.trainable_parameters(), 0)
|
||||
self.assertTrue(size > 0)
|
||||
|
||||
def test_grad_of_module(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
|
Reference in New Issue
Block a user