diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 447002594..50808475d 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -20,7 +20,7 @@ def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn): elif isinstance(value, dict): nd = {} - for k, v in v.items(): + for k, v in value.items(): tk = f"{value_key}.{k}" nd[k] = ( _unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index d704d4004..e3e676c6e 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -54,6 +54,18 @@ class TestBase(mlx_tests.MLXTestCase): m.apply_to_modules(assert_training) + def test_model_with_dict(self): + class DictModule(nn.Module): + def __init__(self): + super().__init__() + self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))} + + model = DictModule() + params = dict(tree_flatten(model.parameters())) + self.assertEqual(len(params), 2) + self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2)))) + self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2)))) + def test_save_npz_weights(self): def make_model(): return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))