Added support for pytree types that inherit from tuple and typing.namedtuple (#2845)

This commit is contained in:
romanoneg
2025-12-04 11:06:45 -08:00
committed by GitHub
parent 50d3914c67
commit 9abb0b8123
6 changed files with 196 additions and 3 deletions

View File

@@ -723,6 +723,52 @@ class TestVmap(mlx_tests.MLXTestCase):
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
def test_vmap_types(self):
from typing import NamedTuple
class Vector(tuple):
pass
class State(NamedTuple):
a: mx.array
b: mx.array
def transform(x: State):
return State(x.a + 10, x.b * 10)
def transform_tuple(t):
return (t[0] + 10, t[1] * 10)
def transform_vector(t):
return Vector([t[0] + 10, t[1] * 10])
x = State(mx.array(1), mx.array(2))
print(f"{transform(x)=}")
vmap_transform = mx.vmap(transform)
vmap_transform_tuple = mx.vmap(transform_tuple)
vmap_transform_vector = mx.vmap(transform_vector)
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out1 = vmap_transform_tuple(x_batch_tuple)
self.assertTrue(isinstance(out1, tuple))
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out2 = vmap_transform(x_batch)
self.assertTrue(isinstance(out2, State))
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
out3 = vmap_transform_vector(x_batch_vector)
self.assertTrue(isinstance(out3, Vector))
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
def test_vmap_masked_scatter(self):
def scatter_fn(x, m, src):
x[m] = src