mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 14:58:11 +08:00
@@ -1555,6 +1555,42 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The max of ``a`` and ``b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"floor",
|
||||
&mlx::core::floor,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
floor(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise floor.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
|
||||
Returns:
|
||||
array: The floor of ``a``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ceil",
|
||||
&mlx::core::ceil,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
ceil(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise ceil.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
|
||||
Returns:
|
||||
array: The ceil of ``a``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"transpose",
|
||||
[](const array& a,
|
||||
|
||||
@@ -334,6 +334,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = [1, -5, 10]
|
||||
self.assertListEqual(mx.maximum(x, y).tolist(), expected)
|
||||
|
||||
def test_floor(self):
|
||||
x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf])
|
||||
expected = [-23, 19, -27, 9, 0, -np.inf, np.inf]
|
||||
self.assertListEqual(mx.floor(x).tolist(), expected)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.floor(mx.array([22 + 3j, 19 + 98j]))
|
||||
|
||||
def test_ceil(self):
|
||||
x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf])
|
||||
expected = [-22, 20, -27, 9, 0, -np.inf, np.inf]
|
||||
self.assertListEqual(mx.ceil(x).tolist(), expected)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.floor(mx.array([22 + 3j, 19 + 98j]))
|
||||
|
||||
def test_transpose_noargs(self):
|
||||
x = mx.array([[0, 1, 1], [1, 0, 0]])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user