fix sequential with empty modules at end (#647)

This commit is contained in:
Awni Hannun 2024-02-07 13:22:27 -08:00 committed by GitHub
parent 28eac18571
commit e5e816a5ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -312,7 +312,7 @@ class Module(dict):
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif isinstance(parameters, list): elif isinstance(parameters, list):
for i in range(len(dst)): for i in range(len(parameters)):
current_value = dst[i] current_value = dst[i]
new_value = parameters[i] new_value = parameters[i]
if isinstance(current_value, mx.array): if isinstance(current_value, mx.array):

View File

@ -71,7 +71,7 @@ class TestBase(mlx_tests.MLXTestCase):
def test_save_safetensors_weights(self): def test_save_safetensors_weights(self):
def make_model(): def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2)) return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2), nn.ReLU())
m = make_model() m = make_model()
tdir = tempfile.TemporaryDirectory() tdir = tempfile.TemporaryDirectory()