Support destination arg in tree flatten/unflatten (#2450)

This commit is contained in:
Luca Vivona
2025-08-06 18:34:59 -04:00
committed by GitHub
parent db5c7efcf6
commit 728d4db582
5 changed files with 81 additions and 48 deletions

View File

@@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
model = DictModule()
params = dict(tree_flatten(model.parameters()))
params = tree_flatten(model.parameters(), destination={})
self.assertEqual(len(params), 2)
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))