From e5e816a5efa7f639469737a25ca947a40e8bf76a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 7 Feb 2024 13:22:27 -0800 Subject: [PATCH] fix sequential with empty modules at end (#647) --- python/mlx/nn/layers/base.py | 2 +- python/tests/test_nn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 3da1993ec..febbafa78 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -312,7 +312,7 @@ class Module(dict): elif isinstance(current_value, (dict, list)): apply(current_value, new_value) elif isinstance(parameters, list): - for i in range(len(dst)): + for i in range(len(parameters)): current_value = dst[i] new_value = parameters[i] if isinstance(current_value, mx.array): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 7749e159a..d7b84bbf6 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -71,7 +71,7 @@ class TestBase(mlx_tests.MLXTestCase): def test_save_safetensors_weights(self): def make_model(): - return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2)) + return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2), nn.ReLU()) m = make_model() tdir = tempfile.TemporaryDirectory()