fix modules with dict (#819)

This commit is contained in:
Awni Hannun 2024-03-12 08:54:06 -07:00 committed by GitHub
parent 8e5600022a
commit 366478c560
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 1 deletions

View File

@ -20,7 +20,7 @@ def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
elif isinstance(value, dict):
nd = {}
for k, v in v.items():
for k, v in value.items():
tk = f"{value_key}.{k}"
nd[k] = (
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)

View File

@ -54,6 +54,18 @@ class TestBase(mlx_tests.MLXTestCase):
m.apply_to_modules(assert_training)
def test_model_with_dict(self):
class DictModule(nn.Module):
def __init__(self):
super().__init__()
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
model = DictModule()
params = dict(tree_flatten(model.parameters()))
self.assertEqual(len(params), 2)
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
def test_save_npz_weights(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))