mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix sequential with empty modules at end (#647)
This commit is contained in:
parent
28eac18571
commit
e5e816a5ef
@ -312,7 +312,7 @@ class Module(dict):
|
|||||||
elif isinstance(current_value, (dict, list)):
|
elif isinstance(current_value, (dict, list)):
|
||||||
apply(current_value, new_value)
|
apply(current_value, new_value)
|
||||||
elif isinstance(parameters, list):
|
elif isinstance(parameters, list):
|
||||||
for i in range(len(dst)):
|
for i in range(len(parameters)):
|
||||||
current_value = dst[i]
|
current_value = dst[i]
|
||||||
new_value = parameters[i]
|
new_value = parameters[i]
|
||||||
if isinstance(current_value, mx.array):
|
if isinstance(current_value, mx.array):
|
||||||
|
@ -71,7 +71,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
def test_save_safetensors_weights(self):
|
def test_save_safetensors_weights(self):
|
||||||
def make_model():
|
def make_model():
|
||||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2), nn.ReLU())
|
||||||
|
|
||||||
m = make_model()
|
m = make_model()
|
||||||
tdir = tempfile.TemporaryDirectory()
|
tdir = tempfile.TemporaryDirectory()
|
||||||
|
Loading…
Reference in New Issue
Block a user