mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
chore: change function with a destination dictonary object
This commit is contained in:
@@ -30,15 +30,16 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(len(flat_children), 3)
|
||||
|
||||
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
|
||||
self.assertEqual(len(leaves), 4)
|
||||
self.assertEqual(leaves[0][0], "layers.0.layers.0")
|
||||
self.assertEqual(leaves[1][0], "layers.1.layers.0")
|
||||
self.assertEqual(leaves[2][0], "layers.1.layers.1")
|
||||
self.assertEqual(leaves[3][0], "layers.2")
|
||||
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
|
||||
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
|
||||
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
|
||||
self.assertTrue(leaves[3][1] is m.layers[2])
|
||||
if isinstance(leaves, list):
|
||||
self.assertEqual(len(leaves), 4)
|
||||
self.assertEqual(leaves[0][0], "layers.0.layers.0")
|
||||
self.assertEqual(leaves[1][0], "layers.1.layers.0")
|
||||
self.assertEqual(leaves[2][0], "layers.1.layers.1")
|
||||
self.assertEqual(leaves[3][0], "layers.2")
|
||||
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
|
||||
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
|
||||
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
|
||||
self.assertTrue(leaves[3][1] is m.layers[2])
|
||||
|
||||
m.eval()
|
||||
|
||||
@@ -80,7 +81,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
||||
|
||||
model = DictModule()
|
||||
params = dict(tree_flatten(model.parameters()))
|
||||
params = tree_flatten(model.parameters(), destination={})
|
||||
self.assertEqual(len(params), 2)
|
||||
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
||||
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
||||
|
||||
Reference in New Issue
Block a user