diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index c2bb59c05..5f2bfdda4 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3548,7 +3548,7 @@ std::vector Reduce::vjp( } else { - throw std::runtime_error("Reduce type VJP not yet implemented."); + return {zeros_like(in, stream())}; } }