From c4ec836523a9ad3e0dcb95f04824b70e543c2058 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 19 Jan 2024 05:31:10 -0800 Subject: [PATCH] fix isinf for integer types (#494) --- mlx/ops.cpp | 12 ++++++++++++ python/tests/test_ops.py | 24 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a67e2f220..7cd79dae4 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1105,18 +1105,30 @@ array array_equal( } array isnan(const array& a, StreamOrDevice s /* = {} */) { + if (is_integral(a.dtype())) { + return full(a.shape(), false, bool_, s); + } return not_equal(a, a, s); } array isinf(const array& a, StreamOrDevice s /* = {} */) { + if (is_integral(a.dtype())) { + return full(a.shape(), false, bool_, s); + } return equal(a, array(std::numeric_limits::infinity(), a.dtype()), s); } array isposinf(const array& a, StreamOrDevice s) { + if (is_integral(a.dtype())) { + return full(a.shape(), false, bool_, s); + } return equal(a, array(std::numeric_limits::infinity(), a.dtype()), s); } array isneginf(const array& a, StreamOrDevice s) { + if (is_integral(a.dtype())) { + return full(a.shape(), false, bool_, s); + } return equal(a, array(-std::numeric_limits::infinity(), a.dtype()), s); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 3206f1dcb..542f1540e 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -351,6 +351,14 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.isinf(0 * mx.array(float("inf"))).tolist(), False) + x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32) + result = mx.isinf(x) + self.assertEqual(result.tolist(), [False, False, False]) + + x = mx.array([-32768, 0, 32767], dtype=mx.int16) + result = mx.isinf(x) + self.assertEqual(result.tolist(), [False, False, False]) + def test_tri(self): for shape in [[4], [4, 4], [2, 10]]: for diag in [-1, 0, 1, -2]: @@ -416,6 +424,14 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.isposinf(0 * mx.array(float("inf"))).tolist(), False) + x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32) + result = mx.isposinf(x) + self.assertEqual(result.tolist(), [False, False, False]) + + x = mx.array([-32768, 0, 32767], dtype=mx.int16) + result = mx.isposinf(x) + self.assertEqual(result.tolist(), [False, False, False]) + def test_isneginf(self): x = mx.array([0.0, float("-inf")]) self.assertEqual(mx.isneginf(x).tolist(), [False, True]) @@ -431,6 +447,14 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.isneginf(0 * mx.array(float("inf"))).tolist(), False) + x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32) + result = mx.isneginf(x) + self.assertEqual(result.tolist(), [False, False, False]) + + x = mx.array([-32768, 0, 32767], dtype=mx.int16) + result = mx.isneginf(x) + self.assertEqual(result.tolist(), [False, False, False]) + def test_round(self): # float x = mx.array(