From 1d90a76d638bb11daf525e4680b38898176706e1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 9 Jan 2024 16:05:38 -0800 Subject: [PATCH] in place ops behave in place, fix some overloads (#411) --- python/src/array.cpp | 162 ++++++++++++++++++++++++++++++++----- python/tests/test_array.py | 59 ++++++++++++++ 2 files changed, 202 insertions(+), 19 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index bf2b09a3c..7a30612f0 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -711,6 +711,13 @@ void init_array(py::module_& m) { return add(a, to_array(v, a.dtype())); }, "other"_a) + .def( + "__iadd__", + [](array& a, const ScalarOrArray v) { + a.overwrite_descriptor(add(a, to_array(v, a.dtype()))); + return a; + }, + "other"_a) .def( "__radd__", [](const array& a, const ScalarOrArray v) { @@ -723,6 +730,13 @@ void init_array(py::module_& m) { return subtract(a, to_array(v, a.dtype())); }, "other"_a) + .def( + "__isub__", + [](array& a, const ScalarOrArray v) { + a.overwrite_descriptor(subtract(a, to_array(v, a.dtype()))); + return a; + }, + "other"_a) .def( "__rsub__", [](const array& a, const ScalarOrArray v) { @@ -735,6 +749,13 @@ void init_array(py::module_& m) { return multiply(a, to_array(v, a.dtype())); }, "other"_a) + .def( + "__imul__", + [](array& a, const ScalarOrArray v) { + a.overwrite_descriptor(multiply(a, to_array(v, a.dtype()))); + return a; + }, + "other"_a) .def( "__rmul__", [](const array& a, const ScalarOrArray v) { @@ -748,16 +769,14 @@ void init_array(py::module_& m) { }, "other"_a) .def( - "__div__", - [](const array& a, const ScalarOrArray v) { - return divide(a, to_array(v, a.dtype())); - }, - "other"_a) - .def( - "__floordiv__", - [](const array& a, const ScalarOrArray v) { - auto b = to_array(v, a.dtype()); - return floor_divide(a, b); + "__itruediv__", + [](array& a, const ScalarOrArray v) { + if (!is_floating_point(a.dtype())) { + throw std::invalid_argument( + "In place division cannot cast to non-floating point type."); + } + a.overwrite_descriptor(divide(a, to_array(v, a.dtype()))); + return a; }, "other"_a) .def( @@ -766,6 +785,31 @@ void init_array(py::module_& m) { return divide(to_array(v, a.dtype()), a); }, "other"_a) + .def( + "__div__", + [](const array& a, const ScalarOrArray v) { + return divide(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__rdiv__", + [](const array& a, const ScalarOrArray v) { + return divide(to_array(v, a.dtype()), a); + }, + "other"_a) + .def( + "__floordiv__", + [](const array& a, const ScalarOrArray v) { + return floor_divide(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__ifloordiv__", + [](array& a, const ScalarOrArray v) { + a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype()))); + return a; + }, + "other"_a) .def( "__rfloordiv__", [](const array& a, const ScalarOrArray v) { @@ -773,18 +817,19 @@ void init_array(py::module_& m) { return floor_divide(b, a); }, "other"_a) - .def( - "__rdiv__", - [](const array& a, const ScalarOrArray v) { - return divide(to_array(v, a.dtype()), a); - }, - "other"_a) .def( "__mod__", [](const array& a, const ScalarOrArray v) { return remainder(a, to_array(v, a.dtype())); }, "other"_a) + .def( + "__imod__", + [](array& a, const ScalarOrArray v) { + a.overwrite_descriptor(remainder(a, to_array(v, a.dtype()))); + return a; + }, + "other"_a) .def( "__rmod__", [](const array& a, const ScalarOrArray v) { @@ -840,23 +885,102 @@ void init_array(py::module_& m) { return os.str(); }) .def( - "__matmul__", [](array& a, array& other) { return matmul(a, other); }) + "__matmul__", + [](const array& a, array& other) { return matmul(a, other); }, + "other"_a) + .def( + "__imatmul__", + [](array& a, array& other) { + a.overwrite_descriptor(matmul(a, other)); + return a; + }, + "other"_a) .def( "__pow__", [](const array& a, const ScalarOrArray v) { return power(a, to_array(v, a.dtype())); }, "other"_a) + .def( + "__ipow__", + [](array& a, const ScalarOrArray v) { + a.overwrite_descriptor(power(a, to_array(v, a.dtype()))); + return a; + }, + "other"_a) + .def( + "__invert__", + [](const array& a) { + if (is_floating_point(a.dtype())) { + throw std::invalid_argument( + "Floating point types not allowed with or bitwise inversion."); + } + if (a.dtype() != bool_) { + throw std::invalid_argument( + "Bitwise inversion not yet supported for integer types."); + } + return logical_not(a); + }) .def( "__and__", [](const array& a, const ScalarOrArray v) { - return logical_and(a, to_array(v, a.dtype())); + auto b = to_array(v, a.dtype()); + if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { + throw std::invalid_argument( + "Floating point types not allowed with bitwise and."); + } + if (a.dtype() != bool_ && b.dtype() != bool_) { + throw std::invalid_argument( + "Bitwise and not yet supported for integer types."); + } + return logical_and(a, b); + }, + "other"_a) + .def( + "__iand__", + [](array& a, const ScalarOrArray v) { + auto b = to_array(v, a.dtype()); + if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { + throw std::invalid_argument( + "Floating point types not allowed with bitwise and."); + } + if (a.dtype() != bool_ && b.dtype() != bool_) { + throw std::invalid_argument( + "Bitwise and not yet supported for integer types."); + } + a.overwrite_descriptor(logical_and(a, b)); + return a; }, "other"_a) .def( "__or__", [](const array& a, const ScalarOrArray v) { - return logical_or(a, to_array(v, a.dtype())); + auto b = to_array(v, a.dtype()); + if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { + throw std::invalid_argument( + "Floating point types not allowed with or bitwise or."); + } + if (a.dtype() != bool_ && b.dtype() != bool_) { + throw std::invalid_argument( + "Bitwise or not yet supported for integer types."); + } + return logical_or(a, b); + }, + "other"_a) + .def( + "__ior__", + [](array& a, const ScalarOrArray v) { + auto b = to_array(v, a.dtype()); + if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { + throw std::invalid_argument( + "Floating point types not allowed with or bitwise or."); + } + if (a.dtype() != bool_ && b.dtype() != bool_) { + throw std::invalid_argument( + "Bitwise or not yet supported for integer types."); + } + a.overwrite_descriptor(logical_or(a, b)); + return a; }, "other"_a) .def( diff --git a/python/tests/test_array.py b/python/tests/test_array.py index a227c8eb1..de49d979f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1313,6 +1313,65 @@ class TestArray(mlx_tests.MLXTestCase): rtol=0, ) + def test_logical_overloads(self): + with self.assertRaises(ValueError): + mx.array(1.0) & mx.array(1) + with self.assertRaises(ValueError): + mx.array(1.0) | mx.array(1) + + self.assertEqual((mx.array(True) & True).item(), True) + self.assertEqual((mx.array(True) & False).item(), False) + self.assertEqual((mx.array(True) | False).item(), True) + self.assertEqual((mx.array(False) | False).item(), False) + self.assertEqual((~mx.array(False)).item(), True) + + def test_inplace(self): + iops = [ + "__iadd__", + "__isub__", + "__imul__", + "__ifloordiv__", + "__imod__", + "__ipow__", + ] + + for op in iops: + a = mx.array([1, 2, 3]) + a_np = np.array(a) + b = a + b = getattr(a, op)(3) + self.assertTrue(mx.array_equal(a, b)) + out_np = getattr(a_np, op)(3) + self.assertTrue(np.array_equal(out_np, a)) + + with self.assertRaises(ValueError): + a = mx.array([1]) + a /= 1 + + a = mx.array([2.0]) + b = a + b /= 2 + self.assertEqual(b.item(), 1.0) + self.assertEqual(b.item(), a.item()) + + a = mx.array(True) + b = a + b &= False + self.assertEqual(b.item(), False) + self.assertEqual(b.item(), a.item()) + + a = mx.array(False) + b = a + b |= True + self.assertEqual(b.item(), True) + self.assertEqual(b.item(), a.item()) + + # In-place matmul on its own + a = mx.array([[1.0, 2.0], [3.0, 4.0]]) + b = a + b @= a + self.assertTrue(mx.array_equal(a, b)) + if __name__ == "__main__": unittest.main()