mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
fix modules with dict (#819)
This commit is contained in:
parent
8e5600022a
commit
366478c560
@ -20,7 +20,7 @@ def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
|||||||
|
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
nd = {}
|
nd = {}
|
||||||
for k, v in v.items():
|
for k, v in value.items():
|
||||||
tk = f"{value_key}.{k}"
|
tk = f"{value_key}.{k}"
|
||||||
nd[k] = (
|
nd[k] = (
|
||||||
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
|
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
|
||||||
|
@ -54,6 +54,18 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
m.apply_to_modules(assert_training)
|
m.apply_to_modules(assert_training)
|
||||||
|
|
||||||
|
def test_model_with_dict(self):
|
||||||
|
class DictModule(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
||||||
|
|
||||||
|
model = DictModule()
|
||||||
|
params = dict(tree_flatten(model.parameters()))
|
||||||
|
self.assertEqual(len(params), 2)
|
||||||
|
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
||||||
|
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
||||||
|
|
||||||
def test_save_npz_weights(self):
|
def test_save_npz_weights(self):
|
||||||
def make_model():
|
def make_model():
|
||||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
||||||
|
Loading…
Reference in New Issue
Block a user