mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 23:08:11 +08:00
Support destination arg in tree flatten/unflatten (#2450)
This commit is contained in:
@@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
||||
|
||||
model = DictModule()
|
||||
params = dict(tree_flatten(model.parameters()))
|
||||
params = tree_flatten(model.parameters(), destination={})
|
||||
self.assertEqual(len(params), 2)
|
||||
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
||||
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
||||
|
Reference in New Issue
Block a user