mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
fix isinf for integer types (#494)
This commit is contained in:
parent
550d4bf7c0
commit
c4ec836523
12
mlx/ops.cpp
12
mlx/ops.cpp
@ -1105,18 +1105,30 @@ array array_equal(
|
||||
}
|
||||
|
||||
array isnan(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (is_integral(a.dtype())) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return not_equal(a, a, s);
|
||||
}
|
||||
|
||||
array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (is_integral(a.dtype())) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||
}
|
||||
|
||||
array isposinf(const array& a, StreamOrDevice s) {
|
||||
if (is_integral(a.dtype())) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||
}
|
||||
|
||||
array isneginf(const array& a, StreamOrDevice s) {
|
||||
if (is_integral(a.dtype())) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||
}
|
||||
|
||||
|
@ -351,6 +351,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertEqual(mx.isinf(0 * mx.array(float("inf"))).tolist(), False)
|
||||
|
||||
x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32)
|
||||
result = mx.isinf(x)
|
||||
self.assertEqual(result.tolist(), [False, False, False])
|
||||
|
||||
x = mx.array([-32768, 0, 32767], dtype=mx.int16)
|
||||
result = mx.isinf(x)
|
||||
self.assertEqual(result.tolist(), [False, False, False])
|
||||
|
||||
def test_tri(self):
|
||||
for shape in [[4], [4, 4], [2, 10]]:
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
@ -416,6 +424,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertEqual(mx.isposinf(0 * mx.array(float("inf"))).tolist(), False)
|
||||
|
||||
x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32)
|
||||
result = mx.isposinf(x)
|
||||
self.assertEqual(result.tolist(), [False, False, False])
|
||||
|
||||
x = mx.array([-32768, 0, 32767], dtype=mx.int16)
|
||||
result = mx.isposinf(x)
|
||||
self.assertEqual(result.tolist(), [False, False, False])
|
||||
|
||||
def test_isneginf(self):
|
||||
x = mx.array([0.0, float("-inf")])
|
||||
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
|
||||
@ -431,6 +447,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertEqual(mx.isneginf(0 * mx.array(float("inf"))).tolist(), False)
|
||||
|
||||
x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32)
|
||||
result = mx.isneginf(x)
|
||||
self.assertEqual(result.tolist(), [False, False, False])
|
||||
|
||||
x = mx.array([-32768, 0, 32767], dtype=mx.int16)
|
||||
result = mx.isneginf(x)
|
||||
self.assertEqual(result.tolist(), [False, False, False])
|
||||
|
||||
def test_round(self):
|
||||
# float
|
||||
x = mx.array(
|
||||
|
Loading…
Reference in New Issue
Block a user