From 2acf2e003eadb3c1da9a3b707e9aa359369582fd Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 14 May 2025 21:53:01 -0700 Subject: [PATCH] Comments --- mlx/primitives.cpp | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index c36a61686..e1924e66c 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1488,16 +1488,16 @@ std::vector Divide::vjp( const std::vector& argnums, const std::vector&) { std::vector vjps; + array denominator_bar = conjugate(primals[1], stream()); for (auto arg : argnums) { if (arg == 0) { - vjps.push_back( - divide(cotangents[0], conjugate(primals[1], stream()), stream())); + vjps.push_back(divide(cotangents[0], denominator_bar, stream())); } else { vjps.push_back(negative( divide( multiply( cotangents[0], conjugate(primals[0], stream()), stream()), - square(conjugate(primals[1], stream()), stream()), + square(denominator_bar, stream()), stream()), stream())); } @@ -1972,8 +1972,14 @@ std::vector FFT::vjp( array last(N - 1 + odd, indices.dtype()); array one(1 / n_elements, in.dtype()); array two(2 / n_elements, in.dtype()); - array mask = - where((first < indices) & (indices < last), two, one, stream()); + array mask = where( + logical_and( + greater(indices, first, stream()), + less(indices, last, stream()), + stream()), + two, + one, + stream()); return { multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())}; } else if (real_) { @@ -1992,8 +1998,14 @@ std::vector FFT::vjp( array last(N - 1 + odd, indices.dtype()); array one(1, complex64); array half(0.5, complex64); - array mask = - where((first < indices) & (indices < last), half, one, stream()); + array mask = where( + logical_and( + greater(indices, first, stream()), + less(indices, last, stream()), + stream()), + half, + one, + stream()); return {multiply( fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()), array(n_elements, in.dtype()),