Add move and swap axis, and vmap for slice, concat, and gather (#158)

* add move and swap axis, and vmap for slice, concat, and gather
This commit is contained in:
Awni Hannun
2023-12-14 12:59:12 -08:00
committed by GitHub
parent f55908bc48
commit e5851e52b1
10 changed files with 399 additions and 7 deletions

View File

@@ -677,6 +677,53 @@ array pad(
s);
}
array moveaxis(
const array& a,
int source,
int destination,
StreamOrDevice s /* = {} */) {
auto check_ax = [&a](int ax) {
auto ndim = static_cast<int>(a.ndim());
if (ax < -ndim || ax >= ndim) {
std::ostringstream msg;
msg << "[moveaxis] Invalid axis " << ax << " for array with " << ndim
<< " dimensions.";
throw std::out_of_range(msg.str());
}
return ax < 0 ? ax + ndim : ax;
};
source = check_ax(source);
destination = check_ax(destination);
std::vector<int> reorder(a.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
reorder.erase(reorder.begin() + source);
reorder.insert(reorder.begin() + destination, source);
return transpose(a, reorder, s);
}
array swapaxes(
const array& a,
int axis1,
int axis2,
StreamOrDevice s /* = {} */) {
auto check_ax = [&a](int ax) {
auto ndim = static_cast<int>(a.ndim());
if (ax < -ndim || ax >= ndim) {
std::ostringstream msg;
msg << "[swapaxes] Invalid axis " << ax << " for array with " << ndim
<< " dimensions.";
throw std::out_of_range(msg.str());
}
return ax < 0 ? ax + ndim : ax;
};
axis1 = check_ax(axis1);
axis2 = check_ax(axis2);
std::vector<int> reorder(a.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
std::swap(reorder[axis1], reorder[axis2]);
return transpose(a, reorder, s);
}
array transpose(
const array& a,
std::vector<int> axes,

View File

@@ -183,6 +183,16 @@ inline array transpose(
return transpose(a, std::vector<int>(axes), s);
}
/** Swap two axes of an array. */
array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});
/** Move an axis of an array. */
array moveaxis(
const array& a,
int source,
int destination,
StreamOrDevice s = {});
/** Pad an array with a constant value */
array pad(
const array& a,

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
@@ -512,7 +511,26 @@ array Concatenate::jvp(
std::pair<array, int> Concatenate::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Concatenate vmap is NYI.");
std::vector<array> t_inputs;
// Find the first vmapped input
int i = 0;
for (; i < axes.size(); i++) {
t_inputs.push_back(inputs[i]);
if (axes[i] >= 0) {
break;
}
}
auto out_ax = axes[i++];
// Move vmap axes to the same spot.
for (; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) {
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
} else {
t_inputs.push_back(inputs[i]);
}
}
auto axis = axis_ + (axis_ >= out_ax);
return {concatenate(t_inputs, axis, stream()), out_ax};
}
bool Concatenate::is_equivalent(const Primitive& other) const {
@@ -1054,7 +1072,53 @@ std::pair<array, int> Full::vmap(
std::pair<array, int> Gather::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Gather vmap is NYI, please change slices instead");
auto& src = inputs[0];
std::vector<array> indices(inputs.begin() + 1, inputs.end());
auto gather_axes = axes_;
auto slice_sizes = slice_sizes_;
auto src_vmapped = axes[0] >= 0;
auto indices_vmapped =
std::any_of(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; });
auto out_ax =
*std::find_if(axes.begin(), axes.end(), [](int a) { return a >= 0; });
// Reorder all the index arrays so the vmap axis is in the same spot.
for (int i = 1; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) {
indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream());
}
}
if (src_vmapped) {
int max_dims = 0;
for (auto& idx : indices) {
max_dims = std::max(static_cast<int>(idx.ndim()), max_dims);
}
auto new_ax_loc =
std::find_if(gather_axes.begin(), gather_axes.end(), [&out_ax](int a) {
return a >= out_ax;
});
for (; new_ax_loc < gather_axes.end(); new_ax_loc++) {
(*new_ax_loc)++;
}
if (indices_vmapped) {
// Make a new index array for the vmapped dimension
// Reshape it so it broadcasts with other index arrays
// Update gather axes and slice sizes accordingly
auto shape = std::vector<int>(max_dims - out_ax, 1);
auto vmap_inds = arange(0, src.shape(out_ax), stream());
shape[0] = vmap_inds.shape(0);
vmap_inds = reshape(vmap_inds, shape, stream());
slice_sizes.insert(slice_sizes.begin() + out_ax, 1);
auto new_ax_idx = new_ax_loc - gather_axes.begin();
gather_axes.insert(new_ax_loc, out_ax);
indices.insert(indices.begin() + new_ax_idx, vmap_inds);
} else {
slice_sizes.insert(slice_sizes.begin() + axes[0], src.shape(axes[0]));
out_ax = max_dims + axes[0];
}
}
return {gather(src, indices, gather_axes, slice_sizes, stream()), out_ax};
}
std::vector<array> Gather::vjp(
@@ -1997,8 +2061,15 @@ std::pair<array, int> Sinh::vmap(
std::pair<array, int> Slice::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
// TODO implement
return {array(1.0f), axes[0]};
auto start = start_indices_;
auto stop = end_indices_;
auto strides = strides_;
auto ax = axes[0];
auto& input = inputs[0];
start.insert(start.begin() + ax, 0);
stop.insert(stop.begin() + ax, input.shape(ax));
strides.insert(strides.begin() + ax, 1);
return {slice(input, start, stop, strides, stream()), ax};
}
std::vector<array> Slice::vjp(