implemented isposinf and isneginf in one PR (#470)

* ran precommit

* updated docs
This commit is contained in:
Yashraj Singh
2024-01-16 20:18:07 +05:30
committed by GitHub
parent a2ffea683a
commit e72458a3fa
6 changed files with 128 additions and 0 deletions

View File

@@ -1856,6 +1856,44 @@ void init_ops(py::module_& m) {
Returns:
array: The boolean array indicating which elements are +/- infinity.
)pbdoc");
m.def(
"isposinf",
&isposinf,
"a"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
isposinf(a: array, stream: Union[None, Stream, Device] = None) -> array
Return a boolean array indicating which elements are positive infinity.
Args:
a (array): Input array.
stream (Union[None, Stream, Device]): Optional stream or device.
Returns:
array: The boolean array indicating which elements are positive infinity.
)pbdoc");
m.def(
"isneginf",
&isneginf,
"a"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
isneginf(a: array, stream: Union[None, Stream, Device] = None) -> array
Return a boolean array indicating which elements are negative infinity.
Args:
a (array): Input array.
stream (Union[None, Stream, Device]): Optional stream or device.
Returns:
array: The boolean array indicating which elements are negative infinity.
)pbdoc");
m.def(
"moveaxis",
&moveaxis,

View File

@@ -401,6 +401,36 @@ class TestOps(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
mx.ceil(mx.array([22 + 3j, 19 + 98j]))
def test_isposinf(self):
x = mx.array([0.0, float("-inf")])
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
x = mx.array([0.0, float("-inf")]).astype(mx.float16)
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
x = mx.array([0.0, float("-inf")]).astype(mx.bfloat16)
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
x = mx.array([0.0, float("-inf")]).astype(mx.complex64)
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
self.assertEqual(mx.isposinf(0 * mx.array(float("inf"))).tolist(), False)
def test_isneginf(self):
x = mx.array([0.0, float("-inf")])
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
x = mx.array([0.0, float("-inf")]).astype(mx.float16)
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
x = mx.array([0.0, float("-inf")]).astype(mx.bfloat16)
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
x = mx.array([0.0, float("-inf")]).astype(mx.complex64)
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
self.assertEqual(mx.isneginf(0 * mx.array(float("inf"))).tolist(), False)
def test_round(self):
# float
x = mx.array(