Add more tests

This commit is contained in:
Angelos Katharopoulos 2025-05-14 23:13:07 -07:00
parent 2acf2e003e
commit cf8766e71d

View File

@ -1133,26 +1133,48 @@ TEST_CASE("test complex gradients") {
}
{
auto multiply_fn =
[](const std::vector<array>& inputs) -> std::vector<array> {
return {multiply(inputs[0], inputs[1])};
};
// Compute jvp
auto x = array(complex64_t{2.0, 4.0});
auto y = array(3.0f);
auto x_tan = array(complex64_t{1.0, 2.0});
auto y_tan = array(2.0f);
auto jvp_out = jvp(multiply_fn, {x, y}, {x_tan, y_tan}).second;
CHECK_EQ(jvp_out[0].item<complex64_t>(), complex64_t{7.0, 14.0});
auto out = jvp([x](array a) { return multiply(a, x); }, y, y_tan).second;
CHECK_EQ(out.item<complex64_t>(), complex64_t{4.0, 8.0});
out = jvp([y](array a) { return multiply(a, y); }, x, x_tan).second;
CHECK_EQ(out.item<complex64_t>(), complex64_t{3.0, 6.0});
// Compute vjp
auto cotan = array(complex64_t{2.0, 3.0});
out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second;
CHECK_EQ(out.dtype(), float32);
CHECK_EQ(out.item<float>(), 16.0);
auto vjp_out = vjp(multiply_fn, {x, y}, {cotan}).second;
CHECK_EQ(vjp_out[0].dtype(), complex64);
CHECK_EQ(vjp_out[0].item<complex64_t>(), complex64_t{6.0, 9.0});
CHECK_EQ(vjp_out[1].dtype(), float32);
CHECK_EQ(vjp_out[1].item<float>(), 16);
}
out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second;
CHECK_EQ(out.item<complex64_t>(), complex64_t{6.0, 9.0});
{
auto divide_fn =
[](const std::vector<array>& inputs) -> std::vector<array> {
return {divide(inputs[0], inputs[1])};
};
// Compute jvp
auto x = array(complex64_t{2.0, 3.0});
auto y = array(complex64_t{1.0, 2.0});
auto x_tan = array(complex64_t{3.0, 4.0});
auto y_tan = array(complex64_t{4.0, -2.0});
auto jvp_out = jvp(divide_fn, {x, y}, {x_tan, y_tan}).second;
CHECK_EQ(
jvp_out[0].item<complex64_t>(), doctest::Approx(complex64_t{2.6, 2.8}));
// Compute vjp
auto cotan = array(complex64_t{2.0, -4.0});
auto vjp_out = vjp(divide_fn, {x, y}, {cotan}).second;
CHECK_EQ(vjp_out[0].item<complex64_t>(), complex64_t{2.0, 0.0});
CHECK_EQ(vjp_out[1].item<complex64_t>(), complex64_t{-3.2, -0.4});
}
}