add zero for argsort vjp (#2345)

This commit is contained in:
Awni Hannun 2025-07-09 14:37:14 -07:00 committed by GitHub
parent 8b9a3f3cea
commit e14ee12491
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 2 deletions

View File

@ -620,10 +620,11 @@ std::vector<array> ArgReduce::vjp(
}
std::vector<array> ArgReduce::jvp(
const std::vector<array>& primals,
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {zeros_like(tangents[0], stream())};
auto shape = output_shapes(primals)[0];
return {zeros(shape, uint32, stream())};
}
std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
@ -647,6 +648,21 @@ bool ArgSort::is_equivalent(const Primitive& other) const {
return axis_ == r_other.axis_;
}
std::vector<array> ArgSort::vjp(
const std::vector<array>& primals,
const std::vector<array>&,
const std::vector<int>&,
const std::vector<array>&) {
return {zeros_like(primals[0], stream())};
}
std::vector<array> ArgSort::jvp(
const std::vector<array>& primals,
const std::vector<array>&,
const std::vector<int>&) {
return {zeros(primals[0].shape(), uint32, stream())};
}
std::vector<array> AsType::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@ -378,6 +378,7 @@ class ArgSort : public UnaryPrimitive {
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(ArgSort)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;