mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix modules with dict (#819)
This commit is contained in:
parent
8e5600022a
commit
366478c560
@ -20,7 +20,7 @@ def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
||||
|
||||
elif isinstance(value, dict):
|
||||
nd = {}
|
||||
for k, v in v.items():
|
||||
for k, v in value.items():
|
||||
tk = f"{value_key}.{k}"
|
||||
nd[k] = (
|
||||
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
|
||||
|
@ -54,6 +54,18 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
|
||||
m.apply_to_modules(assert_training)
|
||||
|
||||
def test_model_with_dict(self):
|
||||
class DictModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
||||
|
||||
model = DictModule()
|
||||
params = dict(tree_flatten(model.parameters()))
|
||||
self.assertEqual(len(params), 2)
|
||||
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
||||
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
||||
|
||||
def test_save_npz_weights(self):
|
||||
def make_model():
|
||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
||||
|
Loading…
Reference in New Issue
Block a user