implemented isposinf and isneginf in one PR (#470)

* ran precommit

* updated docs
This commit is contained in:
Yashraj Singh 2024-01-16 20:18:07 +05:30 committed by GitHub
parent a2ffea683a
commit e72458a3fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 128 additions and 0 deletions

View File

@ -52,6 +52,8 @@ Operations
identity identity
inner inner
isnan isnan
isposinf
isneginf
isinf isinf
less less
less_equal less_equal

View File

@ -1088,6 +1088,14 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s); return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
} }
array isposinf(const array& a, StreamOrDevice s) {
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
}
array isneginf(const array& a, StreamOrDevice s) {
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
}
array where( array where(
const array& condition, const array& condition,
const array& x, const array& x,

View File

@ -380,6 +380,10 @@ array isnan(const array& a, StreamOrDevice s = {});
array isinf(const array& a, StreamOrDevice s = {}); array isinf(const array& a, StreamOrDevice s = {});
array isposinf(const array& a, StreamOrDevice s = {});
array isneginf(const array& a, StreamOrDevice s = {});
/** Select from x or y depending on condition. */ /** Select from x or y depending on condition. */
array where( array where(
const array& condition, const array& condition,

View File

@ -1856,6 +1856,44 @@ void init_ops(py::module_& m) {
Returns: Returns:
array: The boolean array indicating which elements are +/- infinity. array: The boolean array indicating which elements are +/- infinity.
)pbdoc"); )pbdoc");
m.def(
"isposinf",
&isposinf,
"a"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
isposinf(a: array, stream: Union[None, Stream, Device] = None) -> array
Return a boolean array indicating which elements are positive infinity.
Args:
a (array): Input array.
stream (Union[None, Stream, Device]): Optional stream or device.
Returns:
array: The boolean array indicating which elements are positive infinity.
)pbdoc");
m.def(
"isneginf",
&isneginf,
"a"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
isneginf(a: array, stream: Union[None, Stream, Device] = None) -> array
Return a boolean array indicating which elements are negative infinity.
Args:
a (array): Input array.
stream (Union[None, Stream, Device]): Optional stream or device.
Returns:
array: The boolean array indicating which elements are negative infinity.
)pbdoc");
m.def( m.def(
"moveaxis", "moveaxis",
&moveaxis, &moveaxis,

View File

@ -401,6 +401,36 @@ class TestOps(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mx.ceil(mx.array([22 + 3j, 19 + 98j])) mx.ceil(mx.array([22 + 3j, 19 + 98j]))
def test_isposinf(self):
x = mx.array([0.0, float("-inf")])
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
x = mx.array([0.0, float("-inf")]).astype(mx.float16)
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
x = mx.array([0.0, float("-inf")]).astype(mx.bfloat16)
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
x = mx.array([0.0, float("-inf")]).astype(mx.complex64)
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
self.assertEqual(mx.isposinf(0 * mx.array(float("inf"))).tolist(), False)
def test_isneginf(self):
x = mx.array([0.0, float("-inf")])
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
x = mx.array([0.0, float("-inf")]).astype(mx.float16)
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
x = mx.array([0.0, float("-inf")]).astype(mx.bfloat16)
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
x = mx.array([0.0, float("-inf")]).astype(mx.complex64)
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
self.assertEqual(mx.isneginf(0 * mx.array(float("inf"))).tolist(), False)
def test_round(self): def test_round(self):
# float # float
x = mx.array( x = mx.array(

View File

@ -1880,6 +1880,52 @@ TEST_CASE("test scatter") {
CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>()); CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>());
} }
TEST_CASE("test is positive infinity") {
array x(1.0f);
CHECK_FALSE(isposinf(x).item<bool>());
array y(std::numeric_limits<float>::infinity());
CHECK(isposinf(y).item<bool>());
array z = identity(7);
CHECK_FALSE(all(isposinf(z)).item<bool>());
array w = array({1.0f, std::numeric_limits<float>::infinity(), 2.0f});
CHECK_FALSE(all(isposinf(w)).item<bool>());
array a(1.0f, bfloat16);
CHECK_FALSE(isposinf(a).item<bool>());
array b(std::numeric_limits<float>::infinity(), float16);
CHECK(isposinf(b).item<bool>());
array c(std::numeric_limits<float>::infinity(), bfloat16);
CHECK(isposinf(c).item<bool>());
}
TEST_CASE("test is negative infinity") {
array x(1.0f);
CHECK_FALSE(isneginf(x).item<bool>());
array y(-std::numeric_limits<float>::infinity());
CHECK(isneginf(y).item<bool>());
array z = identity(7);
CHECK_FALSE(all(isneginf(z)).item<bool>());
array w = array({1.0f, -std::numeric_limits<float>::infinity(), 2.0f});
CHECK_FALSE(all(isneginf(w)).item<bool>());
array a(1.0f, bfloat16);
CHECK_FALSE(isneginf(a).item<bool>());
array b(-std::numeric_limits<float>::infinity(), float16);
CHECK(isneginf(b).item<bool>());
array c(-std::numeric_limits<float>::infinity(), bfloat16);
CHECK(isneginf(c).item<bool>());
}
TEST_CASE("test scatter types") { TEST_CASE("test scatter types") {
for (auto t : {bool_, uint8, uint16, int8, int16}) { for (auto t : {bool_, uint8, uint16, int8, int16}) {
auto in = zeros({4, 4}, t); auto in = zeros({4, 4}, t);