Put along axis + fixe for partition grad (#1430)

* put along axis, fixes for partition grad

* zeros for arg reduce
This commit is contained in:
Awni Hannun
2024-09-23 10:03:38 -07:00
committed by GitHub
parent 2b878e9dd7
commit 195b429d99
9 changed files with 220 additions and 9 deletions

View File

@@ -471,6 +471,21 @@ std::pair<std::vector<array>, std::vector<int>> ArgPartition::vmap(
return {{argpartition(inputs[0], axis_ + axis_left, stream())}, axes};
}
std::vector<array> ArgPartition::vjp(
const std::vector<array>& primals,
const std::vector<array>&,
const std::vector<int>&,
const std::vector<array>&) {
return {zeros_like(primals[0], stream())};
}
std::vector<array> ArgPartition::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {zeros_like(tangents[0], stream())};
}
bool ArgPartition::is_equivalent(const Primitive& other) const {
const ArgPartition& r_other = static_cast<const ArgPartition&>(other);
return axis_ == r_other.axis_ && kth_ == r_other.kth_;
@@ -495,6 +510,21 @@ std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(
return {out, axes};
}
std::vector<array> ArgReduce::vjp(
const std::vector<array>& primals,
const std::vector<array>&,
const std::vector<int>&,
const std::vector<array>&) {
return {zeros_like(primals[0], stream())};
}
std::vector<array> ArgReduce::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {zeros_like(tangents[0], stream())};
}
std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -2336,7 +2366,13 @@ std::vector<array> Partition::vjp(
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
auto sort_idx = argpartition(primals[0], kth_, axis_, stream());
return {put_along_axis(
zeros_like(primals[0], stream()),
sort_idx,
cotangents[0],
axis_,
stream())};
}
std::vector<array> Partition::jvp(