mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Add move and swap axis, and vmap for slice, concat, and gather (#158)
* add move and swap axis, and vmap for slice, concat, and gather
This commit is contained in:
47
mlx/ops.cpp
47
mlx/ops.cpp
@@ -677,6 +677,53 @@ array pad(
|
||||
s);
|
||||
}
|
||||
|
||||
array moveaxis(
|
||||
const array& a,
|
||||
int source,
|
||||
int destination,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto check_ax = [&a](int ax) {
|
||||
auto ndim = static_cast<int>(a.ndim());
|
||||
if (ax < -ndim || ax >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[moveaxis] Invalid axis " << ax << " for array with " << ndim
|
||||
<< " dimensions.";
|
||||
throw std::out_of_range(msg.str());
|
||||
}
|
||||
return ax < 0 ? ax + ndim : ax;
|
||||
};
|
||||
source = check_ax(source);
|
||||
destination = check_ax(destination);
|
||||
std::vector<int> reorder(a.ndim());
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
reorder.erase(reorder.begin() + source);
|
||||
reorder.insert(reorder.begin() + destination, source);
|
||||
return transpose(a, reorder, s);
|
||||
}
|
||||
|
||||
array swapaxes(
|
||||
const array& a,
|
||||
int axis1,
|
||||
int axis2,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto check_ax = [&a](int ax) {
|
||||
auto ndim = static_cast<int>(a.ndim());
|
||||
if (ax < -ndim || ax >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[swapaxes] Invalid axis " << ax << " for array with " << ndim
|
||||
<< " dimensions.";
|
||||
throw std::out_of_range(msg.str());
|
||||
}
|
||||
return ax < 0 ? ax + ndim : ax;
|
||||
};
|
||||
axis1 = check_ax(axis1);
|
||||
axis2 = check_ax(axis2);
|
||||
std::vector<int> reorder(a.ndim());
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
std::swap(reorder[axis1], reorder[axis2]);
|
||||
return transpose(a, reorder, s);
|
||||
}
|
||||
|
||||
array transpose(
|
||||
const array& a,
|
||||
std::vector<int> axes,
|
||||
|
10
mlx/ops.h
10
mlx/ops.h
@@ -183,6 +183,16 @@ inline array transpose(
|
||||
return transpose(a, std::vector<int>(axes), s);
|
||||
}
|
||||
|
||||
/** Swap two axes of an array. */
|
||||
array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});
|
||||
|
||||
/** Move an axis of an array. */
|
||||
array moveaxis(
|
||||
const array& a,
|
||||
int source,
|
||||
int destination,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Pad an array with a constant value */
|
||||
array pad(
|
||||
const array& a,
|
||||
|
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@@ -512,7 +511,26 @@ array Concatenate::jvp(
|
||||
std::pair<array, int> Concatenate::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Concatenate vmap is NYI.");
|
||||
std::vector<array> t_inputs;
|
||||
// Find the first vmapped input
|
||||
int i = 0;
|
||||
for (; i < axes.size(); i++) {
|
||||
t_inputs.push_back(inputs[i]);
|
||||
if (axes[i] >= 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto out_ax = axes[i++];
|
||||
// Move vmap axes to the same spot.
|
||||
for (; i < axes.size(); ++i) {
|
||||
if (out_ax != axes[i] && axes[i] >= 0) {
|
||||
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
|
||||
} else {
|
||||
t_inputs.push_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
auto axis = axis_ + (axis_ >= out_ax);
|
||||
return {concatenate(t_inputs, axis, stream()), out_ax};
|
||||
}
|
||||
|
||||
bool Concatenate::is_equivalent(const Primitive& other) const {
|
||||
@@ -1054,7 +1072,53 @@ std::pair<array, int> Full::vmap(
|
||||
std::pair<array, int> Gather::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Gather vmap is NYI, please change slices instead");
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> indices(inputs.begin() + 1, inputs.end());
|
||||
auto gather_axes = axes_;
|
||||
auto slice_sizes = slice_sizes_;
|
||||
auto src_vmapped = axes[0] >= 0;
|
||||
auto indices_vmapped =
|
||||
std::any_of(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; });
|
||||
auto out_ax =
|
||||
*std::find_if(axes.begin(), axes.end(), [](int a) { return a >= 0; });
|
||||
|
||||
// Reorder all the index arrays so the vmap axis is in the same spot.
|
||||
for (int i = 1; i < axes.size(); ++i) {
|
||||
if (out_ax != axes[i] && axes[i] >= 0) {
|
||||
indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream());
|
||||
}
|
||||
}
|
||||
|
||||
if (src_vmapped) {
|
||||
int max_dims = 0;
|
||||
for (auto& idx : indices) {
|
||||
max_dims = std::max(static_cast<int>(idx.ndim()), max_dims);
|
||||
}
|
||||
auto new_ax_loc =
|
||||
std::find_if(gather_axes.begin(), gather_axes.end(), [&out_ax](int a) {
|
||||
return a >= out_ax;
|
||||
});
|
||||
for (; new_ax_loc < gather_axes.end(); new_ax_loc++) {
|
||||
(*new_ax_loc)++;
|
||||
}
|
||||
if (indices_vmapped) {
|
||||
// Make a new index array for the vmapped dimension
|
||||
// Reshape it so it broadcasts with other index arrays
|
||||
// Update gather axes and slice sizes accordingly
|
||||
auto shape = std::vector<int>(max_dims - out_ax, 1);
|
||||
auto vmap_inds = arange(0, src.shape(out_ax), stream());
|
||||
shape[0] = vmap_inds.shape(0);
|
||||
vmap_inds = reshape(vmap_inds, shape, stream());
|
||||
slice_sizes.insert(slice_sizes.begin() + out_ax, 1);
|
||||
auto new_ax_idx = new_ax_loc - gather_axes.begin();
|
||||
gather_axes.insert(new_ax_loc, out_ax);
|
||||
indices.insert(indices.begin() + new_ax_idx, vmap_inds);
|
||||
} else {
|
||||
slice_sizes.insert(slice_sizes.begin() + axes[0], src.shape(axes[0]));
|
||||
out_ax = max_dims + axes[0];
|
||||
}
|
||||
}
|
||||
return {gather(src, indices, gather_axes, slice_sizes, stream()), out_ax};
|
||||
}
|
||||
|
||||
std::vector<array> Gather::vjp(
|
||||
@@ -1997,8 +2061,15 @@ std::pair<array, int> Sinh::vmap(
|
||||
std::pair<array, int> Slice::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
// TODO implement
|
||||
return {array(1.0f), axes[0]};
|
||||
auto start = start_indices_;
|
||||
auto stop = end_indices_;
|
||||
auto strides = strides_;
|
||||
auto ax = axes[0];
|
||||
auto& input = inputs[0];
|
||||
start.insert(start.begin() + ax, 0);
|
||||
stop.insert(stop.begin() + ax, input.shape(ax));
|
||||
strides.insert(strides.begin() + ax, 1);
|
||||
return {slice(input, start, stop, strides, stream()), ax};
|
||||
}
|
||||
|
||||
std::vector<array> Slice::vjp(
|
||||
|
Reference in New Issue
Block a user