From 9abb0b8123992718995348ee6876562a720e64c7 Mon Sep 17 00:00:00 2001 From: romanoneg <43445765+romanoneg@users.noreply.github.com> Date: Thu, 4 Dec 2025 11:06:45 -0800 Subject: [PATCH] Added support for pytree types that inherit from tuple and typing.namedtuple (#2845) --- python/mlx/utils.py | 3 ++- python/src/trees.cpp | 12 +++++++-- python/tests/test_autograd.py | 49 +++++++++++++++++++++++++++++++++++ python/tests/test_compile.py | 44 +++++++++++++++++++++++++++++++ python/tests/test_tree.py | 45 ++++++++++++++++++++++++++++++++ python/tests/test_vmap.py | 46 ++++++++++++++++++++++++++++++++ 6 files changed, 196 insertions(+), 3 deletions(-) diff --git a/python/mlx/utils.py b/python/mlx/utils.py index fa9884c10..f4aafe1e3 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -44,10 +44,11 @@ def tree_map( return fn(tree, *rest) elif isinstance(tree, (list, tuple)): TreeType = type(tree) - return TreeType( + subtrees = ( tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf) for i, child in enumerate(tree) ) + return TreeType(*subtrees) if hasattr(tree, "_fields") else TreeType(subtrees) elif isinstance(tree, dict): return { k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf) diff --git a/python/src/trees.cpp b/python/src/trees.cpp index b75d1187c..4b9ca9e12 100644 --- a/python/src/trees.cpp +++ b/python/src/trees.cpp @@ -41,6 +41,7 @@ nb::object tree_map( int len = nb::cast(subtrees[0]).size(); nb::list l; validate_subtrees(subtrees); + auto type = subtrees[0].type(); for (int i = 0; i < len; ++i) { for (int j = 0; j < subtrees.size(); ++j) { if (nb::isinstance(subtrees[j])) { @@ -51,7 +52,10 @@ nb::object tree_map( } l.append(recurse(items)); } - return nb::cast(nb::tuple(l)); + if (PyTuple_CheckExact(subtrees[0].ptr())) { + return nb::cast(nb::tuple(l)); + } + return nb::hasattr(type, "_fields") ? type(*l) : type(l); } else if (nb::isinstance(subtrees[0])) { std::vector items(subtrees.size()); validate_subtrees(subtrees); @@ -178,11 +182,15 @@ void tree_visit_update( } return nb::cast(l); } else if (nb::isinstance(subtree)) { + auto type = subtree.type(); nb::list l(subtree); for (int i = 0; i < l.size(); ++i) { l[i] = recurse(l[i]); } - return nb::cast(nb::tuple(l)); + if (PyTuple_CheckExact(subtree.ptr())) { + return nb::cast(nb::tuple(l)); + } + return nb::hasattr(type, "_fields") ? type(*l) : type(l); } else if (nb::isinstance(subtree)) { auto d = nb::cast(subtree); for (auto item : d) { diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 9eadbd2ef..218ea3ce1 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -798,6 +798,55 @@ class TestAutograd(mlx_tests.MLXTestCase): grad_fn(model) self.assertEqual(model[1].item(), 2.0) + def test_autograd_types(self): + from typing import NamedTuple + + class Vector(tuple): + pass + + class State(NamedTuple): + a: mx.array + b: mx.array + + def transform(x: State): + return State(x.a + 10, x.b * 10) + + def transform_tuple(t): + return (t[0] + 10, t[1] * 10) + + def transform_vector(t): + return Vector([t[0] + 10, t[1] * 10]) + + def loss_fn(x): + out = transform(x) + return out.a.sum() + out.b.sum() + + def loss_fn_tuple(x): + out = transform_tuple(x) + return out[0].sum() + out[1].sum() + + def loss_fn_vector(x): + out = transform_vector(x) + return out[0].sum() + out[1].sum() + + x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6])) + grads = mx.grad(loss_fn)(x_batch) + self.assertTrue(isinstance(grads, State)) + self.assertTrue(mx.array_equal(grads.a, mx.ones(3))) + self.assertTrue(mx.array_equal(grads.b, mx.ones(3) * 10)) + + x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6])) + grads = mx.grad(loss_fn_tuple)(x_batch_tuple) + self.assertTrue(isinstance(grads, tuple)) + self.assertTrue(mx.array_equal(grads[0], mx.ones(3))) + self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10)) + + x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])]) + grads = mx.grad(loss_fn_vector)(x_batch_vector) + self.assertTrue(isinstance(grads, Vector)) + self.assertTrue(mx.array_equal(grads[0], mx.ones(3))) + self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10)) + def test_reduce_jvp(self): a = mx.arange(4) b = mx.array([3, 2, 1, 0]) diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 1ed5b7819..bc3bf80f3 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1179,6 +1179,50 @@ class TestCompile(mlx_tests.MLXTestCase): expected = fun(False) self.assertTrue(mx.allclose(out, expected)) + def test_compile_types(self): + from typing import NamedTuple + + class Vector(tuple): + pass + + class State(NamedTuple): + a: mx.array + b: mx.array + + def transform(x: State): + return State(x.a + 10, x.b * 10) + + def transform_tuple(t): + return (t[0] + 10, t[1] * 10) + + def transform_vector(t): + return Vector([t[0] + 10, t[1] * 10]) + + x = State(mx.array(1), mx.array(2)) + + compiled_transform = mx.compile(transform) + compiled_transform_tuple = mx.compile(transform_tuple) + compiled_transform_vector = mx.compile(transform_vector) + + x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6])) + out1 = compiled_transform_tuple(x_batch_tuple) + + self.assertTrue(isinstance(out1, tuple)) + self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13]))) + self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60]))) + + x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6])) + out2 = compiled_transform(x_batch) + self.assertTrue(isinstance(out2, State)) + self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13]))) + self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60]))) + + x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])]) + out3 = compiled_transform_vector(x_batch_vector) + self.assertTrue(isinstance(out3, Vector)) + self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13]))) + self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60]))) + if __name__ == "__main__": mlx_tests.MLXTestRunner() diff --git a/python/tests/test_tree.py b/python/tests/test_tree.py index bacf6e71d..1ad207bac 100644 --- a/python/tests/test_tree.py +++ b/python/tests/test_tree.py @@ -46,6 +46,51 @@ class TestTreeUtils(mlx_tests.MLXTestCase): self.assertEqual(k1, k2) self.assertTrue(mx.array_equal(v1, v2)) + def test_supported_trees(self): + + from typing import NamedTuple + + class Vector(tuple): + pass + + class Params(NamedTuple): + m: mx.array + b: mx.array + + list1 = [mx.array([0, 1]), mx.array(2)] + tuple1 = (mx.array([0, 1]), mx.array(2)) + vector1 = Vector([mx.array([0, 1]), mx.array(2)]) + params1 = Params(m=mx.array([0, 1]), b=mx.array(2)) + dict1 = {"m": mx.array([0, 1]), "b": mx.array(2)} + + add_one = lambda x: x + 1 + + list2 = mlx.utils.tree_map(add_one, list1) + tuple2 = mlx.utils.tree_map(add_one, tuple1) + vector2 = mlx.utils.tree_map(add_one, vector1) + params2 = mlx.utils.tree_map(add_one, params1) + dict2 = mlx.utils.tree_map(add_one, dict1) + + self.assertTrue(isinstance(list2, list)) + self.assertTrue(mx.array_equal(list2[0], mx.array([1, 2]))) + self.assertTrue(mx.array_equal(list2[1], mx.array(3))) + + self.assertTrue(isinstance(tuple2, tuple)) + self.assertTrue(mx.array_equal(tuple2[0], mx.array([1, 2]))) + self.assertTrue(mx.array_equal(tuple2[1], mx.array(3))) + + self.assertTrue(isinstance(vector2, Vector)) + self.assertTrue(mx.array_equal(vector2[0], mx.array([1, 2]))) + self.assertTrue(mx.array_equal(vector2[1], mx.array(3))) + + self.assertTrue(isinstance(dict2, dict)) + self.assertTrue(mx.array_equal(dict2["m"], mx.array([1, 2]))) + self.assertTrue(mx.array_equal(dict2["b"], mx.array(3))) + + self.assertTrue(isinstance(params2, Params)) + self.assertTrue(mx.array_equal(params2.m, mx.array([1, 2]))) + self.assertTrue(mx.array_equal(params2.b, mx.array(3))) + if __name__ == "__main__": mlx_tests.MLXTestRunner() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 07025be82..8d9ee7051 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -723,6 +723,52 @@ class TestVmap(mlx_tests.MLXTestCase): out = mx.vmap(gconv, in_axes=(0, 0))(x, w) self.assertTrue(mx.allclose(expected, out)) + def test_vmap_types(self): + + from typing import NamedTuple + + class Vector(tuple): + pass + + class State(NamedTuple): + a: mx.array + b: mx.array + + def transform(x: State): + return State(x.a + 10, x.b * 10) + + def transform_tuple(t): + return (t[0] + 10, t[1] * 10) + + def transform_vector(t): + return Vector([t[0] + 10, t[1] * 10]) + + x = State(mx.array(1), mx.array(2)) + print(f"{transform(x)=}") + + vmap_transform = mx.vmap(transform) + vmap_transform_tuple = mx.vmap(transform_tuple) + vmap_transform_vector = mx.vmap(transform_vector) + + x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6])) + out1 = vmap_transform_tuple(x_batch_tuple) + + self.assertTrue(isinstance(out1, tuple)) + self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13]))) + self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60]))) + + x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6])) + out2 = vmap_transform(x_batch) + self.assertTrue(isinstance(out2, State)) + self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13]))) + self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60]))) + + x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])]) + out3 = vmap_transform_vector(x_batch_vector) + self.assertTrue(isinstance(out3, Vector)) + self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13]))) + self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60]))) + def test_vmap_masked_scatter(self): def scatter_fn(x, m, src): x[m] = src