Properly handle negative axes in python vmap (#944)

This commit is contained in:
Angelos Katharopoulos
2024-04-02 18:07:23 -07:00
committed by GitHub
parent 741eb28443
commit 3fc993f82d
4 changed files with 147 additions and 41 deletions

View File

@@ -121,16 +121,13 @@ class TestVmap(mlx_tests.MLXTestCase):
expected = my_fun(tree)
self.assertTrue(mx.array_equal(out, my_fun(tree)))
with self.assertRaises(ValueError):
mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
with self.assertRaises(ValueError):
mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree)
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree)
out = mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
self.assertTrue(mx.array_equal(out, my_fun(tree)))
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree)
out = mx.vmap(my_fun, in_axes={"a": 0, "b": (0, 0)}, out_axes=0)(tree)
self.assertTrue(mx.array_equal(out, my_fun(tree)))
tree = {
@@ -140,7 +137,7 @@ class TestVmap(mlx_tests.MLXTestCase):
mx.random.uniform(shape=(4, 2)),
),
}
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree)
out = mx.vmap(my_fun, in_axes={"a": 0, "b": (1, 1)}, out_axes=0)(tree)
expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T
self.assertTrue(mx.array_equal(out, expected))