mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-26 16:46:10 +08:00
Fix Split::vmap (#1845)
This commit is contained in:
parent
1c0c118f7c
commit
9eb7d7362f
@ -4273,7 +4273,9 @@ std::pair<std::vector<array>, std::vector<int>> Split::vmap(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||||
return {{split(inputs[0], indices_, axis_ + axis_left, stream())}, axes};
|
auto output = split(inputs[0], indices_, axis_ + axis_left, stream());
|
||||||
|
std::vector<int> output_axes(output.size(), axes[0]);
|
||||||
|
return {std::move(output), std::move(output_axes)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Split::vjp(
|
std::vector<array> Split::vjp(
|
||||||
|
@ -596,6 +596,18 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd)
|
out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd)
|
||||||
self.assertEqual(out.shape, (4, 5, 1))
|
self.assertEqual(out.shape, (4, 5, 1))
|
||||||
|
|
||||||
|
def test_vmap_split_vmap(self):
|
||||||
|
def fun(x):
|
||||||
|
a, b = mx.split(x, 2, 1)
|
||||||
|
return mx.concatenate([b, a], 1)
|
||||||
|
|
||||||
|
x = mx.ones((5, 6, 7))
|
||||||
|
y = mx.ones((5, 4, 6, 7))
|
||||||
|
fx = fun(x)
|
||||||
|
fy = mx.vmap(fun, in_axes=1)(y)
|
||||||
|
self.assertEqual(fx.shape, (5, 6, 7))
|
||||||
|
self.assertEqual(fy.shape, (4, 5, 6, 7))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user