mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Added support for pytree types that inherit from tuple and typing.namedtuple (#2845)
This commit is contained in:
@@ -723,6 +723,52 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_vmap_types(self):
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
class Vector(tuple):
|
||||
pass
|
||||
|
||||
class State(NamedTuple):
|
||||
a: mx.array
|
||||
b: mx.array
|
||||
|
||||
def transform(x: State):
|
||||
return State(x.a + 10, x.b * 10)
|
||||
|
||||
def transform_tuple(t):
|
||||
return (t[0] + 10, t[1] * 10)
|
||||
|
||||
def transform_vector(t):
|
||||
return Vector([t[0] + 10, t[1] * 10])
|
||||
|
||||
x = State(mx.array(1), mx.array(2))
|
||||
print(f"{transform(x)=}")
|
||||
|
||||
vmap_transform = mx.vmap(transform)
|
||||
vmap_transform_tuple = mx.vmap(transform_tuple)
|
||||
vmap_transform_vector = mx.vmap(transform_vector)
|
||||
|
||||
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
out1 = vmap_transform_tuple(x_batch_tuple)
|
||||
|
||||
self.assertTrue(isinstance(out1, tuple))
|
||||
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))
|
||||
|
||||
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
out2 = vmap_transform(x_batch)
|
||||
self.assertTrue(isinstance(out2, State))
|
||||
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))
|
||||
|
||||
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
|
||||
out3 = vmap_transform_vector(x_batch_vector)
|
||||
self.assertTrue(isinstance(out3, Vector))
|
||||
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
|
||||
|
||||
def test_vmap_masked_scatter(self):
|
||||
def scatter_fn(x, m, src):
|
||||
x[m] = src
|
||||
|
||||
Reference in New Issue
Block a user