Added support for pytree types that inherit from tuple and typing.namedtuple (#2845)

This commit is contained in:
romanoneg
2025-12-04 11:06:45 -08:00
committed by GitHub
parent 50d3914c67
commit 9abb0b8123
6 changed files with 196 additions and 3 deletions

View File

@@ -798,6 +798,55 @@ class TestAutograd(mlx_tests.MLXTestCase):
grad_fn(model)
self.assertEqual(model[1].item(), 2.0)
def test_autograd_types(self):
from typing import NamedTuple
class Vector(tuple):
pass
class State(NamedTuple):
a: mx.array
b: mx.array
def transform(x: State):
return State(x.a + 10, x.b * 10)
def transform_tuple(t):
return (t[0] + 10, t[1] * 10)
def transform_vector(t):
return Vector([t[0] + 10, t[1] * 10])
def loss_fn(x):
out = transform(x)
return out.a.sum() + out.b.sum()
def loss_fn_tuple(x):
out = transform_tuple(x)
return out[0].sum() + out[1].sum()
def loss_fn_vector(x):
out = transform_vector(x)
return out[0].sum() + out[1].sum()
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
grads = mx.grad(loss_fn)(x_batch)
self.assertTrue(isinstance(grads, State))
self.assertTrue(mx.array_equal(grads.a, mx.ones(3)))
self.assertTrue(mx.array_equal(grads.b, mx.ones(3) * 10))
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
grads = mx.grad(loss_fn_tuple)(x_batch_tuple)
self.assertTrue(isinstance(grads, tuple))
self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))
self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
grads = mx.grad(loss_fn_vector)(x_batch_vector)
self.assertTrue(isinstance(grads, Vector))
self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))
self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))
def test_reduce_jvp(self):
a = mx.arange(4)
b = mx.array([3, 2, 1, 0])

View File

@@ -1179,6 +1179,50 @@ class TestCompile(mlx_tests.MLXTestCase):
expected = fun(False)
self.assertTrue(mx.allclose(out, expected))
def test_compile_types(self):
from typing import NamedTuple
class Vector(tuple):
pass
class State(NamedTuple):
a: mx.array
b: mx.array
def transform(x: State):
return State(x.a + 10, x.b * 10)
def transform_tuple(t):
return (t[0] + 10, t[1] * 10)
def transform_vector(t):
return Vector([t[0] + 10, t[1] * 10])
x = State(mx.array(1), mx.array(2))
compiled_transform = mx.compile(transform)
compiled_transform_tuple = mx.compile(transform_tuple)
compiled_transform_vector = mx.compile(transform_vector)
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out1 = compiled_transform_tuple(x_batch_tuple)
self.assertTrue(isinstance(out1, tuple))
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out2 = compiled_transform(x_batch)
self.assertTrue(isinstance(out2, State))
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
out3 = compiled_transform_vector(x_batch_vector)
self.assertTrue(isinstance(out3, Vector))
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()

View File

@@ -46,6 +46,51 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
self.assertEqual(k1, k2)
self.assertTrue(mx.array_equal(v1, v2))
def test_supported_trees(self):
from typing import NamedTuple
class Vector(tuple):
pass
class Params(NamedTuple):
m: mx.array
b: mx.array
list1 = [mx.array([0, 1]), mx.array(2)]
tuple1 = (mx.array([0, 1]), mx.array(2))
vector1 = Vector([mx.array([0, 1]), mx.array(2)])
params1 = Params(m=mx.array([0, 1]), b=mx.array(2))
dict1 = {"m": mx.array([0, 1]), "b": mx.array(2)}
add_one = lambda x: x + 1
list2 = mlx.utils.tree_map(add_one, list1)
tuple2 = mlx.utils.tree_map(add_one, tuple1)
vector2 = mlx.utils.tree_map(add_one, vector1)
params2 = mlx.utils.tree_map(add_one, params1)
dict2 = mlx.utils.tree_map(add_one, dict1)
self.assertTrue(isinstance(list2, list))
self.assertTrue(mx.array_equal(list2[0], mx.array([1, 2])))
self.assertTrue(mx.array_equal(list2[1], mx.array(3)))
self.assertTrue(isinstance(tuple2, tuple))
self.assertTrue(mx.array_equal(tuple2[0], mx.array([1, 2])))
self.assertTrue(mx.array_equal(tuple2[1], mx.array(3)))
self.assertTrue(isinstance(vector2, Vector))
self.assertTrue(mx.array_equal(vector2[0], mx.array([1, 2])))
self.assertTrue(mx.array_equal(vector2[1], mx.array(3)))
self.assertTrue(isinstance(dict2, dict))
self.assertTrue(mx.array_equal(dict2["m"], mx.array([1, 2])))
self.assertTrue(mx.array_equal(dict2["b"], mx.array(3)))
self.assertTrue(isinstance(params2, Params))
self.assertTrue(mx.array_equal(params2.m, mx.array([1, 2])))
self.assertTrue(mx.array_equal(params2.b, mx.array(3)))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()

View File

@@ -723,6 +723,52 @@ class TestVmap(mlx_tests.MLXTestCase):
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
def test_vmap_types(self):
from typing import NamedTuple
class Vector(tuple):
pass
class State(NamedTuple):
a: mx.array
b: mx.array
def transform(x: State):
return State(x.a + 10, x.b * 10)
def transform_tuple(t):
return (t[0] + 10, t[1] * 10)
def transform_vector(t):
return Vector([t[0] + 10, t[1] * 10])
x = State(mx.array(1), mx.array(2))
print(f"{transform(x)=}")
vmap_transform = mx.vmap(transform)
vmap_transform_tuple = mx.vmap(transform_tuple)
vmap_transform_vector = mx.vmap(transform_vector)
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out1 = vmap_transform_tuple(x_batch_tuple)
self.assertTrue(isinstance(out1, tuple))
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out2 = vmap_transform(x_batch)
self.assertTrue(isinstance(out2, State))
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
out3 = vmap_transform_vector(x_batch_vector)
self.assertTrue(isinstance(out3, Vector))
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
def test_vmap_masked_scatter(self):
def scatter_fn(x, m, src):
x[m] = src