[Fix] mx.allclose bug with infinite values (#539)

* Added isclose op and fixed comparison with inf values

* Added 'equal_nan' to match numpy

* format

* Add test

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Addressed CR comments

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Rifur13
2024-01-25 23:47:06 -05:00
committed by GitHub
parent 87b7fa9ba2
commit 2463496471
5 changed files with 143 additions and 11 deletions

View File

@@ -514,6 +514,9 @@ TEST_CASE("test is inf") {
array y(inf);
CHECK(isinf(y).item<bool>());
auto neginf = -std::numeric_limits<float>::infinity();
CHECK(isinf(array(neginf)).item<bool>());
array z = identity(7);
CHECK_FALSE(any(isinf(z)).item<bool>());
@@ -545,6 +548,36 @@ TEST_CASE("test all close") {
CHECK(allclose(x, y, 0.01, 0.1).item<bool>());
}
TEST_CASE("test is close") {
{
array a({1.0, std::numeric_limits<float>::infinity()});
array b({1.0, std::numeric_limits<float>::infinity()});
CHECK(array_equal(isclose(a, b), array({true, true})).item<bool>());
}
{
array a({1.0, -std::numeric_limits<float>::infinity()});
array b({1.0, -std::numeric_limits<float>::infinity()});
CHECK(array_equal(isclose(a, b), array({true, true})).item<bool>());
}
{
array a({1.0, std::numeric_limits<float>::infinity()});
array b({1.0, -std::numeric_limits<float>::infinity()});
CHECK(array_equal(isclose(a, b), array({true, false})).item<bool>());
}
{
array a({1.0, std::nan("1"), std::nan("1")});
array b({1.0, std::nan("1"), 2.0});
CHECK(array_equal(isclose(a, b), array({true, false, false})).item<bool>());
}
{
array a({1.0, std::nan("1"), std::nan("1")});
array b({1.0, std::nan("1"), 2.0});
CHECK(
array_equal(isclose(a, b, 1e-5, 1e-8, true), array({true, true, false}))
.item<bool>());
}
}
TEST_CASE("test reduction ops") {
// Check shapes and throws correctly
{