diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b2b7306dd..eb5d9d6b3 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -620,10 +620,11 @@ std::vector ArgReduce::vjp( } std::vector ArgReduce::jvp( + const std::vector& primals, const std::vector&, - const std::vector& tangents, const std::vector&) { - return {zeros_like(tangents[0], stream())}; + auto shape = output_shapes(primals)[0]; + return {zeros(shape, uint32, stream())}; } std::pair, std::vector> ArgSort::vmap( @@ -647,6 +648,21 @@ bool ArgSort::is_equivalent(const Primitive& other) const { return axis_ == r_other.axis_; } +std::vector ArgSort::vjp( + const std::vector& primals, + const std::vector&, + const std::vector&, + const std::vector&) { + return {zeros_like(primals[0], stream())}; +} + +std::vector ArgSort::jvp( + const std::vector& primals, + const std::vector&, + const std::vector&) { + return {zeros(primals[0].shape(), uint32, stream())}; +} + std::vector AsType::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index f4f157298..3d3202aaa 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -378,6 +378,7 @@ class ArgSort : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() + DEFINE_GRADS() DEFINE_PRINT(ArgSort) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override;