mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
Update cpp tests with allclose and doctest::Approx for numerical tolerance (#401)
This commit is contained in:
@@ -238,11 +238,11 @@ TEST_CASE("test grad") {
|
||||
auto x = array(1.);
|
||||
auto expfn = [](array input) { return exp(input); };
|
||||
auto dfdx = grad(expfn);
|
||||
CHECK_EQ(dfdx(x).item<float>(), std::exp(1.0f));
|
||||
CHECK_EQ(dfdx(x).item<float>(), doctest::Approx(std::exp(1.0f)));
|
||||
auto d2fdx2 = grad(grad(expfn));
|
||||
CHECK_EQ(d2fdx2(x).item<float>(), std::exp(1.0f));
|
||||
CHECK_EQ(d2fdx2(x).item<float>(), doctest::Approx(std::exp(1.0f)));
|
||||
auto d3fdx3 = grad(grad(grad(expfn)));
|
||||
CHECK_EQ(d3fdx3(x).item<float>(), std::exp(1.0f));
|
||||
CHECK_EQ(d3fdx3(x).item<float>(), doctest::Approx(std::exp(1.0f)));
|
||||
}
|
||||
|
||||
{
|
||||
@@ -393,7 +393,7 @@ TEST_CASE("test op vjps") {
|
||||
// Test exp
|
||||
{
|
||||
auto out = vjp([](array in) { return exp(in); }, array(1.0f), array(2.0f));
|
||||
CHECK_EQ(out.second.item<float>(), 2.0f * std::exp(1.0f));
|
||||
CHECK_EQ(out.second.item<float>(), doctest::Approx(2.0f * std::exp(1.0f)));
|
||||
}
|
||||
|
||||
// Test sin
|
||||
|
Reference in New Issue
Block a user