mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
implemented isposinf and isneginf in one PR (#470)
* ran precommit * updated docs
This commit is contained in:
parent
a2ffea683a
commit
e72458a3fa
@ -52,6 +52,8 @@ Operations
|
|||||||
identity
|
identity
|
||||||
inner
|
inner
|
||||||
isnan
|
isnan
|
||||||
|
isposinf
|
||||||
|
isneginf
|
||||||
isinf
|
isinf
|
||||||
less
|
less
|
||||||
less_equal
|
less_equal
|
||||||
|
@ -1088,6 +1088,14 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array isposinf(const array& a, StreamOrDevice s) {
|
||||||
|
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array isneginf(const array& a, StreamOrDevice s) {
|
||||||
|
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||||
|
}
|
||||||
|
|
||||||
array where(
|
array where(
|
||||||
const array& condition,
|
const array& condition,
|
||||||
const array& x,
|
const array& x,
|
||||||
|
@ -380,6 +380,10 @@ array isnan(const array& a, StreamOrDevice s = {});
|
|||||||
|
|
||||||
array isinf(const array& a, StreamOrDevice s = {});
|
array isinf(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array isposinf(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array isneginf(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Select from x or y depending on condition. */
|
/** Select from x or y depending on condition. */
|
||||||
array where(
|
array where(
|
||||||
const array& condition,
|
const array& condition,
|
||||||
|
@ -1856,6 +1856,44 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The boolean array indicating which elements are +/- infinity.
|
array: The boolean array indicating which elements are +/- infinity.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"isposinf",
|
||||||
|
&isposinf,
|
||||||
|
"a"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
isposinf(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Return a boolean array indicating which elements are positive infinity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
stream (Union[None, Stream, Device]): Optional stream or device.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The boolean array indicating which elements are positive infinity.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"isneginf",
|
||||||
|
&isneginf,
|
||||||
|
"a"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
isneginf(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Return a boolean array indicating which elements are negative infinity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
stream (Union[None, Stream, Device]): Optional stream or device.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The boolean array indicating which elements are negative infinity.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"moveaxis",
|
"moveaxis",
|
||||||
&moveaxis,
|
&moveaxis,
|
||||||
|
@ -401,6 +401,36 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.ceil(mx.array([22 + 3j, 19 + 98j]))
|
mx.ceil(mx.array([22 + 3j, 19 + 98j]))
|
||||||
|
|
||||||
|
def test_isposinf(self):
|
||||||
|
x = mx.array([0.0, float("-inf")])
|
||||||
|
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
|
||||||
|
|
||||||
|
x = mx.array([0.0, float("-inf")]).astype(mx.float16)
|
||||||
|
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
|
||||||
|
|
||||||
|
x = mx.array([0.0, float("-inf")]).astype(mx.bfloat16)
|
||||||
|
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
|
||||||
|
|
||||||
|
x = mx.array([0.0, float("-inf")]).astype(mx.complex64)
|
||||||
|
self.assertEqual(mx.isposinf(x).tolist(), [False, False])
|
||||||
|
|
||||||
|
self.assertEqual(mx.isposinf(0 * mx.array(float("inf"))).tolist(), False)
|
||||||
|
|
||||||
|
def test_isneginf(self):
|
||||||
|
x = mx.array([0.0, float("-inf")])
|
||||||
|
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
|
||||||
|
|
||||||
|
x = mx.array([0.0, float("-inf")]).astype(mx.float16)
|
||||||
|
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
|
||||||
|
|
||||||
|
x = mx.array([0.0, float("-inf")]).astype(mx.bfloat16)
|
||||||
|
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
|
||||||
|
|
||||||
|
x = mx.array([0.0, float("-inf")]).astype(mx.complex64)
|
||||||
|
self.assertEqual(mx.isneginf(x).tolist(), [False, True])
|
||||||
|
|
||||||
|
self.assertEqual(mx.isneginf(0 * mx.array(float("inf"))).tolist(), False)
|
||||||
|
|
||||||
def test_round(self):
|
def test_round(self):
|
||||||
# float
|
# float
|
||||||
x = mx.array(
|
x = mx.array(
|
||||||
|
@ -1880,6 +1880,52 @@ TEST_CASE("test scatter") {
|
|||||||
CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>());
|
CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test is positive infinity") {
|
||||||
|
array x(1.0f);
|
||||||
|
CHECK_FALSE(isposinf(x).item<bool>());
|
||||||
|
|
||||||
|
array y(std::numeric_limits<float>::infinity());
|
||||||
|
CHECK(isposinf(y).item<bool>());
|
||||||
|
|
||||||
|
array z = identity(7);
|
||||||
|
CHECK_FALSE(all(isposinf(z)).item<bool>());
|
||||||
|
|
||||||
|
array w = array({1.0f, std::numeric_limits<float>::infinity(), 2.0f});
|
||||||
|
CHECK_FALSE(all(isposinf(w)).item<bool>());
|
||||||
|
|
||||||
|
array a(1.0f, bfloat16);
|
||||||
|
CHECK_FALSE(isposinf(a).item<bool>());
|
||||||
|
|
||||||
|
array b(std::numeric_limits<float>::infinity(), float16);
|
||||||
|
CHECK(isposinf(b).item<bool>());
|
||||||
|
|
||||||
|
array c(std::numeric_limits<float>::infinity(), bfloat16);
|
||||||
|
CHECK(isposinf(c).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test is negative infinity") {
|
||||||
|
array x(1.0f);
|
||||||
|
CHECK_FALSE(isneginf(x).item<bool>());
|
||||||
|
|
||||||
|
array y(-std::numeric_limits<float>::infinity());
|
||||||
|
CHECK(isneginf(y).item<bool>());
|
||||||
|
|
||||||
|
array z = identity(7);
|
||||||
|
CHECK_FALSE(all(isneginf(z)).item<bool>());
|
||||||
|
|
||||||
|
array w = array({1.0f, -std::numeric_limits<float>::infinity(), 2.0f});
|
||||||
|
CHECK_FALSE(all(isneginf(w)).item<bool>());
|
||||||
|
|
||||||
|
array a(1.0f, bfloat16);
|
||||||
|
CHECK_FALSE(isneginf(a).item<bool>());
|
||||||
|
|
||||||
|
array b(-std::numeric_limits<float>::infinity(), float16);
|
||||||
|
CHECK(isneginf(b).item<bool>());
|
||||||
|
|
||||||
|
array c(-std::numeric_limits<float>::infinity(), bfloat16);
|
||||||
|
CHECK(isneginf(c).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("test scatter types") {
|
TEST_CASE("test scatter types") {
|
||||||
for (auto t : {bool_, uint8, uint16, int8, int16}) {
|
for (auto t : {bool_, uint8, uint16, int8, int16}) {
|
||||||
auto in = zeros({4, 4}, t);
|
auto in = zeros({4, 4}, t);
|
||||||
|
Loading…
Reference in New Issue
Block a user