mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Added support for pytree types that inherit from tuple and typing.namedtuple (#2845)
This commit is contained in:
@@ -44,10 +44,11 @@ def tree_map(
|
||||
return fn(tree, *rest)
|
||||
elif isinstance(tree, (list, tuple)):
|
||||
TreeType = type(tree)
|
||||
return TreeType(
|
||||
subtrees = (
|
||||
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
||||
for i, child in enumerate(tree)
|
||||
)
|
||||
return TreeType(*subtrees) if hasattr(tree, "_fields") else TreeType(subtrees)
|
||||
elif isinstance(tree, dict):
|
||||
return {
|
||||
k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)
|
||||
|
||||
@@ -41,6 +41,7 @@ nb::object tree_map(
|
||||
int len = nb::cast<nb::tuple>(subtrees[0]).size();
|
||||
nb::list l;
|
||||
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
|
||||
auto type = subtrees[0].type();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (nb::isinstance<nb::tuple>(subtrees[j])) {
|
||||
@@ -51,7 +52,10 @@ nb::object tree_map(
|
||||
}
|
||||
l.append(recurse(items));
|
||||
}
|
||||
return nb::cast<nb::object>(nb::tuple(l));
|
||||
if (PyTuple_CheckExact(subtrees[0].ptr())) {
|
||||
return nb::cast<nb::object>(nb::tuple(l));
|
||||
}
|
||||
return nb::hasattr(type, "_fields") ? type(*l) : type(l);
|
||||
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
|
||||
@@ -178,11 +182,15 @@ void tree_visit_update(
|
||||
}
|
||||
return nb::cast<nb::object>(l);
|
||||
} else if (nb::isinstance<nb::tuple>(subtree)) {
|
||||
auto type = subtree.type();
|
||||
nb::list l(subtree);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
l[i] = recurse(l[i]);
|
||||
}
|
||||
return nb::cast<nb::object>(nb::tuple(l));
|
||||
if (PyTuple_CheckExact(subtree.ptr())) {
|
||||
return nb::cast<nb::object>(nb::tuple(l));
|
||||
}
|
||||
return nb::hasattr(type, "_fields") ? type(*l) : type(l);
|
||||
} else if (nb::isinstance<nb::dict>(subtree)) {
|
||||
auto d = nb::cast<nb::dict>(subtree);
|
||||
for (auto item : d) {
|
||||
|
||||
@@ -798,6 +798,55 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
grad_fn(model)
|
||||
self.assertEqual(model[1].item(), 2.0)
|
||||
|
||||
def test_autograd_types(self):
|
||||
from typing import NamedTuple
|
||||
|
||||
class Vector(tuple):
|
||||
pass
|
||||
|
||||
class State(NamedTuple):
|
||||
a: mx.array
|
||||
b: mx.array
|
||||
|
||||
def transform(x: State):
|
||||
return State(x.a + 10, x.b * 10)
|
||||
|
||||
def transform_tuple(t):
|
||||
return (t[0] + 10, t[1] * 10)
|
||||
|
||||
def transform_vector(t):
|
||||
return Vector([t[0] + 10, t[1] * 10])
|
||||
|
||||
def loss_fn(x):
|
||||
out = transform(x)
|
||||
return out.a.sum() + out.b.sum()
|
||||
|
||||
def loss_fn_tuple(x):
|
||||
out = transform_tuple(x)
|
||||
return out[0].sum() + out[1].sum()
|
||||
|
||||
def loss_fn_vector(x):
|
||||
out = transform_vector(x)
|
||||
return out[0].sum() + out[1].sum()
|
||||
|
||||
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
grads = mx.grad(loss_fn)(x_batch)
|
||||
self.assertTrue(isinstance(grads, State))
|
||||
self.assertTrue(mx.array_equal(grads.a, mx.ones(3)))
|
||||
self.assertTrue(mx.array_equal(grads.b, mx.ones(3) * 10))
|
||||
|
||||
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
grads = mx.grad(loss_fn_tuple)(x_batch_tuple)
|
||||
self.assertTrue(isinstance(grads, tuple))
|
||||
self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))
|
||||
self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))
|
||||
|
||||
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
|
||||
grads = mx.grad(loss_fn_vector)(x_batch_vector)
|
||||
self.assertTrue(isinstance(grads, Vector))
|
||||
self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))
|
||||
self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))
|
||||
|
||||
def test_reduce_jvp(self):
|
||||
a = mx.arange(4)
|
||||
b = mx.array([3, 2, 1, 0])
|
||||
|
||||
@@ -1179,6 +1179,50 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
expected = fun(False)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_compile_types(self):
|
||||
from typing import NamedTuple
|
||||
|
||||
class Vector(tuple):
|
||||
pass
|
||||
|
||||
class State(NamedTuple):
|
||||
a: mx.array
|
||||
b: mx.array
|
||||
|
||||
def transform(x: State):
|
||||
return State(x.a + 10, x.b * 10)
|
||||
|
||||
def transform_tuple(t):
|
||||
return (t[0] + 10, t[1] * 10)
|
||||
|
||||
def transform_vector(t):
|
||||
return Vector([t[0] + 10, t[1] * 10])
|
||||
|
||||
x = State(mx.array(1), mx.array(2))
|
||||
|
||||
compiled_transform = mx.compile(transform)
|
||||
compiled_transform_tuple = mx.compile(transform_tuple)
|
||||
compiled_transform_vector = mx.compile(transform_vector)
|
||||
|
||||
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
out1 = compiled_transform_tuple(x_batch_tuple)
|
||||
|
||||
self.assertTrue(isinstance(out1, tuple))
|
||||
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))
|
||||
|
||||
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
out2 = compiled_transform(x_batch)
|
||||
self.assertTrue(isinstance(out2, State))
|
||||
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))
|
||||
|
||||
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
|
||||
out3 = compiled_transform_vector(x_batch_vector)
|
||||
self.assertTrue(isinstance(out3, Vector))
|
||||
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
@@ -46,6 +46,51 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(k1, k2)
|
||||
self.assertTrue(mx.array_equal(v1, v2))
|
||||
|
||||
def test_supported_trees(self):
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
class Vector(tuple):
|
||||
pass
|
||||
|
||||
class Params(NamedTuple):
|
||||
m: mx.array
|
||||
b: mx.array
|
||||
|
||||
list1 = [mx.array([0, 1]), mx.array(2)]
|
||||
tuple1 = (mx.array([0, 1]), mx.array(2))
|
||||
vector1 = Vector([mx.array([0, 1]), mx.array(2)])
|
||||
params1 = Params(m=mx.array([0, 1]), b=mx.array(2))
|
||||
dict1 = {"m": mx.array([0, 1]), "b": mx.array(2)}
|
||||
|
||||
add_one = lambda x: x + 1
|
||||
|
||||
list2 = mlx.utils.tree_map(add_one, list1)
|
||||
tuple2 = mlx.utils.tree_map(add_one, tuple1)
|
||||
vector2 = mlx.utils.tree_map(add_one, vector1)
|
||||
params2 = mlx.utils.tree_map(add_one, params1)
|
||||
dict2 = mlx.utils.tree_map(add_one, dict1)
|
||||
|
||||
self.assertTrue(isinstance(list2, list))
|
||||
self.assertTrue(mx.array_equal(list2[0], mx.array([1, 2])))
|
||||
self.assertTrue(mx.array_equal(list2[1], mx.array(3)))
|
||||
|
||||
self.assertTrue(isinstance(tuple2, tuple))
|
||||
self.assertTrue(mx.array_equal(tuple2[0], mx.array([1, 2])))
|
||||
self.assertTrue(mx.array_equal(tuple2[1], mx.array(3)))
|
||||
|
||||
self.assertTrue(isinstance(vector2, Vector))
|
||||
self.assertTrue(mx.array_equal(vector2[0], mx.array([1, 2])))
|
||||
self.assertTrue(mx.array_equal(vector2[1], mx.array(3)))
|
||||
|
||||
self.assertTrue(isinstance(dict2, dict))
|
||||
self.assertTrue(mx.array_equal(dict2["m"], mx.array([1, 2])))
|
||||
self.assertTrue(mx.array_equal(dict2["b"], mx.array(3)))
|
||||
|
||||
self.assertTrue(isinstance(params2, Params))
|
||||
self.assertTrue(mx.array_equal(params2.m, mx.array([1, 2])))
|
||||
self.assertTrue(mx.array_equal(params2.b, mx.array(3)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
@@ -723,6 +723,52 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_vmap_types(self):
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
class Vector(tuple):
|
||||
pass
|
||||
|
||||
class State(NamedTuple):
|
||||
a: mx.array
|
||||
b: mx.array
|
||||
|
||||
def transform(x: State):
|
||||
return State(x.a + 10, x.b * 10)
|
||||
|
||||
def transform_tuple(t):
|
||||
return (t[0] + 10, t[1] * 10)
|
||||
|
||||
def transform_vector(t):
|
||||
return Vector([t[0] + 10, t[1] * 10])
|
||||
|
||||
x = State(mx.array(1), mx.array(2))
|
||||
print(f"{transform(x)=}")
|
||||
|
||||
vmap_transform = mx.vmap(transform)
|
||||
vmap_transform_tuple = mx.vmap(transform_tuple)
|
||||
vmap_transform_vector = mx.vmap(transform_vector)
|
||||
|
||||
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
out1 = vmap_transform_tuple(x_batch_tuple)
|
||||
|
||||
self.assertTrue(isinstance(out1, tuple))
|
||||
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))
|
||||
|
||||
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
|
||||
out2 = vmap_transform(x_batch)
|
||||
self.assertTrue(isinstance(out2, State))
|
||||
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))
|
||||
|
||||
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
|
||||
out3 = vmap_transform_vector(x_batch_vector)
|
||||
self.assertTrue(isinstance(out3, Vector))
|
||||
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
|
||||
|
||||
def test_vmap_masked_scatter(self):
|
||||
def scatter_fn(x, m, src):
|
||||
x[m] = src
|
||||
|
||||
Reference in New Issue
Block a user