diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 3da1993ec..febbafa78 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -312,7 +312,7 @@ class Module(dict): elif isinstance(current_value, (dict, list)): apply(current_value, new_value) elif isinstance(parameters, list): - for i in range(len(dst)): + for i in range(len(parameters)): current_value = dst[i] new_value = parameters[i] if isinstance(current_value, mx.array): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 7749e159a..d7b84bbf6 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -71,7 +71,7 @@ class TestBase(mlx_tests.MLXTestCase): def test_save_safetensors_weights(self): def make_model(): - return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2)) + return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2), nn.ReLU()) m = make_model() tdir = tempfile.TemporaryDirectory()