mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Comments
This commit is contained in:
@@ -1488,16 +1488,16 @@ std::vector<array> Divide::vjp(
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>&) {
|
const std::vector<array>&) {
|
||||||
std::vector<array> vjps;
|
std::vector<array> vjps;
|
||||||
|
array denominator_bar = conjugate(primals[1], stream());
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
if (arg == 0) {
|
if (arg == 0) {
|
||||||
vjps.push_back(
|
vjps.push_back(divide(cotangents[0], denominator_bar, stream()));
|
||||||
divide(cotangents[0], conjugate(primals[1], stream()), stream()));
|
|
||||||
} else {
|
} else {
|
||||||
vjps.push_back(negative(
|
vjps.push_back(negative(
|
||||||
divide(
|
divide(
|
||||||
multiply(
|
multiply(
|
||||||
cotangents[0], conjugate(primals[0], stream()), stream()),
|
cotangents[0], conjugate(primals[0], stream()), stream()),
|
||||||
square(conjugate(primals[1], stream()), stream()),
|
square(denominator_bar, stream()),
|
||||||
stream()),
|
stream()),
|
||||||
stream()));
|
stream()));
|
||||||
}
|
}
|
||||||
@@ -1972,8 +1972,14 @@ std::vector<array> FFT::vjp(
|
|||||||
array last(N - 1 + odd, indices.dtype());
|
array last(N - 1 + odd, indices.dtype());
|
||||||
array one(1 / n_elements, in.dtype());
|
array one(1 / n_elements, in.dtype());
|
||||||
array two(2 / n_elements, in.dtype());
|
array two(2 / n_elements, in.dtype());
|
||||||
array mask =
|
array mask = where(
|
||||||
where((first < indices) & (indices < last), two, one, stream());
|
logical_and(
|
||||||
|
greater(indices, first, stream()),
|
||||||
|
less(indices, last, stream()),
|
||||||
|
stream()),
|
||||||
|
two,
|
||||||
|
one,
|
||||||
|
stream());
|
||||||
return {
|
return {
|
||||||
multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())};
|
multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())};
|
||||||
} else if (real_) {
|
} else if (real_) {
|
||||||
@@ -1992,8 +1998,14 @@ std::vector<array> FFT::vjp(
|
|||||||
array last(N - 1 + odd, indices.dtype());
|
array last(N - 1 + odd, indices.dtype());
|
||||||
array one(1, complex64);
|
array one(1, complex64);
|
||||||
array half(0.5, complex64);
|
array half(0.5, complex64);
|
||||||
array mask =
|
array mask = where(
|
||||||
where((first < indices) & (indices < last), half, one, stream());
|
logical_and(
|
||||||
|
greater(indices, first, stream()),
|
||||||
|
less(indices, last, stream()),
|
||||||
|
stream()),
|
||||||
|
half,
|
||||||
|
one,
|
||||||
|
stream());
|
||||||
return {multiply(
|
return {multiply(
|
||||||
fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()),
|
fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()),
|
||||||
array(n_elements, in.dtype()),
|
array(n_elements, in.dtype()),
|
||||||
|
|||||||
Reference in New Issue
Block a user