Added support for pytree types that inherit from tuple and typing.namedtuple (#2845)

This commit is contained in:
romanoneg
2025-12-04 11:06:45 -08:00
committed by GitHub
parent 50d3914c67
commit 9abb0b8123
6 changed files with 196 additions and 3 deletions

View File

@@ -44,10 +44,11 @@ def tree_map(
return fn(tree, *rest)
elif isinstance(tree, (list, tuple)):
TreeType = type(tree)
return TreeType(
subtrees = (
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
for i, child in enumerate(tree)
)
return TreeType(*subtrees) if hasattr(tree, "_fields") else TreeType(subtrees)
elif isinstance(tree, dict):
return {
k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)

View File

@@ -41,6 +41,7 @@ nb::object tree_map(
int len = nb::cast<nb::tuple>(subtrees[0]).size();
nb::list l;
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
auto type = subtrees[0].type();
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::tuple>(subtrees[j])) {
@@ -51,7 +52,10 @@ nb::object tree_map(
}
l.append(recurse(items));
}
return nb::cast<nb::object>(nb::tuple(l));
if (PyTuple_CheckExact(subtrees[0].ptr())) {
return nb::cast<nb::object>(nb::tuple(l));
}
return nb::hasattr(type, "_fields") ? type(*l) : type(l);
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
@@ -178,11 +182,15 @@ void tree_visit_update(
}
return nb::cast<nb::object>(l);
} else if (nb::isinstance<nb::tuple>(subtree)) {
auto type = subtree.type();
nb::list l(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return nb::cast<nb::object>(nb::tuple(l));
if (PyTuple_CheckExact(subtree.ptr())) {
return nb::cast<nb::object>(nb::tuple(l));
}
return nb::hasattr(type, "_fields") ? type(*l) : type(l);
} else if (nb::isinstance<nb::dict>(subtree)) {
auto d = nb::cast<nb::dict>(subtree);
for (auto item : d) {

View File

@@ -798,6 +798,55 @@ class TestAutograd(mlx_tests.MLXTestCase):
grad_fn(model)
self.assertEqual(model[1].item(), 2.0)
def test_autograd_types(self):
from typing import NamedTuple
class Vector(tuple):
pass
class State(NamedTuple):
a: mx.array
b: mx.array
def transform(x: State):
return State(x.a + 10, x.b * 10)
def transform_tuple(t):
return (t[0] + 10, t[1] * 10)
def transform_vector(t):
return Vector([t[0] + 10, t[1] * 10])
def loss_fn(x):
out = transform(x)
return out.a.sum() + out.b.sum()
def loss_fn_tuple(x):
out = transform_tuple(x)
return out[0].sum() + out[1].sum()
def loss_fn_vector(x):
out = transform_vector(x)
return out[0].sum() + out[1].sum()
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
grads = mx.grad(loss_fn)(x_batch)
self.assertTrue(isinstance(grads, State))
self.assertTrue(mx.array_equal(grads.a, mx.ones(3)))
self.assertTrue(mx.array_equal(grads.b, mx.ones(3) * 10))
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
grads = mx.grad(loss_fn_tuple)(x_batch_tuple)
self.assertTrue(isinstance(grads, tuple))
self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))
self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
grads = mx.grad(loss_fn_vector)(x_batch_vector)
self.assertTrue(isinstance(grads, Vector))
self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))
self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))
def test_reduce_jvp(self):
a = mx.arange(4)
b = mx.array([3, 2, 1, 0])

View File

