mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Fix complex multiplication derivatives
This commit is contained in:
parent
8f3d208dce
commit
23417cee8e
@ -1490,12 +1490,14 @@ std::vector<array> Divide::vjp(
|
|||||||
std::vector<array> vjps;
|
std::vector<array> vjps;
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
if (arg == 0) {
|
if (arg == 0) {
|
||||||
vjps.push_back(divide(cotangents[0], primals[1], stream()));
|
vjps.push_back(
|
||||||
|
divide(cotangents[0], conjugate(primals[1], stream()), stream()));
|
||||||
} else {
|
} else {
|
||||||
vjps.push_back(negative(
|
vjps.push_back(negative(
|
||||||
divide(
|
divide(
|
||||||
multiply(cotangents[0], primals[0], stream()),
|
multiply(
|
||||||
square(primals[1], stream()),
|
cotangents[0], conjugate(primals[0], stream()), stream()),
|
||||||
|
square(conjugate(primals[1], stream()), stream()),
|
||||||
stream()),
|
stream()),
|
||||||
stream()));
|
stream()));
|
||||||
}
|
}
|
||||||
@ -2776,7 +2778,8 @@ std::vector<array> Multiply::vjp(
|
|||||||
const std::vector<array>&) {
|
const std::vector<array>&) {
|
||||||
std::vector<array> vjps;
|
std::vector<array> vjps;
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream()));
|
vjps.push_back(multiply(
|
||||||
|
conjugate(primals[1 - arg], stream()), cotangents[0], stream()));
|
||||||
}
|
}
|
||||||
return vjps;
|
return vjps;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user