diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index 8b7126714..5b3454bfc 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -1133,26 +1133,48 @@ TEST_CASE("test complex gradients") { } { + auto multiply_fn = + [](const std::vector& inputs) -> std::vector { + return {multiply(inputs[0], inputs[1])}; + }; + // Compute jvp auto x = array(complex64_t{2.0, 4.0}); auto y = array(3.0f); - auto x_tan = array(complex64_t{1.0, 2.0}); auto y_tan = array(2.0f); + auto jvp_out = jvp(multiply_fn, {x, y}, {x_tan, y_tan}).second; + CHECK_EQ(jvp_out[0].item(), complex64_t{7.0, 14.0}); - auto out = jvp([x](array a) { return multiply(a, x); }, y, y_tan).second; - CHECK_EQ(out.item(), complex64_t{4.0, 8.0}); - - out = jvp([y](array a) { return multiply(a, y); }, x, x_tan).second; - CHECK_EQ(out.item(), complex64_t{3.0, 6.0}); - + // Compute vjp auto cotan = array(complex64_t{2.0, 3.0}); - out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second; - CHECK_EQ(out.dtype(), float32); - CHECK_EQ(out.item(), 16.0); + auto vjp_out = vjp(multiply_fn, {x, y}, {cotan}).second; + CHECK_EQ(vjp_out[0].dtype(), complex64); + CHECK_EQ(vjp_out[0].item(), complex64_t{6.0, 9.0}); + CHECK_EQ(vjp_out[1].dtype(), float32); + CHECK_EQ(vjp_out[1].item(), 16); + } - out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second; - CHECK_EQ(out.item(), complex64_t{6.0, 9.0}); + { + auto divide_fn = + [](const std::vector& inputs) -> std::vector { + return {divide(inputs[0], inputs[1])}; + }; + + // Compute jvp + auto x = array(complex64_t{2.0, 3.0}); + auto y = array(complex64_t{1.0, 2.0}); + auto x_tan = array(complex64_t{3.0, 4.0}); + auto y_tan = array(complex64_t{4.0, -2.0}); + auto jvp_out = jvp(divide_fn, {x, y}, {x_tan, y_tan}).second; + CHECK_EQ( + jvp_out[0].item(), doctest::Approx(complex64_t{2.6, 2.8})); + + // Compute vjp + auto cotan = array(complex64_t{2.0, -4.0}); + auto vjp_out = vjp(divide_fn, {x, y}, {cotan}).second; + CHECK_EQ(vjp_out[0].item(), complex64_t{2.0, 0.0}); + CHECK_EQ(vjp_out[1].item(), complex64_t{-3.2, -0.4}); } }