mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Properly handle negative axes in python vmap (#944)
This commit is contained in:

committed by
GitHub

parent
741eb28443
commit
3fc993f82d
@@ -121,16 +121,13 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
expected = my_fun(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree)
|
||||
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree)
|
||||
out = mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree)
|
||||
out = mx.vmap(my_fun, in_axes={"a": 0, "b": (0, 0)}, out_axes=0)(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
tree = {
|
||||
@@ -140,7 +137,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
mx.random.uniform(shape=(4, 2)),
|
||||
),
|
||||
}
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree)
|
||||
out = mx.vmap(my_fun, in_axes={"a": 0, "b": (1, 1)}, out_axes=0)(tree)
|
||||
expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
|
Reference in New Issue
Block a user