mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 06:38:38 +08:00
implemented isposinf and isneginf in one PR (#470)
* ran precommit * updated docs
This commit is contained in:
@@ -1856,6 +1856,44 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The boolean array indicating which elements are +/- infinity.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"isposinf",
|
||||
&isposinf,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
isposinf(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Return a boolean array indicating which elements are positive infinity.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
stream (Union[None, Stream, Device]): Optional stream or device.
|
||||
|
||||
Returns:
|
||||
array: The boolean array indicating which elements are positive infinity.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"isneginf",
|
||||
&isneginf,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
isneginf(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Return a boolean array indicating which elements are negative infinity.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
stream (Union[None, Stream, Device]): Optional stream or device.
|
||||
|
||||
Returns:
|
||||
array: The boolean array indicating which elements are negative infinity.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"moveaxis",
|
||||
&moveaxis,
|
||||
|
@@ -401,6 +401,36 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
mx.ceil(mx.array([22 + 3j, 19 + 98j]))
|
||||
|
||||
def test_isposinf(self):
|
||||
x = mx.array([0.0, float("-inf")])
|
||||
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
|
||||
|
||||
x = mx.array([0.0, float("-inf")]).astype(mx.float16)
|
||||
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
|
||||
|
||||
x = mx.array([0.0, float("-inf")]).astype(mx.bfloat16)
|
||||
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
|
||||
|
||||
x = mx.array([0.0, float("-inf")]).astype(mx.complex64)
|
||||
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
|
||||
|
||||
self.assertEqual(mx.isposinf(0 * mx.array(float("inf"))).tolist(), False)
|
||||
|
||||
def test_isneginf(self):
|
||||
x = mx.array([0.0, float("-inf")])
|
||||
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
|
||||
|
||||
x = mx.array([0.0, float("-inf")]).astype(mx.float16)
|
||||
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
|
||||
|
||||
x = mx.array([0.0, float("-inf")]).astype(mx.bfloat16)
|
||||
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
|
||||
|
||||
x = mx.array([0.0, float("-inf")]).astype(mx.complex64)
|
||||
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
|
||||
|
||||
self.assertEqual(mx.isneginf(0 * mx.array(float("inf"))).tolist(), False)
|
||||
|
||||
def test_round(self):
|
||||
# float
|
||||
x = mx.array(
|
||||
|
Reference in New Issue
Block a user