mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix complex multiplication derivatives
This commit is contained in:
parent
8f3d208dce
commit
23417cee8e
@ -1490,12 +1490,14 @@ std::vector<array> Divide::vjp(
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
vjps.push_back(divide(cotangents[0], primals[1], stream()));
|
||||
vjps.push_back(
|
||||
divide(cotangents[0], conjugate(primals[1], stream()), stream()));
|
||||
} else {
|
||||
vjps.push_back(negative(
|
||||
divide(
|
||||
multiply(cotangents[0], primals[0], stream()),
|
||||
square(primals[1], stream()),
|
||||
multiply(
|
||||
cotangents[0], conjugate(primals[0], stream()), stream()),
|
||||
square(conjugate(primals[1], stream()), stream()),
|
||||
stream()),
|
||||
stream()));
|
||||
}
|
||||
@ -2776,7 +2778,8 @@ std::vector<array> Multiply::vjp(
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream()));
|
||||
vjps.push_back(multiply(
|
||||
conjugate(primals[1 - arg], stream()), cotangents[0], stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user