Add move and swap axis, and vmap for slice, concat, and gather (#158)

* add move and swap axis, and vmap for slice, concat, and gather
This commit is contained in:
Awni Hannun 2023-12-14 12:59:12 -08:00 committed by GitHub
parent f55908bc48
commit e5851e52b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 399 additions and 7 deletions

View File

@ -61,6 +61,7 @@ Operations
mean mean
min min
minimum minimum
moveaxis
multiply multiply
negative negative
ones ones
@ -87,6 +88,7 @@ Operations
stop_gradient stop_gradient
subtract subtract
sum sum
swapaxes
take take
take_along_axis take_along_axis
tan tan

View File

@ -677,6 +677,53 @@ array pad(
s); s);
} }
array moveaxis(
const array& a,
int source,
int destination,
StreamOrDevice s /* = {} */) {
auto check_ax = [&a](int ax) {
auto ndim = static_cast<int>(a.ndim());
if (ax < -ndim || ax >= ndim) {
std::ostringstream msg;
msg << "[moveaxis] Invalid axis " << ax << " for array with " << ndim
<< " dimensions.";
throw std::out_of_range(msg.str());
}
return ax < 0 ? ax + ndim : ax;
};
source = check_ax(source);
destination = check_ax(destination);
std::vector<int> reorder(a.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
reorder.erase(reorder.begin() + source);
reorder.insert(reorder.begin() + destination, source);
return transpose(a, reorder, s);
}
array swapaxes(
const array& a,
int axis1,
int axis2,
StreamOrDevice s /* = {} */) {
auto check_ax = [&a](int ax) {
auto ndim = static_cast<int>(a.ndim());
if (ax < -ndim || ax >= ndim) {
std::ostringstream msg;
msg << "[swapaxes] Invalid axis " << ax << " for array with " << ndim
<< " dimensions.";
throw std::out_of_range(msg.str());
}
return ax < 0 ? ax + ndim : ax;
};
axis1 = check_ax(axis1);
axis2 = check_ax(axis2);
std::vector<int> reorder(a.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
std::swap(reorder[axis1], reorder[axis2]);
return transpose(a, reorder, s);
}
array transpose( array transpose(
const array& a, const array& a,
std::vector<int> axes, std::vector<int> axes,

View File

@ -183,6 +183,16 @@ inline array transpose(
return transpose(a, std::vector<int>(axes), s); return transpose(a, std::vector<int>(axes), s);
} }
/** Swap two axes of an array. */
array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});
/** Move an axis of an array. */
array moveaxis(
const array& a,
int source,
int destination,
StreamOrDevice s = {});
/** Pad an array with a constant value */ /** Pad an array with a constant value */
array pad( array pad(
const array& a, const array& a,

View File

@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@ -512,7 +511,26 @@ array Concatenate::jvp(
std::pair<array, int> Concatenate::vmap( std::pair<array, int> Concatenate::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
throw std::runtime_error("Concatenate vmap is NYI."); std::vector<array> t_inputs;
// Find the first vmapped input
int i = 0;
for (; i < axes.size(); i++) {
t_inputs.push_back(inputs[i]);
if (axes[i] >= 0) {
break;
}
}
auto out_ax = axes[i++];
// Move vmap axes to the same spot.
for (; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) {
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
} else {
t_inputs.push_back(inputs[i]);
}
}
auto axis = axis_ + (axis_ >= out_ax);
return {concatenate(t_inputs, axis, stream()), out_ax};
} }
bool Concatenate::is_equivalent(const Primitive& other) const { bool Concatenate::is_equivalent(const Primitive& other) const {
@ -1054,7 +1072,53 @@ std::pair<array, int> Full::vmap(
std::pair<array, int> Gather::vmap( std::pair<array, int> Gather::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
throw std::runtime_error("Gather vmap is NYI, please change slices instead"); auto& src = inputs[0];
std::vector<array> indices(inputs.begin() + 1, inputs.end());
auto gather_axes = axes_;
auto slice_sizes = slice_sizes_;
auto src_vmapped = axes[0] >= 0;
auto indices_vmapped =
std::any_of(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; });
auto out_ax =
*std::find_if(axes.begin(), axes.end(), [](int a) { return a >= 0; });
// Reorder all the index arrays so the vmap axis is in the same spot.
for (int i = 1; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) {
indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream());
}
}
if (src_vmapped) {
int max_dims = 0;
for (auto& idx : indices) {
max_dims = std::max(static_cast<int>(idx.ndim()), max_dims);
}
auto new_ax_loc =
std::find_if(gather_axes.begin(), gather_axes.end(), [&out_ax](int a) {
return a >= out_ax;
});
for (; new_ax_loc < gather_axes.end(); new_ax_loc++) {
(*new_ax_loc)++;
}
if (indices_vmapped) {
// Make a new index array for the vmapped dimension
// Reshape it so it broadcasts with other index arrays
// Update gather axes and slice sizes accordingly
auto shape = std::vector<int>(max_dims - out_ax, 1);
auto vmap_inds = arange(0, src.shape(out_ax), stream());
shape[0] = vmap_inds.shape(0);
vmap_inds = reshape(vmap_inds, shape, stream());
slice_sizes.insert(slice_sizes.begin() + out_ax, 1);
auto new_ax_idx = new_ax_loc - gather_axes.begin();
gather_axes.insert(new_ax_loc, out_ax);
indices.insert(indices.begin() + new_ax_idx, vmap_inds);
} else {
slice_sizes.insert(slice_sizes.begin() + axes[0], src.shape(axes[0]));
out_ax = max_dims + axes[0];
}
}
return {gather(src, indices, gather_axes, slice_sizes, stream()), out_ax};
} }
std::vector<array> Gather::vjp( std::vector<array> Gather::vjp(
@ -1997,8 +2061,15 @@ std::pair<array, int> Sinh::vmap(
std::pair<array, int> Slice::vmap( std::pair<array, int> Slice::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
// TODO implement auto start = start_indices_;
return {array(1.0f), axes[0]}; auto stop = end_indices_;
auto strides = strides_;
auto ax = axes[0];
auto& input = inputs[0];
start.insert(start.begin() + ax, 0);
stop.insert(stop.begin() + ax, input.shape(ax));
strides.insert(strides.begin() + ax, 1);
return {slice(input, start, stop, strides, stream()), ax};
} }
std::vector<array> Slice::vjp( std::vector<array> Slice::vjp(

View File

@ -862,6 +862,22 @@ void init_array(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
"See :func:`any`.") "See :func:`any`.")
.def(
"moveaxis",
&moveaxis,
"source"_a,
"destination"_a,
py::kw_only(),
"stream"_a = none,
"See :func:`moveaxis`.")
.def(
"swapaxes",
&swapaxes,
"axis1"_a,
"axis2"_a,
py::kw_only(),
"stream"_a = none,
"See :func:`moveaxis`.")
.def( .def(
"transpose", "transpose",
[](const array& a, py::args axes, StreamOrDevice s) { [](const array& a, py::args axes, StreamOrDevice s) {

View File

@ -1591,6 +1591,50 @@ void init_ops(py::module_& m) {
Returns: Returns:
array: The ceil of ``a``. array: The ceil of ``a``.
)pbdoc"); )pbdoc");
m.def(
"moveaxis",
&moveaxis,
"a"_a,
py::pos_only(),
"source"_a,
"destiantion"_a,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
moveaxis(a: array, /, source: int, destination: int, *, stream: Union[None, Stream, Device] = None) -> array
Move an axis to a new position.
Args:
a (array): Input array.
source (int): Specifies the source axis.
destination (int): Specifies the destination axis.
Returns:
array: The array with the axis moved.
)pbdoc");
m.def(
"swapaxes",
&swapaxes,
"a"_a,
py::pos_only(),
"axis1"_a,
"axis2"_a,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
swapaxes(a: array, /, axis1 : int, axis2: int, *, stream: Union[None, Stream, Device] = None) -> array
Swap two axes of an array.
Args:
a (array): Input array.
axis1 (int): Specifies the first axis.
axis2 (int): Specifies the second axis.
Returns:
array: The array with swapped axes.
)pbdoc");
m.def( m.def(
"transpose", "transpose",
[](const array& a, [](const array& a,

View File

@ -375,6 +375,13 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertListEqual(mx.transpose(x, axes=(0, 2, 1)).tolist(), expected) self.assertListEqual(mx.transpose(x, axes=(0, 2, 1)).tolist(), expected)
def test_move_swap_axes(self):
x = mx.zeros((2, 3, 4))
self.assertEqual(mx.moveaxis(x, 0, 2).shape, [3, 4, 2])
self.assertEqual(x.moveaxis(0, 2).shape, [3, 4, 2])
self.assertEqual(mx.swapaxes(x, 0, 2).shape, [4, 3, 2])
self.assertEqual(x.swapaxes(0, 2).shape, [4, 3, 2])
def test_sum(self): def test_sum(self):
x = mx.array( x = mx.array(
[ [

View File

@ -163,6 +163,61 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertTrue(mx.array_equal(out["a"].T, expected["a"])) self.assertTrue(mx.array_equal(out["a"].T, expected["a"]))
self.assertTrue(mx.array_equal(out["b"], expected["b"])) self.assertTrue(mx.array_equal(out["b"], expected["b"]))
def test_vmap_indexing(self):
x = mx.arange(16).reshape(2, 2, 2, 2)
inds = mx.array([[0, 1, 0], [1, 1, 0]])
out = mx.vmap(lambda x, y: x[y], in_axes=(0, 0))(x, inds)
expected = mx.array(
[
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],
[[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],
]
)
self.assertTrue(mx.array_equal(out, expected))
out = mx.vmap(lambda x, y: x[y], in_axes=(0, None))(x, inds)
expected = mx.array(
[
[
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],
[[[4, 5], [6, 7]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],
],
[
[[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],
[[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],
],
]
)
self.assertTrue(mx.array_equal(out, expected))
out = mx.vmap(lambda x, y: x[y], in_axes=(None, 0))(x, inds)
expected = mx.array(
[
[
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]],
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
],
[
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]],
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]],
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
],
]
)
self.assertTrue(mx.array_equal(out, expected))
inds2 = mx.array([[0, 1, 0], [0, 1, 0]])
out = mx.vmap(lambda x, y, z: x[y, z], in_axes=(None, 0, 0))(x, inds, inds2)
expected = mx.array(
[
[[[0, 1], [2, 3]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]],
[[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]],
]
)
self.assertTrue(mx.array_equal(out, expected))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <cmath> #include <cmath>
#include <numeric> #include <numeric>
@ -205,6 +204,46 @@ TEST_CASE("test split") {
CHECK(array_equal(out[3], array({2, 3, 4})).item<bool>()); CHECK(array_equal(out[3], array({2, 3, 4})).item<bool>());
} }
TEST_CASE("test swap and move axes") {
// Test swapaxes
array a(0.0);
CHECK_THROWS(swapaxes(a, 0, 0));
a = zeros({2});
CHECK_THROWS(swapaxes(a, 0, 1));
CHECK_EQ(swapaxes(a, 0, 0).shape(), std::vector<int>{2});
CHECK_EQ(swapaxes(a, -1, -1).shape(), std::vector<int>{2});
a = zeros({2, 3, 4});
CHECK_THROWS(swapaxes(a, 0, -4));
CHECK_THROWS(swapaxes(a, 0, 3));
CHECK_THROWS(swapaxes(a, 3, 0));
CHECK_THROWS(swapaxes(a, -4, 0));
CHECK_EQ(swapaxes(a, 0, 2).shape(), std::vector<int>{4, 3, 2});
CHECK_EQ(swapaxes(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
CHECK_EQ(swapaxes(a, 0, -1).shape(), std::vector<int>{4, 3, 2});
CHECK_EQ(swapaxes(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
// Test moveaxis
a = array(0.0);
CHECK_THROWS(moveaxis(a, 0, 0));
a = zeros({2});
CHECK_THROWS(moveaxis(a, 0, 1));
CHECK_EQ(moveaxis(a, 0, 0).shape(), std::vector<int>{2});
CHECK_EQ(moveaxis(a, -1, -1).shape(), std::vector<int>{2});
a = zeros({2, 3, 4});
CHECK_THROWS(moveaxis(a, 0, -4));
CHECK_THROWS(moveaxis(a, 0, 3));
CHECK_THROWS(moveaxis(a, 3, 0));
CHECK_THROWS(moveaxis(a, -4, 0));
CHECK_EQ(moveaxis(a, 0, 2).shape(), std::vector<int>{3, 4, 2});
CHECK_EQ(moveaxis(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
CHECK_EQ(moveaxis(a, 0, -1).shape(), std::vector<int>{3, 4, 2});
CHECK_EQ(moveaxis(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
}
TEST_CASE("test transpose") { TEST_CASE("test transpose") {
array x(1); array x(1);
auto y = transpose(x); auto y = transpose(x);

View File

@ -248,3 +248,104 @@ TEST_CASE("test vmap creation ops") {
CHECK(array_equal(out, expected).item<bool>()); CHECK(array_equal(out, expected).item<bool>());
} }
} }
TEST_CASE("test vmap slice") {
{
auto fun = [](array in) { return slice(in, {4}, {8}, {2}); };
auto x = reshape(arange(16), {2, 8});
auto out = vmap(fun)(x);
auto expected = reshape(array({4, 6, 12, 14}), {2, 2});
CHECK(array_equal(out, expected).item<bool>());
}
{
auto fun = [](array in) { return slice(in, {0, 1}, {2, 3}); };
auto x = reshape(arange(12), {2, 2, 3});
auto out = vmap(fun, 1, 0)(x);
auto expected = reshape(array({1, 2, 7, 8, 4, 5, 10, 11}), {2, 2, 2});
CHECK(array_equal(out, expected).item<bool>());
}
}
TEST_CASE("test vmap concatenate") {
auto fun = [](std::vector<array> inputs) {
return std::vector<array>{concatenate(inputs, 0)};
};
auto x = reshape(arange(4), {2, 2});
auto y = reshape(arange(4), {2, 2});
auto out = vmap(fun)({x, y})[0];
auto expected = reshape(array({0, 1, 0, 1, 2, 3, 2, 3}), {2, 4});
CHECK(array_equal(out, expected).item<bool>());
out = vmap(fun, {1, 1})({x, y})[0];
expected = reshape(array({0, 2, 0, 2, 1, 3, 1, 3}), {2, 4});
CHECK(array_equal(out, expected).item<bool>());
out = vmap(fun, {0, 1})({x, y})[0];
expected = reshape(array({0, 1, 0, 2, 2, 3, 1, 3}), {2, 4});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test vmap gather") {
{
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 2);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto out = vmap(fun, {0, -1})({x, y})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 2, 3, 2, 2});
out = vmap(fun, {0, -1}, {3})({x, y})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2, 2});
}
{
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto out = vmap(fun, {0, 0})({x, y})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
}
{
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto out = vmap(fun, {-1, 0})({x, y})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2, 2});
}
{
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = std::vector<array>(inputs.begin() + 1, inputs.end());
std::vector<int> slice_sizes = {1, 1, 2, 2};
auto out = squeeze(gather(src, indices, {0, 1}, slice_sizes), {1, 2});
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto z = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto out = vmap(fun, {-1, 0, 0})({x, y, z})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
z = array({0, 1, 0, 0, 1, 0}, {3, 2});
out = vmap(fun, {-1, 0, 1})({x, y, z})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
}
}