mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-28 22:28:11 +08:00
Fix some complex vjps (#2178)
This commit is contained in:
committed by
GitHub
parent
130df35e1b
commit
cf6c939e86
@@ -1133,26 +1133,48 @@ TEST_CASE("test complex gradients") {
|
||||
}
|
||||
|
||||
{
|
||||
auto multiply_fn =
|
||||
[](const std::vector<array>& inputs) -> std::vector<array> {
|
||||
return {multiply(inputs[0], inputs[1])};
|
||||
};
|
||||
|
||||
// Compute jvp
|
||||
auto x = array(complex64_t{2.0, 4.0});
|
||||
auto y = array(3.0f);
|
||||
|
||||
auto x_tan = array(complex64_t{1.0, 2.0});
|
||||
auto y_tan = array(2.0f);
|
||||
auto jvp_out = jvp(multiply_fn, {x, y}, {x_tan, y_tan}).second;
|
||||
CHECK_EQ(jvp_out[0].item<complex64_t>(), complex64_t{7.0, 14.0});
|
||||
|
||||
auto out = jvp([x](array a) { return multiply(a, x); }, y, y_tan).second;
|
||||
CHECK_EQ(out.item<complex64_t>(), complex64_t{4.0, 8.0});
|
||||
|
||||
out = jvp([y](array a) { return multiply(a, y); }, x, x_tan).second;
|
||||
CHECK_EQ(out.item<complex64_t>(), complex64_t{3.0, 6.0});
|
||||
|
||||
// Compute vjp
|
||||
auto cotan = array(complex64_t{2.0, 3.0});
|
||||
out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second;
|
||||
CHECK_EQ(out.dtype(), float32);
|
||||
CHECK_EQ(out.item<float>(), -8.0);
|
||||
auto vjp_out = vjp(multiply_fn, {x, y}, {cotan}).second;
|
||||
CHECK_EQ(vjp_out[0].dtype(), complex64);
|
||||
CHECK_EQ(vjp_out[0].item<complex64_t>(), complex64_t{6.0, 9.0});
|
||||
CHECK_EQ(vjp_out[1].dtype(), float32);
|
||||
CHECK_EQ(vjp_out[1].item<float>(), 16);
|
||||
}
|
||||
|
||||
out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second;
|
||||
CHECK_EQ(out.item<complex64_t>(), complex64_t{6.0, 9.0});
|
||||
{
|
||||
auto divide_fn =
|
||||
[](const std::vector<array>& inputs) -> std::vector<array> {
|
||||
return {divide(inputs[0], inputs[1])};
|
||||
};
|
||||
|
||||
// Compute jvp
|
||||
auto x = array(complex64_t{2.0, 3.0});
|
||||
auto y = array(complex64_t{1.0, 2.0});
|
||||
auto x_tan = array(complex64_t{3.0, 4.0});
|
||||
auto y_tan = array(complex64_t{4.0, -2.0});
|
||||
auto jvp_out = jvp(divide_fn, {x, y}, {x_tan, y_tan}).second;
|
||||
CHECK_EQ(
|
||||
jvp_out[0].item<complex64_t>(), doctest::Approx(complex64_t{2.6, 2.8}));
|
||||
|
||||
// Compute vjp
|
||||
auto cotan = array(complex64_t{2.0, -4.0});
|
||||
auto vjp_out = vjp(divide_fn, {x, y}, {cotan}).second;
|
||||
CHECK_EQ(vjp_out[0].item<complex64_t>(), complex64_t{2.0, 0.0});
|
||||
CHECK_EQ(vjp_out[1].item<complex64_t>(), complex64_t{-3.2, -0.4});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -243,7 +243,7 @@ TEST_CASE("test fft grads") {
|
||||
auto fft_fn = [](array x) { return fft::fft(x); };
|
||||
auto cotangent = astype(arange(10), complex64);
|
||||
auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second;
|
||||
CHECK(array_equal(fft::fft(cotangent), vjp_out).item<bool>());
|
||||
CHECK(array_equal(fft::ifft(cotangent) * 10, vjp_out).item<bool>());
|
||||
|
||||
auto tangent = astype(arange(10), complex64);
|
||||
auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second;
|
||||
@@ -252,7 +252,7 @@ TEST_CASE("test fft grads") {
|
||||
// Inverse
|
||||
auto ifft_fn = [](array x) { return fft::ifft(x); };
|
||||
vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second;
|
||||
CHECK(array_equal(fft::ifft(cotangent), vjp_out).item<bool>());
|
||||
CHECK(array_equal(fft::fft(cotangent) * 0.1, vjp_out).item<bool>());
|
||||
|
||||
jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second;
|
||||
CHECK(array_equal(fft::ifft(tangent), jvp_out).item<bool>());
|
||||
@@ -261,7 +261,8 @@ TEST_CASE("test fft grads") {
|
||||
auto rfft_fn = [](array x) { return fft::rfft(x); };
|
||||
cotangent = astype(arange(6), complex64);
|
||||
vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second;
|
||||
auto expected = astype(fft::fft(cotangent, 10, 0), float32);
|
||||
array mask({1.0, 0.5, 0.5, 0.5, 0.5, 1.0}, complex64);
|
||||
auto expected = fft::irfft(cotangent * mask, 10, 0) * 10;
|
||||
CHECK(array_equal(expected, vjp_out).item<bool>());
|
||||
|
||||
tangent = astype(arange(10), float32);
|
||||
@@ -272,12 +273,9 @@ TEST_CASE("test fft grads") {
|
||||
auto irfft_fn = [](array x) { return fft::irfft(x); };
|
||||
cotangent = astype(arange(10), float32);
|
||||
vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second;
|
||||
expected = fft::fft(cotangent, 10, 0);
|
||||
auto o_splits = split(vjp_out, {1, 5});
|
||||
auto e_splits = split(expected, {1, 5, 6});
|
||||
CHECK_EQ(e_splits[0].item<complex64_t>(), o_splits[0].item<complex64_t>());
|
||||
CHECK(array_equal(2 * e_splits[1], o_splits[1]).item<bool>());
|
||||
CHECK_EQ(e_splits[2].item<complex64_t>(), o_splits[2].item<complex64_t>());
|
||||
mask = array({0.1, 0.2, 0.2, 0.2, 0.2, 0.1}, float32);
|
||||
expected = fft::rfft(cotangent) * mask;
|
||||
CHECK(array_equal(expected, vjp_out).item<bool>());
|
||||
|
||||
tangent = astype(arange(10), complex64);
|
||||
jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second;
|
||||
|
||||
Reference in New Issue
Block a user