mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
Fix some complex vjps (#2178)
This commit is contained in:

committed by
GitHub

parent
130df35e1b
commit
cf6c939e86
@@ -1488,14 +1488,16 @@ std::vector<array> Divide::vjp(
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
array denominator_bar = conjugate(primals[1], stream());
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
vjps.push_back(divide(cotangents[0], primals[1], stream()));
|
||||
vjps.push_back(divide(cotangents[0], denominator_bar, stream()));
|
||||
} else {
|
||||
vjps.push_back(negative(
|
||||
divide(
|
||||
multiply(cotangents[0], primals[0], stream()),
|
||||
square(primals[1], stream()),
|
||||
multiply(
|
||||
cotangents[0], conjugate(primals[0], stream()), stream()),
|
||||
square(denominator_bar, stream()),
|
||||
stream()),
|
||||
stream()));
|
||||
}
|
||||
@@ -1950,30 +1952,74 @@ std::vector<array> FFT::vjp(
|
||||
assert(argnums.size() == 1);
|
||||
auto& in = primals[0];
|
||||
std::vector<int> axes(axes_.begin(), axes_.end());
|
||||
|
||||
// TODO: Add it as an option to do an unnormalized or scaled fft so that this
|
||||
// isn't part of the graph.
|
||||
double n_elements = 1;
|
||||
for (auto ax : axes) {
|
||||
n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax);
|
||||
}
|
||||
|
||||
if (real_ && inverse_) {
|
||||
auto out = fft::fftn(cotangents[0], axes, stream());
|
||||
auto start = Shape(out.ndim(), 0);
|
||||
auto stop = in.shape();
|
||||
out = slice(out, start, stop, stream());
|
||||
auto mask_shape = out.shape();
|
||||
mask_shape[axes_.back()] -= 2;
|
||||
auto mask = full(mask_shape, 2.0f, stream());
|
||||
auto pad_shape = out.shape();
|
||||
pad_shape[axes_.back()] = 1;
|
||||
auto pad = full(pad_shape, 1.0f, stream());
|
||||
mask = concatenate({pad, mask, pad}, axes_.back(), stream());
|
||||
return {multiply(mask, out, stream())};
|
||||
// Make a mask to account for the double use in the forward pass.
|
||||
// Everything except the DC and nyquist frequencies gets doubled.
|
||||
int N = in.shape(axes_.back());
|
||||
bool odd = cotangents[0].shape(axes_.back()) % 2;
|
||||
Shape c(in.ndim(), 1);
|
||||
c[axes_.back()] = N;
|
||||
array indices = reshape(arange(N, stream()), std::move(c), stream());
|
||||
array first(0, indices.dtype());
|
||||
array last(N - 1 + odd, indices.dtype());
|
||||
array one(1 / n_elements, in.dtype());
|
||||
array two(2 / n_elements, in.dtype());
|
||||
array mask = where(
|
||||
logical_and(
|
||||
greater(indices, first, stream()),
|
||||
less(indices, last, stream()),
|
||||
stream()),
|
||||
two,
|
||||
one,
|
||||
stream());
|
||||
return {
|
||||
multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())};
|
||||
} else if (real_) {
|
||||
Shape n;
|
||||
for (auto ax : axes_) {
|
||||
n.push_back(in.shape()[ax]);
|
||||
n.push_back(in.shape(ax));
|
||||
}
|
||||
return {astype(
|
||||
fft::fftn(cotangents[0], n, axes, stream()), in.dtype(), stream())};
|
||||
// Make a mask to account for the double use in the forward pass.
|
||||
// Everything except the DC and nyquist frequencies gets halved.
|
||||
int N = cotangents[0].shape(axes_.back());
|
||||
bool odd = in.shape(axes_.back()) % 2;
|
||||
Shape c(in.ndim(), 1);
|
||||
c[axes_.back()] = N;
|
||||
array indices = reshape(arange(N, stream()), std::move(c), stream());
|
||||
array first(0, indices.dtype());
|
||||
array last(N - 1 + odd, indices.dtype());
|
||||
array one(1, complex64);
|
||||
array half(0.5, complex64);
|
||||
array mask = where(
|
||||
logical_and(
|
||||
greater(indices, first, stream()),
|
||||
less(indices, last, stream()),
|
||||
stream()),
|
||||
half,
|
||||
one,
|
||||
stream());
|
||||
return {multiply(
|
||||
fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()),
|
||||
array(n_elements, in.dtype()),
|
||||
stream())};
|
||||
} else if (inverse_) {
|
||||
return {fft::ifftn(cotangents[0], axes, stream())};
|
||||
return {multiply(
|
||||
fft::fftn(cotangents[0], axes, stream()),
|
||||
array(1 / n_elements, complex64),
|
||||
stream())};
|
||||
} else {
|
||||
return {fft::fftn(cotangents[0], axes, stream())};
|
||||
return {multiply(
|
||||
fft::ifftn(cotangents[0], axes, stream()),
|
||||
array(n_elements, complex64),
|
||||
stream())};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2776,7 +2822,8 @@ std::vector<array> Multiply::vjp(
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream()));
|
||||
vjps.push_back(multiply(
|
||||
conjugate(primals[1 - arg], stream()), cotangents[0], stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
Reference in New Issue
Block a user