mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-15 15:38:10 +08:00
Comments
This commit is contained in:
@@ -1488,16 +1488,16 @@ std::vector<array> Divide::vjp(
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
array denominator_bar = conjugate(primals[1], stream());
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
vjps.push_back(
|
||||
divide(cotangents[0], conjugate(primals[1], stream()), stream()));
|
||||
vjps.push_back(divide(cotangents[0], denominator_bar, stream()));
|
||||
} else {
|
||||
vjps.push_back(negative(
|
||||
divide(
|
||||
multiply(
|
||||
cotangents[0], conjugate(primals[0], stream()), stream()),
|
||||
square(conjugate(primals[1], stream()), stream()),
|
||||
square(denominator_bar, stream()),
|
||||
stream()),
|
||||
stream()));
|
||||
}
|
||||
@@ -1972,8 +1972,14 @@ std::vector<array> FFT::vjp(
|
||||
array last(N - 1 + odd, indices.dtype());
|
||||
array one(1 / n_elements, in.dtype());
|
||||
array two(2 / n_elements, in.dtype());
|
||||
array mask =
|
||||
where((first < indices) & (indices < last), two, one, stream());
|
||||
array mask = where(
|
||||
logical_and(
|
||||
greater(indices, first, stream()),
|
||||
less(indices, last, stream()),
|
||||
stream()),
|
||||
two,
|
||||
one,
|
||||
stream());
|
||||
return {
|
||||
multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())};
|
||||
} else if (real_) {
|
||||
@@ -1992,8 +1998,14 @@ std::vector<array> FFT::vjp(
|
||||
array last(N - 1 + odd, indices.dtype());
|
||||
array one(1, complex64);
|
||||
array half(0.5, complex64);
|
||||
array mask =
|
||||
where((first < indices) & (indices < last), half, one, stream());
|
||||
array mask = where(
|
||||
logical_and(
|
||||
greater(indices, first, stream()),
|
||||
less(indices, last, stream()),
|
||||
stream()),
|
||||
half,
|
||||
one,
|
||||
stream());
|
||||
return {multiply(
|
||||
fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()),
|
||||
array(n_elements, in.dtype()),
|
||||
|
Reference in New Issue
Block a user