Power VJP fix for 0 (#505)

This commit is contained in:
Awni Hannun
2024-01-20 01:17:40 -08:00
committed by GitHub
parent 6bf779e72b
commit b207c2c86b
2 changed files with 19 additions and 1 deletions

View File

@@ -1852,7 +1852,12 @@ std::vector<array> Power::vjp(
for (auto arg : argnums) {
if (arg == 0) {
vjps.push_back(multiply(
outputs[0], divide(primals[1], primals[0], stream()), stream()));
power(
primals[0],
subtract(primals[1], array(1, primals[0].dtype()), stream()),
stream()),
primals[1],
stream()));
} else {
vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream()));
}