mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 23:24:41 +08:00
Update cpp tests with allclose and doctest::Approx for numerical tolerance (#401)
This commit is contained in:
@@ -929,7 +929,7 @@ TEST_CASE("test arithmetic unary ops") {
|
||||
|
||||
// Input is irregularly strided
|
||||
x = broadcast_to(array(1.0f), {2, 2, 2});
|
||||
CHECK(array_equal(exp(x), full({2, 2, 2}, std::exp(1.0f))).item<bool>());
|
||||
CHECK(allclose(exp(x), full({2, 2, 2}, std::exp(1.0f))).item<bool>());
|
||||
|
||||
x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
|
||||
auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1});
|
||||
@@ -2016,7 +2016,7 @@ TEST_CASE("test power") {
|
||||
CHECK_EQ((array(false) ^ array(true)).item<bool>(), false);
|
||||
|
||||
auto x = array(2.0f);
|
||||
CHECK_EQ((x ^ 0.5).item<float>(), std::pow(2.0f, 0.5f));
|
||||
CHECK_EQ((x ^ 0.5).item<float>(), doctest::Approx(std::pow(2.0f, 0.5f)));
|
||||
CHECK_EQ((x ^ 2.0f).item<float>(), 4.0f);
|
||||
|
||||
CHECK(std::isnan((array(-1.0f) ^ 0.5).item<float>()));
|
||||
|
Reference in New Issue
Block a user