mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Power VJP fix for 0 (#505)
This commit is contained in:
@@ -1852,7 +1852,12 @@ std::vector<array> Power::vjp(
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
vjps.push_back(multiply(
|
||||
outputs[0], divide(primals[1], primals[0], stream()), stream()));
|
||||
power(
|
||||
primals[0],
|
||||
subtract(primals[1], array(1, primals[0].dtype()), stream()),
|
||||
stream()),
|
||||
primals[1],
|
||||
stream()));
|
||||
} else {
|
||||
vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream()));
|
||||
}
|
||||
|
Reference in New Issue
Block a user