make python array SupportsAbs conform (like numpy) (#624)

This commit is contained in:
Daniel Strobusch 2024-02-04 18:31:02 +01:00 committed by GitHub
parent 9852af1a19
commit 4fd2fb84a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 0 deletions

View File

@ -1110,6 +1110,7 @@ void init_array(py::module_& m) {
py::kw_only(),
"stream"_a = none,
"See :func:`abs`.")
.def("__abs__", &mlx::core::abs, "See :func:`abs`.")
.def(
"square",
&square,

View File

@ -696,6 +696,8 @@ class TestOps(mlx_tests.MLXTestCase):
expected = np.abs(a, dtype=np.float32)
self.assertTrue(np.allclose(result, expected))
self.assertTrue(np.allclose(a.abs(), abs(a)))
def test_negative(self):
a = mx.array([-1.0, 1.0, -2.0, 3.0])
result = mx.negative(a)