Floor and Ceil (#150)

* Implements Floor and Ceil Ops
This commit is contained in:
Luca Arnaboldi
2023-12-14 19:00:23 +01:00
committed by GitHub
parent 1e0c78b970
commit b93c4cf378
14 changed files with 250 additions and 4 deletions

View File

@@ -1555,6 +1555,42 @@ void init_ops(py::module_& m) {
Returns:
array: The max of ``a`` and ``b``.
)pbdoc");
m.def(
"floor",
&mlx::core::floor,
"a"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
floor(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise floor.
Args:
a (array): Input array.
Returns:
array: The floor of ``a``.
)pbdoc");
m.def(
"ceil",
&mlx::core::ceil,
"a"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
ceil(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise ceil.
Args:
a (array): Input array.
Returns:
array: The ceil of ``a``.
)pbdoc");
m.def(
"transpose",
[](const array& a,

View File

@@ -334,6 +334,22 @@ class TestOps(mlx_tests.MLXTestCase):
expected = [1, -5, 10]
self.assertListEqual(mx.maximum(x, y).tolist(), expected)
def test_floor(self):
x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf])
expected = [-23, 19, -27, 9, 0, -np.inf, np.inf]
self.assertListEqual(mx.floor(x).tolist(), expected)
with self.assertRaises(ValueError):
mx.floor(mx.array([22 + 3j, 19 + 98j]))
def test_ceil(self):
x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf])
expected = [-22, 20, -27, 9, 0, -np.inf, np.inf]
self.assertListEqual(mx.ceil(x).tolist(), expected)
with self.assertRaises(ValueError):
mx.floor(mx.array([22 + 3j, 19 + 98j]))
def test_transpose_noargs(self):
x = mx.array([[0, 1, 1], [1, 0, 0]])