diff --git a/python/src/array.cpp b/python/src/array.cpp index 7a1083c7f..ddf27e68f 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1110,6 +1110,7 @@ void init_array(py::module_& m) { py::kw_only(), "stream"_a = none, "See :func:`abs`.") + .def("__abs__", &mlx::core::abs, "See :func:`abs`.") .def( "square", &square, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 64152b537..2d5366a6d 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -696,6 +696,8 @@ class TestOps(mlx_tests.MLXTestCase): expected = np.abs(a, dtype=np.float32) self.assertTrue(np.allclose(result, expected)) + self.assertTrue(np.allclose(a.abs(), abs(a))) + def test_negative(self): a = mx.array([-1.0, 1.0, -2.0, 3.0]) result = mx.negative(a)