mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Added support for pytree types that inherit from tuple and typing.namedtuple (#2845)
This commit is contained in:
@@ -798,6 +798,55 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
grad_fn(model)
|
||||
self.assertEqual(model[1].item(), 2.0)
|
||||
|
||||
def test_autograd_types(self):
|
||||
from typing import NamedTuple
|
||||
|
||||
class Vector(tuple):
|
||||
pass
|
||||
|
||||
class State(NamedTuple):
|
||||
a: mx.array
|
||||
b: mx.array
|
||||
|
||||
def transform(x: State):
|
||||
return State(x.a + 10, x.b * 10)
|
||||
|
||||
def transform_tuple(t):
|
||||
return (t[0] + 10, t[1] * 10)
|
||||
|
||||
def transform_vector(t):
|
||||
return Vector([t[0] + 10, t[1] * 10])
|
||||
|
||||
def loss_fn(x):
|
||||
out = transform(x)
|
||||
return out.a.sum() + out.b.sum()
|
||||
|
||||
def loss_fn_tuple(x):
|
||||
out = transform_tuple(x)
|
||||
return out[0].sum() + out[1].sum()
|
||||
|
||||
def loss_fn_vector(x):
|
||||
out = transform_vector(x)
|
||||
return out[0].sum() + out[1].sum()
|
||||
|
||||
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
grads = mx.grad(loss_fn)(x_batch)
|
||||
self.assertTrue(isinstance(grads, State))
|
||||
self.assertTrue(mx.array_equal(grads.a, mx.ones(3)))
|
||||
self.assertTrue(mx.array_equal(grads.b, mx.ones(3) * 10))
|
||||
|
||||
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
grads = mx.grad(loss_fn_tuple)(x_batch_tuple)
|
||||
self.assertTrue(isinstance(grads, tuple))
|
||||
self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))
|
||||
self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))
|
||||
|
||||
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
|
||||
grads = mx.grad(loss_fn_vector)(x_batch_vector)
|
||||
self.assertTrue(isinstance(grads, Vector))
|
||||
self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))
|
||||
self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))
|
||||
|
||||
def test_reduce_jvp(self):
|
||||
a = mx.arange(4)
|
||||
b = mx.array([3, 2, 1, 0])
|
||||
|
||||
@@ -1179,6 +1179,50 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
expected = fun(False)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_compile_types(self):
|
||||
from typing import NamedTuple
|
||||
|
||||
class Vector(tuple):
|
||||
pass
|
||||
|
||||
class State(NamedTuple):
|
||||
a: mx.array
|
||||
b: mx.array
|
||||
|
||||
def transform(x: State):
|
||||
return State(x.a + 10, x.b * 10)
|
||||
|
||||
def transform_tuple(t):
|
||||
return (t[0] + 10, t[1] * 10)
|
||||
|
||||
def transform_vector(t):
|
||||
return Vector([t[0] + 10, t[1] * 10])
|
||||
|
||||
x = State(mx.array(1), mx.array(2))
|
||||
|
||||
compiled_transform = mx.compile(transform)
|
||||
compiled_transform_tuple = mx.compile(transform_tuple)
|
||||
compiled_transform_vector = mx.compile(transform_vector)
|
||||
|
||||
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
out1 = compiled_transform_tuple(x_batch_tuple)
|
||||
|
||||
self.assertTrue(isinstance(out1, tuple))
|
||||
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))
|
||||
|
||||
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
out2 = compiled_transform(x_batch)
|
||||
self.assertTrue(isinstance(out2, State))
|
||||
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))
|
||||
|
||||
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
|
||||
out3 = compiled_transform_vector(x_batch_vector)
|
||||
self.assertTrue(isinstance(out3, Vector))
|
||||
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
@@ -46,6 +46,51 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(k1, k2)
|
||||
self.assertTrue(mx.array_equal(v1, v2))
|
||||
|
||||
def test_supported_trees(self):
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
class Vector(tuple):
|
||||
pass
|
||||
|
||||
class Params(NamedTuple):
|
||||
m: mx.array
|
||||
b: mx.array
|
||||
|
||||
list1 = [mx.array([0, 1]), mx.array(2)]
|
||||
tuple1 = (mx.array([0, 1]), mx.array(2))
|
||||
vector1 = Vector([mx.array([0, 1]), mx.array(2)])
|
||||
params1 = Params(m=mx.array([0, 1]), b=mx.array(2))
|
||||
dict1 = {"m": mx.array([0, 1]), "b": mx.array(2)}
|
||||
|
||||
add_one = lambda x: x + 1
|
||||
|
||||
list2 = mlx.utils.tree_map(add_one, list1)
|
||||
tuple2 = mlx.utils.tree_map(add_one, tuple1)
|
||||
vector2 = mlx.utils.tree_map(add_one, vector1)
|
||||
params2 = mlx.utils.tree_map(add_one, params1)
|
||||
dict2 = mlx.utils.tree_map(add_one, dict1)
|
||||
|
||||
self.assertTrue(isinstance(list2, list))
|
||||
self.assertTrue(mx.array_equal(list2[0], mx.array([1, 2])))
|
||||
self.assertTrue(mx.array_equal(list2[1], mx.array(3)))
|
||||
|
||||
self.assertTrue(isinstance(tuple2, tuple))
|
||||
self.assertTrue(mx.array_equal(tuple2[0], mx.array([1, 2])))
|
||||
self.assertTrue(mx.array_equal(tuple2[1], mx.array(3)))
|
||||
|
||||
self.assertTrue(isinstance(vector2, Vector))
|
||||
self.assertTrue(mx.array_equal(vector2[0], mx.array([1, 2])))
|
||||
self.assertTrue(mx.array_equal(vector2[1], mx.array(3)))
|
||||
|
||||
self.assertTrue(isinstance(dict2, dict))
|
||||
self.assertTrue(mx.array_equal(dict2["m"], mx.array([1, 2])))
|
||||
self.assertTrue(mx.array_equal(dict2["b"], mx.array(3)))
|
||||
|
||||
self.assertTrue(isinstance(params2, Params))
|
||||
self.assertTrue(mx.array_equal(params2.m, mx.array([1, 2])))
|
||||
self.assertTrue(mx.array_equal(params2.b, mx.array(3)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
@@ -723,6 +723,52 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_vmap_types(self):
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
class Vector(tuple):
|
||||
pass
|
||||
|
||||
class State(NamedTuple):
|
||||
a: mx.array
|
||||
b: mx.array
|
||||
|
||||
def transform(x: State):
|
||||
return State(x.a + 10, x.b * 10)
|
||||
|
||||
def transform_tuple(t):
|
||||
return (t[0] + 10, t[1] * 10)
|
||||
|
||||
def transform_vector(t):
|
||||
return Vector([t[0] + 10, t[1] * 10])
|
||||
|
||||
x = State(mx.array(1), mx.array(2))
|
||||
print(f"{transform(x)=}")
|
||||
|
||||
vmap_transform = mx.vmap(transform)
|
||||
vmap_transform_tuple = mx.vmap(transform_tuple)
|
||||
vmap_transform_vector = mx.vmap(transform_vector)
|
||||
|
||||
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
out1 = vmap_transform_tuple(x_batch_tuple)
|
||||
|
||||
self.assertTrue(isinstance(out1, tuple))
|
||||
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))
|
||||
|
||||
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
out2 = vmap_transform(x_batch)
|
||||
self.assertTrue(isinstance(out2, State))
|
||||
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))
|
||||
|
||||
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
|
||||
out3 = vmap_transform_vector(x_batch_vector)
|
||||
self.assertTrue(isinstance(out3, Vector))
|
||||
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
|
||||
|
||||
def test_vmap_masked_scatter(self):
|
||||
def scatter_fn(x, m, src):
|
||||
x[m] = src
|
||||
|
||||
Reference in New Issue
Block a user