mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Adds round op and primitive (#203)
This commit is contained in:

committed by
GitHub

parent
477397bc98
commit
4d4af12c6f
@@ -1148,5 +1148,15 @@ void init_array(py::module_& m) {
|
||||
"reverse"_a = false,
|
||||
"inclusive"_a = true,
|
||||
"stream"_a = none,
|
||||
"See :func:`cummin`.");
|
||||
"See :func:`cummin`.")
|
||||
.def(
|
||||
"round",
|
||||
[](const array& a, int decimals, StreamOrDevice s) {
|
||||
return round(a, decimals, s);
|
||||
},
|
||||
py::pos_only(),
|
||||
"decimals"_a = 0,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
"See :func:`round`.");
|
||||
}
|
||||
|
@@ -2922,4 +2922,33 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
result (array): The output containing elements selected from ``x`` and ``y``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"round",
|
||||
[](const array& a, int decimals, StreamOrDevice s) {
|
||||
return round(a, decimals, s);
|
||||
},
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
"decimals"_a = 0,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
round(a: array, /, decimals: int = 0, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Round to the given number of decimals.
|
||||
|
||||
Bascially performs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
s = 10**decimals
|
||||
x = round(x * s) / s
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
decimals (int): Number of decimal places to round to. (default: 0)
|
||||
|
||||
Returns:
|
||||
result (array): An array of the same type as ``a`` rounded to the given number of decimals.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@@ -372,7 +372,35 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(mx.ceil(x).tolist(), expected)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.floor(mx.array([22 + 3j, 19 + 98j]))
|
||||
mx.ceil(mx.array([22 + 3j, 19 + 98j]))
|
||||
|
||||
def test_round(self):
|
||||
# float
|
||||
x = mx.array(
|
||||
[0.5, -0.5, 1.5, -1.5, -22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf]
|
||||
)
|
||||
expected = [1, -1, 2, -2, -22, 20, -27, 9, 0, -np.inf, np.inf]
|
||||
self.assertListEqual(mx.round(x).tolist(), expected)
|
||||
|
||||
# complex
|
||||
y = mx.round(mx.array([22.2 + 3.6j, 19.5 + 98.2j]))
|
||||
self.assertListEqual(y.tolist(), [22 + 4j, 20 + 98j])
|
||||
|
||||
# decimals
|
||||
y0 = mx.round(mx.array([15, 122], mx.int32), decimals=0)
|
||||
y1 = mx.round(mx.array([15, 122], mx.int32), decimals=-1)
|
||||
y2 = mx.round(mx.array([15, 122], mx.int32), decimals=-2)
|
||||
self.assertEqual(y0.dtype, mx.int32)
|
||||
self.assertEqual(y1.dtype, mx.int32)
|
||||
self.assertEqual(y2.dtype, mx.int32)
|
||||
self.assertListEqual(y0.tolist(), [15, 122])
|
||||
self.assertListEqual(y1.tolist(), [20, 120])
|
||||
self.assertListEqual(y2.tolist(), [0, 100])
|
||||
|
||||
y1 = mx.round(mx.array([1.537, 1.471], mx.float32), decimals=1)
|
||||
y2 = mx.round(mx.array([1.537, 1.471], mx.float32), decimals=2)
|
||||
self.assertTrue(mx.allclose(y1, mx.array([1.5, 1.5])))
|
||||
self.assertTrue(mx.allclose(y2, mx.array([1.54, 1.47])))
|
||||
|
||||
def test_transpose_noargs(self):
|
||||
x = mx.array([[0, 1, 1], [1, 0, 0]])
|
||||
|
Reference in New Issue
Block a user