Fix complex multiplication derivatives

This commit is contained in:
Angelos Katharopoulos 2025-05-12 15:19:54 -07:00
parent 8f3d208dce
commit 23417cee8e

View File

@ -1490,12 +1490,14 @@ std::vector<array> Divide::vjp(
std::vector<array> vjps; std::vector<array> vjps;
for (auto arg : argnums) { for (auto arg : argnums) {
if (arg == 0) { if (arg == 0) {
vjps.push_back(divide(cotangents[0], primals[1], stream())); vjps.push_back(
divide(cotangents[0], conjugate(primals[1], stream()), stream()));
} else { } else {
vjps.push_back(negative( vjps.push_back(negative(
divide( divide(
multiply(cotangents[0], primals[0], stream()), multiply(
square(primals[1], stream()), cotangents[0], conjugate(primals[0], stream()), stream()),
square(conjugate(primals[1], stream()), stream()),
stream()), stream()),
stream())); stream()));
} }
@ -2776,7 +2778,8 @@ std::vector<array> Multiply::vjp(
const std::vector<array>&) { const std::vector<array>&) {
std::vector<array> vjps; std::vector<array> vjps;
for (auto arg : argnums) { for (auto arg : argnums) {
vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream())); vjps.push_back(multiply(
conjugate(primals[1 - arg], stream()), cotangents[0], stream()));
} }
return vjps; return vjps;
} }