From 23417cee8e8495c9fcf222b80a75b323cf5a552e Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 12 May 2025 15:19:54 -0700 Subject: [PATCH] Fix complex multiplication derivatives --- mlx/primitives.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 03ca06bdd..906767c30 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1490,12 +1490,14 @@ std::vector Divide::vjp( std::vector vjps; for (auto arg : argnums) { if (arg == 0) { - vjps.push_back(divide(cotangents[0], primals[1], stream())); + vjps.push_back( + divide(cotangents[0], conjugate(primals[1], stream()), stream())); } else { vjps.push_back(negative( divide( - multiply(cotangents[0], primals[0], stream()), - square(primals[1], stream()), + multiply( + cotangents[0], conjugate(primals[0], stream()), stream()), + square(conjugate(primals[1], stream()), stream()), stream()), stream())); } @@ -2776,7 +2778,8 @@ std::vector Multiply::vjp( const std::vector&) { std::vector vjps; for (auto arg : argnums) { - vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream())); + vjps.push_back(multiply( + conjugate(primals[1 - arg], stream()), cotangents[0], stream())); } return vjps; }