fix isinf for integer types (#494)

This commit is contained in:
Awni Hannun 2024-01-19 05:31:10 -08:00 committed by GitHub
parent 550d4bf7c0
commit c4ec836523
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 0 deletions

View File

@ -1105,18 +1105,30 @@ array array_equal(
}
array isnan(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
return full(a.shape(), false, bool_, s);
}
return not_equal(a, a, s);
}
array isinf(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
return full(a.shape(), false, bool_, s);
}
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
}
array isposinf(const array& a, StreamOrDevice s) {
if (is_integral(a.dtype())) {
return full(a.shape(), false, bool_, s);
}
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
}
array isneginf(const array& a, StreamOrDevice s) {
if (is_integral(a.dtype())) {
return full(a.shape(), false, bool_, s);
}
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
}

View File

@ -351,6 +351,14 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.isinf(0 * mx.array(float("inf"))).tolist(), False)
x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32)
result = mx.isinf(x)
self.assertEqual(result.tolist(), [False, False, False])
x = mx.array([-32768, 0, 32767], dtype=mx.int16)
result = mx.isinf(x)
self.assertEqual(result.tolist(), [False, False, False])
def test_tri(self):
for shape in [[4], [4, 4], [2, 10]]:
for diag in [-1, 0, 1, -2]:
@ -416,6 +424,14 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.isposinf(0 * mx.array(float("inf"))).tolist(), False)
x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32)
result = mx.isposinf(x)
self.assertEqual(result.tolist(), [False, False, False])
x = mx.array([-32768, 0, 32767], dtype=mx.int16)
result = mx.isposinf(x)
self.assertEqual(result.tolist(), [False, False, False])
def test_isneginf(self):
x = mx.array([0.0, float("-inf")])
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
@ -431,6 +447,14 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.isneginf(0 * mx.array(float("inf"))).tolist(), False)
x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32)
result = mx.isneginf(x)
self.assertEqual(result.tolist(), [False, False, False])
x = mx.array([-32768, 0, 32767], dtype=mx.int16)
result = mx.isneginf(x)
self.assertEqual(result.tolist(), [False, False, False])
def test_round(self):
# float
x = mx.array(