mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-27 00:09:17 +08:00
Add move and swap axis, and vmap for slice, concat, and gather (#158)
* add move and swap axis, and vmap for slice, concat, and gather
This commit is contained in:
parent
f55908bc48
commit
e5851e52b1
@ -61,6 +61,7 @@ Operations
|
||||
mean
|
||||
min
|
||||
minimum
|
||||
moveaxis
|
||||
multiply
|
||||
negative
|
||||
ones
|
||||
@ -87,6 +88,7 @@ Operations
|
||||
stop_gradient
|
||||
subtract
|
||||
sum
|
||||
swapaxes
|
||||
take
|
||||
take_along_axis
|
||||
tan
|
||||
|
47
mlx/ops.cpp
47
mlx/ops.cpp
@ -677,6 +677,53 @@ array pad(
|
||||
s);
|
||||
}
|
||||
|
||||
array moveaxis(
|
||||
const array& a,
|
||||
int source,
|
||||
int destination,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto check_ax = [&a](int ax) {
|
||||
auto ndim = static_cast<int>(a.ndim());
|
||||
if (ax < -ndim || ax >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[moveaxis] Invalid axis " << ax << " for array with " << ndim
|
||||
<< " dimensions.";
|
||||
throw std::out_of_range(msg.str());
|
||||
}
|
||||
return ax < 0 ? ax + ndim : ax;
|
||||
};
|
||||
source = check_ax(source);
|
||||
destination = check_ax(destination);
|
||||
std::vector<int> reorder(a.ndim());
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
reorder.erase(reorder.begin() + source);
|
||||
reorder.insert(reorder.begin() + destination, source);
|
||||
return transpose(a, reorder, s);
|
||||
}
|
||||
|
||||
array swapaxes(
|
||||
const array& a,
|
||||
int axis1,
|
||||
int axis2,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto check_ax = [&a](int ax) {
|
||||
auto ndim = static_cast<int>(a.ndim());
|
||||
if (ax < -ndim || ax >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[swapaxes] Invalid axis " << ax << " for array with " << ndim
|
||||
<< " dimensions.";
|
||||
throw std::out_of_range(msg.str());
|
||||
}
|
||||
return ax < 0 ? ax + ndim : ax;
|
||||
};
|
||||
axis1 = check_ax(axis1);
|
||||
axis2 = check_ax(axis2);
|
||||
std::vector<int> reorder(a.ndim());
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
std::swap(reorder[axis1], reorder[axis2]);
|
||||
return transpose(a, reorder, s);
|
||||
}
|
||||
|
||||
array transpose(
|
||||
const array& a,
|
||||
std::vector<int> axes,
|
||||
|
10
mlx/ops.h
10
mlx/ops.h
@ -183,6 +183,16 @@ inline array transpose(
|
||||
return transpose(a, std::vector<int>(axes), s);
|
||||
}
|
||||
|
||||
/** Swap two axes of an array. */
|
||||
array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});
|
||||
|
||||
/** Move an axis of an array. */
|
||||
array moveaxis(
|
||||
const array& a,
|
||||
int source,
|
||||
int destination,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Pad an array with a constant value */
|
||||
array pad(
|
||||
const array& a,
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@ -512,7 +511,26 @@ array Concatenate::jvp(
|
||||
std::pair<array, int> Concatenate::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Concatenate vmap is NYI.");
|
||||
std::vector<array> t_inputs;
|
||||
// Find the first vmapped input
|
||||
int i = 0;
|
||||
for (; i < axes.size(); i++) {
|
||||
t_inputs.push_back(inputs[i]);
|
||||
if (axes[i] >= 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto out_ax = axes[i++];
|
||||
// Move vmap axes to the same spot.
|
||||
for (; i < axes.size(); ++i) {
|
||||
if (out_ax != axes[i] && axes[i] >= 0) {
|
||||
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
|
||||
} else {
|
||||
t_inputs.push_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
auto axis = axis_ + (axis_ >= out_ax);
|
||||
return {concatenate(t_inputs, axis, stream()), out_ax};
|
||||
}
|
||||
|
||||
bool Concatenate::is_equivalent(const Primitive& other) const {
|
||||
@ -1054,7 +1072,53 @@ std::pair<array, int> Full::vmap(
|
||||
std::pair<array, int> Gather::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Gather vmap is NYI, please change slices instead");
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> indices(inputs.begin() + 1, inputs.end());
|
||||
auto gather_axes = axes_;
|
||||
auto slice_sizes = slice_sizes_;
|
||||
auto src_vmapped = axes[0] >= 0;
|
||||
auto indices_vmapped =
|
||||
std::any_of(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; });
|
||||
auto out_ax =
|
||||
*std::find_if(axes.begin(), axes.end(), [](int a) { return a >= 0; });
|
||||
|
||||
// Reorder all the index arrays so the vmap axis is in the same spot.
|
||||
for (int i = 1; i < axes.size(); ++i) {
|
||||
if (out_ax != axes[i] && axes[i] >= 0) {
|
||||
indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream());
|
||||
}
|
||||
}
|
||||
|
||||
if (src_vmapped) {
|
||||
int max_dims = 0;
|
||||
for (auto& idx : indices) {
|
||||
max_dims = std::max(static_cast<int>(idx.ndim()), max_dims);
|
||||
}
|
||||
auto new_ax_loc =
|
||||
std::find_if(gather_axes.begin(), gather_axes.end(), [&out_ax](int a) {
|
||||
return a >= out_ax;
|
||||
});
|
||||
for (; new_ax_loc < gather_axes.end(); new_ax_loc++) {
|
||||
(*new_ax_loc)++;
|
||||
}
|
||||
if (indices_vmapped) {
|
||||
// Make a new index array for the vmapped dimension
|
||||
// Reshape it so it broadcasts with other index arrays
|
||||
// Update gather axes and slice sizes accordingly
|
||||
auto shape = std::vector<int>(max_dims - out_ax, 1);
|
||||
auto vmap_inds = arange(0, src.shape(out_ax), stream());
|
||||
shape[0] = vmap_inds.shape(0);
|
||||
vmap_inds = reshape(vmap_inds, shape, stream());
|
||||
slice_sizes.insert(slice_sizes.begin() + out_ax, 1);
|
||||
auto new_ax_idx = new_ax_loc - gather_axes.begin();
|
||||
gather_axes.insert(new_ax_loc, out_ax);
|
||||
indices.insert(indices.begin() + new_ax_idx, vmap_inds);
|
||||
} else {
|
||||
slice_sizes.insert(slice_sizes.begin() + axes[0], src.shape(axes[0]));
|
||||
out_ax = max_dims + axes[0];
|
||||
}
|
||||
}
|
||||
return {gather(src, indices, gather_axes, slice_sizes, stream()), out_ax};
|
||||
}
|
||||
|
||||
std::vector<array> Gather::vjp(
|
||||
@ -1997,8 +2061,15 @@ std::pair<array, int> Sinh::vmap(
|
||||
std::pair<array, int> Slice::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
// TODO implement
|
||||
return {array(1.0f), axes[0]};
|
||||
auto start = start_indices_;
|
||||
auto stop = end_indices_;
|
||||
auto strides = strides_;
|
||||
auto ax = axes[0];
|
||||
auto& input = inputs[0];
|
||||
start.insert(start.begin() + ax, 0);
|
||||
stop.insert(stop.begin() + ax, input.shape(ax));
|
||||
strides.insert(strides.begin() + ax, 1);
|
||||
return {slice(input, start, stop, strides, stream()), ax};
|
||||
}
|
||||
|
||||
std::vector<array> Slice::vjp(
|
||||
|
@ -862,6 +862,22 @@ void init_array(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
"See :func:`any`.")
|
||||
.def(
|
||||
"moveaxis",
|
||||
&moveaxis,
|
||||
"source"_a,
|
||||
"destination"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
"See :func:`moveaxis`.")
|
||||
.def(
|
||||
"swapaxes",
|
||||
&swapaxes,
|
||||
"axis1"_a,
|
||||
"axis2"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
"See :func:`moveaxis`.")
|
||||
.def(
|
||||
"transpose",
|
||||
[](const array& a, py::args axes, StreamOrDevice s) {
|
||||
|
@ -1591,6 +1591,50 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The ceil of ``a``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"moveaxis",
|
||||
&moveaxis,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
"source"_a,
|
||||
"destiantion"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
moveaxis(a: array, /, source: int, destination: int, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Move an axis to a new position.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
source (int): Specifies the source axis.
|
||||
destination (int): Specifies the destination axis.
|
||||
|
||||
Returns:
|
||||
array: The array with the axis moved.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"swapaxes",
|
||||
&swapaxes,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
"axis1"_a,
|
||||
"axis2"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
swapaxes(a: array, /, axis1 : int, axis2: int, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Swap two axes of an array.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
axis1 (int): Specifies the first axis.
|
||||
axis2 (int): Specifies the second axis.
|
||||
|
||||
Returns:
|
||||
array: The array with swapped axes.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"transpose",
|
||||
[](const array& a,
|
||||
|
@ -375,6 +375,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertListEqual(mx.transpose(x, axes=(0, 2, 1)).tolist(), expected)
|
||||
|
||||
def test_move_swap_axes(self):
|
||||
x = mx.zeros((2, 3, 4))
|
||||
self.assertEqual(mx.moveaxis(x, 0, 2).shape, [3, 4, 2])
|
||||
self.assertEqual(x.moveaxis(0, 2).shape, [3, 4, 2])
|
||||
self.assertEqual(mx.swapaxes(x, 0, 2).shape, [4, 3, 2])
|
||||
self.assertEqual(x.swapaxes(0, 2).shape, [4, 3, 2])
|
||||
|
||||
def test_sum(self):
|
||||
x = mx.array(
|
||||
[
|
||||
|
@ -163,6 +163,61 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.array_equal(out["a"].T, expected["a"]))
|
||||
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
|
||||
|
||||
def test_vmap_indexing(self):
|
||||
x = mx.arange(16).reshape(2, 2, 2, 2)
|
||||
inds = mx.array([[0, 1, 0], [1, 1, 0]])
|
||||
|
||||
out = mx.vmap(lambda x, y: x[y], in_axes=(0, 0))(x, inds)
|
||||
expected = mx.array(
|
||||
[
|
||||
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],
|
||||
[[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],
|
||||
]
|
||||
)
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
out = mx.vmap(lambda x, y: x[y], in_axes=(0, None))(x, inds)
|
||||
expected = mx.array(
|
||||
[
|
||||
[
|
||||
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],
|
||||
[[[4, 5], [6, 7]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],
|
||||
],
|
||||
[
|
||||
[[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],
|
||||
[[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],
|
||||
],
|
||||
]
|
||||
)
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
out = mx.vmap(lambda x, y: x[y], in_axes=(None, 0))(x, inds)
|
||||
expected = mx.array(
|
||||
[
|
||||
[
|
||||
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
|
||||
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]],
|
||||
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
|
||||
],
|
||||
[
|
||||
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]],
|
||||
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]],
|
||||
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
|
||||
],
|
||||
]
|
||||
)
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
inds2 = mx.array([[0, 1, 0], [0, 1, 0]])
|
||||
out = mx.vmap(lambda x, y, z: x[y, z], in_axes=(None, 0, 0))(x, inds, inds2)
|
||||
expected = mx.array(
|
||||
[
|
||||
[[[0, 1], [2, 3]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]],
|
||||
[[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]],
|
||||
]
|
||||
)
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
@ -205,6 +204,46 @@ TEST_CASE("test split") {
|
||||
CHECK(array_equal(out[3], array({2, 3, 4})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test swap and move axes") {
|
||||
// Test swapaxes
|
||||
array a(0.0);
|
||||
CHECK_THROWS(swapaxes(a, 0, 0));
|
||||
|
||||
a = zeros({2});
|
||||
CHECK_THROWS(swapaxes(a, 0, 1));
|
||||
CHECK_EQ(swapaxes(a, 0, 0).shape(), std::vector<int>{2});
|
||||
CHECK_EQ(swapaxes(a, -1, -1).shape(), std::vector<int>{2});
|
||||
|
||||
a = zeros({2, 3, 4});
|
||||
CHECK_THROWS(swapaxes(a, 0, -4));
|
||||
CHECK_THROWS(swapaxes(a, 0, 3));
|
||||
CHECK_THROWS(swapaxes(a, 3, 0));
|
||||
CHECK_THROWS(swapaxes(a, -4, 0));
|
||||
CHECK_EQ(swapaxes(a, 0, 2).shape(), std::vector<int>{4, 3, 2});
|
||||
CHECK_EQ(swapaxes(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
|
||||
CHECK_EQ(swapaxes(a, 0, -1).shape(), std::vector<int>{4, 3, 2});
|
||||
CHECK_EQ(swapaxes(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
|
||||
|
||||
// Test moveaxis
|
||||
a = array(0.0);
|
||||
CHECK_THROWS(moveaxis(a, 0, 0));
|
||||
|
||||
a = zeros({2});
|
||||
CHECK_THROWS(moveaxis(a, 0, 1));
|
||||
CHECK_EQ(moveaxis(a, 0, 0).shape(), std::vector<int>{2});
|
||||
CHECK_EQ(moveaxis(a, -1, -1).shape(), std::vector<int>{2});
|
||||
|
||||
a = zeros({2, 3, 4});
|
||||
CHECK_THROWS(moveaxis(a, 0, -4));
|
||||
CHECK_THROWS(moveaxis(a, 0, 3));
|
||||
CHECK_THROWS(moveaxis(a, 3, 0));
|
||||
CHECK_THROWS(moveaxis(a, -4, 0));
|
||||
CHECK_EQ(moveaxis(a, 0, 2).shape(), std::vector<int>{3, 4, 2});
|
||||
CHECK_EQ(moveaxis(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
|
||||
CHECK_EQ(moveaxis(a, 0, -1).shape(), std::vector<int>{3, 4, 2});
|
||||
CHECK_EQ(moveaxis(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
|
||||
}
|
||||
|
||||
TEST_CASE("test transpose") {
|
||||
array x(1);
|
||||
auto y = transpose(x);
|
||||
@ -2003,4 +2042,4 @@ TEST_CASE("test eye with negative k offset") {
|
||||
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
|
||||
{4, 3});
|
||||
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
|
||||
}
|
||||
}
|
||||
|
@ -248,3 +248,104 @@ TEST_CASE("test vmap creation ops") {
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap slice") {
|
||||
{
|
||||
auto fun = [](array in) { return slice(in, {4}, {8}, {2}); };
|
||||
auto x = reshape(arange(16), {2, 8});
|
||||
auto out = vmap(fun)(x);
|
||||
auto expected = reshape(array({4, 6, 12, 14}), {2, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto fun = [](array in) { return slice(in, {0, 1}, {2, 3}); };
|
||||
auto x = reshape(arange(12), {2, 2, 3});
|
||||
auto out = vmap(fun, 1, 0)(x);
|
||||
auto expected = reshape(array({1, 2, 7, 8, 4, 5, 10, 11}), {2, 2, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap concatenate") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return std::vector<array>{concatenate(inputs, 0)};
|
||||
};
|
||||
auto x = reshape(arange(4), {2, 2});
|
||||
auto y = reshape(arange(4), {2, 2});
|
||||
auto out = vmap(fun)({x, y})[0];
|
||||
auto expected = reshape(array({0, 1, 0, 1, 2, 3, 2, 3}), {2, 4});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
out = vmap(fun, {1, 1})({x, y})[0];
|
||||
expected = reshape(array({0, 2, 0, 2, 1, 3, 1, 3}), {2, 4});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
out = vmap(fun, {0, 1})({x, y})[0];
|
||||
expected = reshape(array({0, 1, 0, 2, 2, 3, 1, 3}), {2, 4});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap gather") {
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = inputs[1];
|
||||
std::vector<int> slice_sizes = {1, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 2);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
auto out = vmap(fun, {0, -1})({x, y})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 2, 3, 2, 2});
|
||||
out = vmap(fun, {0, -1}, {3})({x, y})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2, 2});
|
||||
}
|
||||
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = inputs[1];
|
||||
std::vector<int> slice_sizes = {1, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
auto out = vmap(fun, {0, 0})({x, y})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
|
||||
}
|
||||
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = inputs[1];
|
||||
std::vector<int> slice_sizes = {1, 2, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
|
||||
auto out = vmap(fun, {-1, 0})({x, y})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2, 2});
|
||||
}
|
||||
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = std::vector<array>(inputs.begin() + 1, inputs.end());
|
||||
std::vector<int> slice_sizes = {1, 1, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, {0, 1}, slice_sizes), {1, 2});
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
auto z = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
auto out = vmap(fun, {-1, 0, 0})({x, y, z})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
|
||||
|
||||
z = array({0, 1, 0, 0, 1, 0}, {3, 2});
|
||||
out = vmap(fun, {-1, 0, 1})({x, y, z})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user