diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index 554726363..5e4100c75 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -238,11 +238,11 @@ TEST_CASE("test grad") { auto x = array(1.); auto expfn = [](array input) { return exp(input); }; auto dfdx = grad(expfn); - CHECK_EQ(dfdx(x).item(), std::exp(1.0f)); + CHECK_EQ(dfdx(x).item(), doctest::Approx(std::exp(1.0f))); auto d2fdx2 = grad(grad(expfn)); - CHECK_EQ(d2fdx2(x).item(), std::exp(1.0f)); + CHECK_EQ(d2fdx2(x).item(), doctest::Approx(std::exp(1.0f))); auto d3fdx3 = grad(grad(grad(expfn))); - CHECK_EQ(d3fdx3(x).item(), std::exp(1.0f)); + CHECK_EQ(d3fdx3(x).item(), doctest::Approx(std::exp(1.0f))); } { @@ -393,7 +393,7 @@ TEST_CASE("test op vjps") { // Test exp { auto out = vjp([](array in) { return exp(in); }, array(1.0f), array(2.0f)); - CHECK_EQ(out.second.item(), 2.0f * std::exp(1.0f)); + CHECK_EQ(out.second.item(), doctest::Approx(2.0f * std::exp(1.0f))); } // Test sin diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index d1cd9552b..0521d9c25 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -929,7 +929,7 @@ TEST_CASE("test arithmetic unary ops") { // Input is irregularly strided x = broadcast_to(array(1.0f), {2, 2, 2}); - CHECK(array_equal(exp(x), full({2, 2, 2}, std::exp(1.0f))).item()); + CHECK(allclose(exp(x), full({2, 2, 2}, std::exp(1.0f))).item()); x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0]; auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1}); @@ -2016,7 +2016,7 @@ TEST_CASE("test power") { CHECK_EQ((array(false) ^ array(true)).item(), false); auto x = array(2.0f); - CHECK_EQ((x ^ 0.5).item(), std::pow(2.0f, 0.5f)); + CHECK_EQ((x ^ 0.5).item(), doctest::Approx(std::pow(2.0f, 0.5f))); CHECK_EQ((x ^ 2.0f).item(), 4.0f); CHECK(std::isnan((array(-1.0f) ^ 0.5).item()));