add zero for argsort vjp (#2345)

This commit is contained in:
Awni Hannun
2025-07-09 14:37:14 -07:00
committed by GitHub
parent 8b9a3f3cea
commit e14ee12491
2 changed files with 19 additions and 2 deletions

View File

@@ -378,6 +378,7 @@ class ArgSort : public UnaryPrimitive {
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(ArgSort)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;