Update cpp tests with allclose and doctest::Approx for numerical tolerance (#401)

This commit is contained in:
Jagrit Digani 2024-01-08 09:35:05 -08:00 committed by GitHub
parent 73321b8097
commit 432ee5650b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 6 deletions

View File

@ -238,11 +238,11 @@ TEST_CASE("test grad") {
auto x = array(1.); auto x = array(1.);
auto expfn = [](array input) { return exp(input); }; auto expfn = [](array input) { return exp(input); };
auto dfdx = grad(expfn); auto dfdx = grad(expfn);
CHECK_EQ(dfdx(x).item<float>(), std::exp(1.0f)); CHECK_EQ(dfdx(x).item<float>(), doctest::Approx(std::exp(1.0f)));
auto d2fdx2 = grad(grad(expfn)); auto d2fdx2 = grad(grad(expfn));
CHECK_EQ(d2fdx2(x).item<float>(), std::exp(1.0f)); CHECK_EQ(d2fdx2(x).item<float>(), doctest::Approx(std::exp(1.0f)));
auto d3fdx3 = grad(grad(grad(expfn))); auto d3fdx3 = grad(grad(grad(expfn)));
CHECK_EQ(d3fdx3(x).item<float>(), std::exp(1.0f)); CHECK_EQ(d3fdx3(x).item<float>(), doctest::Approx(std::exp(1.0f)));
} }
{ {
@ -393,7 +393,7 @@ TEST_CASE("test op vjps") {
// Test exp // Test exp
{ {
auto out = vjp([](array in) { return exp(in); }, array(1.0f), array(2.0f)); auto out = vjp([](array in) { return exp(in); }, array(1.0f), array(2.0f));
CHECK_EQ(out.second.item<float>(), 2.0f * std::exp(1.0f)); CHECK_EQ(out.second.item<float>(), doctest::Approx(2.0f * std::exp(1.0f)));
} }
// Test sin // Test sin

View File

@ -929,7 +929,7 @@ TEST_CASE("test arithmetic unary ops") {
// Input is irregularly strided // Input is irregularly strided
x = broadcast_to(array(1.0f), {2, 2, 2}); x = broadcast_to(array(1.0f), {2, 2, 2});
CHECK(array_equal(exp(x), full({2, 2, 2}, std::exp(1.0f))).item<bool>()); CHECK(allclose(exp(x), full({2, 2, 2}, std::exp(1.0f))).item<bool>());
x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0]; x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1}); auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1});
@ -2016,7 +2016,7 @@ TEST_CASE("test power") {
CHECK_EQ((array(false) ^ array(true)).item<bool>(), false); CHECK_EQ((array(false) ^ array(true)).item<bool>(), false);
auto x = array(2.0f); auto x = array(2.0f);
CHECK_EQ((x ^ 0.5).item<float>(), std::pow(2.0f, 0.5f)); CHECK_EQ((x ^ 0.5).item<float>(), doctest::Approx(std::pow(2.0f, 0.5f)));
CHECK_EQ((x ^ 2.0f).item<float>(), 4.0f); CHECK_EQ((x ^ 2.0f).item<float>(), 4.0f);
CHECK(std::isnan((array(-1.0f) ^ 0.5).item<float>())); CHECK(std::isnan((array(-1.0f) ^ 0.5).item<float>()));