Put along axis + fixe for partition grad (#1430)

* put along axis, fixes for partition grad

* zeros for arg reduce
This commit is contained in:
Awni Hannun
2024-09-23 10:03:38 -07:00
committed by GitHub
parent 2b878e9dd7
commit 195b429d99
9 changed files with 220 additions and 9 deletions

View File

@@ -2767,6 +2767,53 @@ array take_along_axis(
return reshape(out, out_shape, s);
}
array put_along_axis(
const array& a,
const array& indices,
const array& values,
int axis,
StreamOrDevice s /* = {} */) {
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
std::ostringstream msg;
msg << "[put_along_axis] Received invalid axis " << " for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (indices.ndim() != a.ndim()) {
std::ostringstream msg;
msg << "[put_along_axis] Indices of dimension " << indices.ndim()
<< " does not match array of dimension " << a.ndim() << ".";
throw std::invalid_argument(msg.str());
}
// Allow negative axis
axis = axis < 0 ? a.ndim() + axis : axis;
std::vector<array> nd_indices;
std::vector<int> index_shape(a.ndim(), 1);
for (int i = 0; i < a.ndim(); ++i) {
if (i == axis) {
nd_indices.push_back(indices);
} else {
// Reshape so they can be broadcast
index_shape[i] = a.shape(i);
nd_indices.push_back(reshape(arange(a.shape(i), s), index_shape, s));
index_shape[i] = 1;
}
}
auto update = astype(broadcast_to(values, indices.shape(), s), a.dtype(), s);
{
auto update_shape = update.shape();
update_shape.resize(update_shape.size() + a.ndim(), 1);
update = reshape(update, std::move(update_shape), s);
}
std::vector<int> dims(a.ndim());
std::iota(dims.begin(), dims.end(), 0);
return scatter(a, nd_indices, update, dims, s);
}
/** Scatter updates to given indices */
array scatter(
const array& a,
@@ -2853,7 +2900,6 @@ array scatter(
}
inputs.insert(inputs.begin(), a);
// TODO promote or cast?
inputs.push_back(astype(updates, a.dtype(), s));
return array(

View File

@@ -947,6 +947,14 @@ array take_along_axis(
int axis,
StreamOrDevice s = {});
/** Put the values into the array at the given indices along the axis */
array put_along_axis(
const array& a,
const array& indices,
const array& values,
int axis,
StreamOrDevice s = {});
/** Scatter updates to the given indices.
*
* The parameters ``indices`` and ``axes`` determine the locations of ``a``

View File

@@ -471,6 +471,21 @@ std::pair<std::vector<array>, std::vector<int>> ArgPartition::vmap(
return {{argpartition(inputs[0], axis_ + axis_left, stream())}, axes};
}
std::vector<array> ArgPartition::vjp(
const std::vector<array>& primals,
const std::vector<array>&,
const std::vector<int>&,
const std::vector<array>&) {
return {zeros_like(primals[0], stream())};
}
std::vector<array> ArgPartition::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {zeros_like(tangents[0], stream())};
}
bool ArgPartition::is_equivalent(const Primitive& other) const {
const ArgPartition& r_other = static_cast<const ArgPartition&>(other);
return axis_ == r_other.axis_ && kth_ == r_other.kth_;
@@ -495,6 +510,21 @@ std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(
return {out, axes};
}
std::vector<array> ArgReduce::vjp(
const std::vector<array>& primals,
const std::vector<array>&,
const std::vector<int>&,
const std::vector<array>&) {
return {zeros_like(primals[0], stream())};
}
std::vector<array> ArgReduce::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {zeros_like(tangents[0], stream())};
}
std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -2336,7 +2366,13 @@ std::vector<array> Partition::vjp(
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
auto sort_idx = argpartition(primals[0], kth_, axis_, stream());
return {put_along_axis(
zeros_like(primals[0], stream()),
sort_idx,
cotangents[0],
axis_,
stream())};
}
std::vector<array> Partition::jvp(

View File

@@ -357,6 +357,7 @@ class ArgPartition : public UnaryPrimitive {
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(ArgPartition)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
@@ -382,6 +383,7 @@ class ArgReduce : public UnaryPrimitive {
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(ArgReduce)
bool is_equivalent(const Primitive& other) const override;
std::vector<std::vector<int>> output_shapes(