@@ -1179,6 +1179,50 @@ class TestCompile(mlx_tests.MLXTestCase):
expected = fun(False)
self.assertTrue(mx.allclose(out, expected))
def test_compile_types(self):
from typing import NamedTuple
class Vector(tuple):
pass
class State(NamedTuple):
a: mx.array
b: mx.array
def transform(x: State):
return State(x.a + 10, x.b * 10)
def transform_tuple(t):
return (t[0] + 10, t[1] * 10)
def transform_vector(t):
return Vector([t[0] + 10, t[1] * 10])
x = State(mx.array(1), mx.array(2))
compiled_transform = mx.compile(transform)
compiled_transform_tuple = mx.compile(transform_tuple)
compiled_transform_vector = mx.compile(transform_vector)
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out1 = compiled_transform_tuple(x_batch_tuple)
self.assertTrue(isinstance(out1, tuple))
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out2 = compiled_transform(x_batch)
self.assertTrue(isinstance(out2, State))
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
out3 = compiled_transform_vector(x_batch_vector)
self.assertTrue(isinstance(out3, Vector))
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()

View File

@@ -46,6 +46,51 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
self.assertEqual(k1, k2)
self.assertTrue(mx.array_equal(v1, v2))
def test_supported_trees(self):
from typing import NamedTuple
class Vector(tuple):
pass
class Params(NamedTuple):
m: mx.array
b: mx.array
list1 = [mx.array([0, 1]), mx.array(2)]
tuple1 = (mx.array([0, 1]), mx.array(2))
vector1 = Vector([mx.array([0, 1]), mx.array(2)])
params1 = Params(m=mx.array([0, 1]), b=mx.array(2))
dict1 = {"m": mx.array([0, 1]), "b": mx.array(2)}
add_one = lambda x: x + 1
list2 = mlx.utils.tree_map(add_one, list1)
tuple2 = mlx.utils.tree_map(add_one, tuple1)
vector2 = mlx.utils.tree_map(add_one, vector1)
params2 = mlx.utils.tree_map(add_one, params1)
dict2 = mlx.utils.tree_map(add_one, dict1)
self.assertTrue(isinstance(list2, list))
self.assertTrue(mx.array_equal(list2[0], mx.array([1, 2])))
self.assertTrue(mx.array_equal(list2[1], mx.array(3)))
self.assertTrue(isinstance(tuple2, tuple))
self.assertTrue(mx.array_equal(tuple2[0], mx.array([1, 2])))
self.assertTrue(mx.array_equal(tuple2[1], mx.array(3)))
self.assertTrue(isinstance(vector2, Vector))
self.assertTrue(mx.array_equal(vector2[0], mx.array([1, 2])))
self.assertTrue(mx.array_equal(vector2[1], mx.array(3)))
self.assertTrue(isinstance(dict2, dict))
self.assertTrue(mx.array_equal(dict2["m"], mx.array([1, 2])))
self.assertTrue(mx.array_equal(dict2["b"], mx.array(3)))
self.assertTrue(isinstance(params2, Params))
self.assertTrue(mx.array_equal(params2.m, mx.array([1, 2])))
self.assertTrue(mx.array_equal(params2.b, mx.array(3)))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()

View File

@@ -723,6 +723,52 @@ class TestVmap(mlx_tests.MLXTestCase):
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
def test_vmap_types(self):
from typing import NamedTuple
class Vector(tuple):
pass
class State(NamedTuple):
a: mx.array
b: mx.array
def transform(x: State):
return State(x.a + 10, x.b * 10)
def transform_tuple(t):
return (t[0] + 10, t[1] * 10)
def transform_vector(t):
return Vector([t[0] + 10, t[1] * 10])
x = State(mx.array(1), mx.array(2))
print(f"{transform(x)=}")
vmap_transform = mx.vmap(transform)
vmap_transform_tuple = mx.vmap(transform_tuple)
vmap_transform_vector = mx.vmap(transform_vector)
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out1 = vmap_transform_tuple(x_batch_tuple)
self.assertTrue(isinstance(out1, tuple))
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out2 = vmap_transform(x_batch)
self.assertTrue(isinstance(out2, State))
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
out3 = vmap_transform_vector(x_batch_vector)
self.assertTrue(isinstance(out3, Vector))
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
def test_vmap_masked_scatter(self):
def scatter_fn(x, m, src):
x[m] = src