Fix complex multiplication derivatives

This commit is contained in:
Angelos Katharopoulos 2025-05-12 15:19:54 -07:00
parent 8f3d208dce
commit 23417cee8e

View File

@ -1490,12 +1490,14 @@ std::vector<array> Divide::vjp(
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg == 0) {
vjps.push_back(divide(cotangents[0], primals[1], stream()));
vjps.push_back(
divide(cotangents[0], conjugate(primals[1], stream()), stream()));
} else {
vjps.push_back(negative(
divide(
multiply(cotangents[0], primals[0], stream()),
square(primals[1], stream()),
multiply(
cotangents[0], conjugate(primals[0], stream()), stream()),
square(conjugate(primals[1], stream()), stream()),
stream()),
stream()));
}
@ -2776,7 +2778,8 @@ std::vector<array> Multiply::vjp(
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream()));
vjps.push_back(multiply(
conjugate(primals[1 - arg], stream()), cotangents[0], stream()));
}
return vjps;
}