mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-09 13:33:56 +08:00
[Fix] mx.allclose bug with infinite values (#539)
* Added isclose op and fixed comparison with inf values * Added 'equal_nan' to match numpy * format * Add test * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Addressed CR comments * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * nits --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -514,6 +514,9 @@ TEST_CASE("test is inf") {
|
||||
array y(inf);
|
||||
CHECK(isinf(y).item<bool>());
|
||||
|
||||
auto neginf = -std::numeric_limits<float>::infinity();
|
||||
CHECK(isinf(array(neginf)).item<bool>());
|
||||
|
||||
array z = identity(7);
|
||||
CHECK_FALSE(any(isinf(z)).item<bool>());
|
||||
|
||||
@@ -545,6 +548,36 @@ TEST_CASE("test all close") {
|
||||
CHECK(allclose(x, y, 0.01, 0.1).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test is close") {
|
||||
{
|
||||
array a({1.0, std::numeric_limits<float>::infinity()});
|
||||
array b({1.0, std::numeric_limits<float>::infinity()});
|
||||
CHECK(array_equal(isclose(a, b), array({true, true})).item<bool>());
|
||||
}
|
||||
{
|
||||
array a({1.0, -std::numeric_limits<float>::infinity()});
|
||||
array b({1.0, -std::numeric_limits<float>::infinity()});
|
||||
CHECK(array_equal(isclose(a, b), array({true, true})).item<bool>());
|
||||
}
|
||||
{
|
||||
array a({1.0, std::numeric_limits<float>::infinity()});
|
||||
array b({1.0, -std::numeric_limits<float>::infinity()});
|
||||
CHECK(array_equal(isclose(a, b), array({true, false})).item<bool>());
|
||||
}
|
||||
{
|
||||
array a({1.0, std::nan("1"), std::nan("1")});
|
||||
array b({1.0, std::nan("1"), 2.0});
|
||||
CHECK(array_equal(isclose(a, b), array({true, false, false})).item<bool>());
|
||||
}
|
||||
{
|
||||
array a({1.0, std::nan("1"), std::nan("1")});
|
||||
array b({1.0, std::nan("1"), 2.0});
|
||||
CHECK(
|
||||
array_equal(isclose(a, b, 1e-5, 1e-8, true), array({true, true, false}))
|
||||
.item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test reduction ops") {
|
||||
// Check shapes and throws correctly
|
||||
{
|
||||
|
Reference in New Issue
Block a user