mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 14:59:22 +08:00
fix isinf for integer types (#494)
This commit is contained in:
parent
550d4bf7c0
commit
c4ec836523
12
mlx/ops.cpp
12
mlx/ops.cpp
@ -1105,18 +1105,30 @@ array array_equal(
|
|||||||
}
|
}
|
||||||
|
|
||||||
array isnan(const array& a, StreamOrDevice s /* = {} */) {
|
array isnan(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
if (is_integral(a.dtype())) {
|
||||||
|
return full(a.shape(), false, bool_, s);
|
||||||
|
}
|
||||||
return not_equal(a, a, s);
|
return not_equal(a, a, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
if (is_integral(a.dtype())) {
|
||||||
|
return full(a.shape(), false, bool_, s);
|
||||||
|
}
|
||||||
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array isposinf(const array& a, StreamOrDevice s) {
|
array isposinf(const array& a, StreamOrDevice s) {
|
||||||
|
if (is_integral(a.dtype())) {
|
||||||
|
return full(a.shape(), false, bool_, s);
|
||||||
|
}
|
||||||
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array isneginf(const array& a, StreamOrDevice s) {
|
array isneginf(const array& a, StreamOrDevice s) {
|
||||||
|
if (is_integral(a.dtype())) {
|
||||||
|
return full(a.shape(), false, bool_, s);
|
||||||
|
}
|
||||||
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
|
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -351,6 +351,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertEqual(mx.isinf(0 * mx.array(float("inf"))).tolist(), False)
|
self.assertEqual(mx.isinf(0 * mx.array(float("inf"))).tolist(), False)
|
||||||
|
|
||||||
|
x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32)
|
||||||
|
result = mx.isinf(x)
|
||||||
|
self.assertEqual(result.tolist(), [False, False, False])
|
||||||
|
|
||||||
|
x = mx.array([-32768, 0, 32767], dtype=mx.int16)
|
||||||
|
result = mx.isinf(x)
|
||||||
|
self.assertEqual(result.tolist(), [False, False, False])
|
||||||
|
|
||||||
def test_tri(self):
|
def test_tri(self):
|
||||||
for shape in [[4], [4, 4], [2, 10]]:
|
for shape in [[4], [4, 4], [2, 10]]:
|
||||||
for diag in [-1, 0, 1, -2]:
|
for diag in [-1, 0, 1, -2]:
|
||||||
@ -416,6 +424,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertEqual(mx.isposinf(0 * mx.array(float("inf"))).tolist(), False)
|
self.assertEqual(mx.isposinf(0 * mx.array(float("inf"))).tolist(), False)
|
||||||
|
|
||||||
|
x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32)
|
||||||
|
result = mx.isposinf(x)
|
||||||
|
self.assertEqual(result.tolist(), [False, False, False])
|
||||||
|
|
||||||
|
x = mx.array([-32768, 0, 32767], dtype=mx.int16)
|
||||||
|
result = mx.isposinf(x)
|
||||||
|
self.assertEqual(result.tolist(), [False, False, False])
|
||||||
|
|
||||||
def test_isneginf(self):
|
def test_isneginf(self):
|
||||||
x = mx.array([0.0, float("-inf")])
|
x = mx.array([0.0, float("-inf")])
|
||||||
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
|
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
|
||||||
@ -431,6 +447,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertEqual(mx.isneginf(0 * mx.array(float("inf"))).tolist(), False)
|
self.assertEqual(mx.isneginf(0 * mx.array(float("inf"))).tolist(), False)
|
||||||
|
|
||||||
|
x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32)
|
||||||
|
result = mx.isneginf(x)
|
||||||
|
self.assertEqual(result.tolist(), [False, False, False])
|
||||||
|
|
||||||
|
x = mx.array([-32768, 0, 32767], dtype=mx.int16)
|
||||||
|
result = mx.isneginf(x)
|
||||||
|
self.assertEqual(result.tolist(), [False, False, False])
|
||||||
|
|
||||||
def test_round(self):
|
def test_round(self):
|
||||||
# float
|
# float
|
||||||
x = mx.array(
|
x = mx.array(
|
||||||
|
Loading…
Reference in New Issue
Block a user