diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 784dc6329..6c151117c 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -15,7 +15,8 @@ MLX was developed with contributions from the following individuals: - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays. - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` -- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. +- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`. + diff --git a/python/src/array.cpp b/python/src/array.cpp index 0ef27880b..789b8c00f 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -690,6 +690,9 @@ void init_array(nb::module_& m) { .def( "__add__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("addition", v); + } auto b = to_array(v, a.dtype()); return add(a, b); }, @@ -697,6 +700,9 @@ void init_array(nb::module_& m) { .def( "__iadd__", [](array& a, const ScalarOrArray v) -> array& { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("inplace addition", v); + } a.overwrite_descriptor(add(a, to_array(v, a.dtype()))); return a; }, @@ -705,18 +711,27 @@ void init_array(nb::module_& m) { .def( "__radd__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("addition", v); + } return add(a, to_array(v, a.dtype())); }, "other"_a) .def( "__sub__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("subtraction", v); + } return subtract(a, to_array(v, a.dtype())); }, "other"_a) .def( "__isub__", [](array& a, const ScalarOrArray v) -> array& { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("inplace subtraction", v); + } a.overwrite_descriptor(subtract(a, to_array(v, a.dtype()))); return a; }, @@ -725,18 +740,27 @@ void init_array(nb::module_& m) { .def( "__rsub__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("subtraction", v); + } return subtract(to_array(v, a.dtype()), a); }, "other"_a) .def( "__mul__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("multiplication", v); + } return multiply(a, to_array(v, a.dtype())); }, "other"_a) .def( "__imul__", [](array& a, const ScalarOrArray v) -> array& { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("inplace multiplication", v); + } a.overwrite_descriptor(multiply(a, to_array(v, a.dtype()))); return a; }, @@ -745,18 +769,27 @@ void init_array(nb::module_& m) { .def( "__rmul__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("multiplication", v); + } return multiply(a, to_array(v, a.dtype())); }, "other"_a) .def( "__truediv__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("division", v); + } return divide(a, to_array(v, a.dtype())); }, "other"_a) .def( "__itruediv__", [](array& a, const ScalarOrArray v) -> array& { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("inplace division", v); + } if (!issubdtype(a.dtype(), inexact)) { throw std::invalid_argument( "In place division cannot cast to non-floating point type."); @@ -769,30 +802,45 @@ void init_array(nb::module_& m) { .def( "__rtruediv__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("division", v); + } return divide(to_array(v, a.dtype()), a); }, "other"_a) .def( "__div__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("division", v); + } return divide(a, to_array(v, a.dtype())); }, "other"_a) .def( "__rdiv__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("division", v); + } return divide(to_array(v, a.dtype()), a); }, "other"_a) .def( "__floordiv__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("floor division", v); + } return floor_divide(a, to_array(v, a.dtype())); }, "other"_a) .def( "__ifloordiv__", [](array& a, const ScalarOrArray v) -> array& { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("inplace floor division", v); + } a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype()))); return a; }, @@ -801,6 +849,9 @@ void init_array(nb::module_& m) { .def( "__rfloordiv__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("floor division", v); + } auto b = to_array(v, a.dtype()); return floor_divide(b, a); }, @@ -808,12 +859,18 @@ void init_array(nb::module_& m) { .def( "__mod__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("modulus", v); + } return remainder(a, to_array(v, a.dtype())); }, "other"_a) .def( "__imod__", [](array& a, const ScalarOrArray v) -> array& { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("inplace modulus", v); + } a.overwrite_descriptor(remainder(a, to_array(v, a.dtype()))); return a; }, @@ -822,6 +879,9 @@ void init_array(nb::module_& m) { .def( "__rmod__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("modulus", v); + } return remainder(to_array(v, a.dtype()), a); }, "other"_a) @@ -838,24 +898,36 @@ void init_array(nb::module_& m) { .def( "__lt__", [](const array& a, const ScalarOrArray v) -> array { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("less than", v); + } return less(a, to_array(v, a.dtype())); }, "other"_a) .def( "__le__", [](const array& a, const ScalarOrArray v) -> array { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("less than or equal", v); + } return less_equal(a, to_array(v, a.dtype())); }, "other"_a) .def( "__gt__", [](const array& a, const ScalarOrArray v) -> array { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("greater than", v); + } return greater(a, to_array(v, a.dtype())); }, "other"_a) .def( "__ge__", [](const array& a, const ScalarOrArray v) -> array { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("greater than or equal", v); + } return greater_equal(a, to_array(v, a.dtype())); }, "other"_a) @@ -897,18 +969,27 @@ void init_array(nb::module_& m) { .def( "__pow__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("power", v); + } return power(a, to_array(v, a.dtype())); }, "other"_a) .def( "__rpow__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("power", v); + } return power(to_array(v, a.dtype()), a); }, "other"_a) .def( "__ipow__", [](array& a, const ScalarOrArray v) -> array& { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("inplace power", v); + } a.overwrite_descriptor(power(a, to_array(v, a.dtype()))); return a; }, @@ -930,6 +1011,9 @@ void init_array(nb::module_& m) { .def( "__and__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("bitwise and", v); + } auto b = to_array(v, a.dtype()); if (issubdtype(a.dtype(), inexact) || issubdtype(b.dtype(), inexact)) { @@ -946,6 +1030,9 @@ void init_array(nb::module_& m) { .def( "__iand__", [](array& a, const ScalarOrArray v) -> array& { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("inplace bitwise and", v); + } auto b = to_array(v, a.dtype()); if (issubdtype(a.dtype(), inexact) || issubdtype(b.dtype(), inexact)) { @@ -964,6 +1051,9 @@ void init_array(nb::module_& m) { .def( "__or__", [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("bitwise or", v); + } auto b = to_array(v, a.dtype()); if (issubdtype(a.dtype(), inexact) || issubdtype(b.dtype(), inexact)) { @@ -980,6 +1070,9 @@ void init_array(nb::module_& m) { .def( "__ior__", [](array& a, const ScalarOrArray v) -> array& { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("inplace bitwise or", v); + } auto b = to_array(v, a.dtype()); if (issubdtype(a.dtype(), inexact) || issubdtype(b.dtype(), inexact)) { diff --git a/python/src/utils.h b/python/src/utils.h index 41b422b98..35ac53a52 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -2,6 +2,7 @@ #pragma once #include #include +#include #include #include @@ -56,6 +57,19 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) { } } +inline nb::handle get_handle_of_object(const ScalarOrArray& v) { + return std::get(v).ptr(); +} + +inline void throw_invalid_operation( + const std::string& operation, + const ScalarOrArray operand) { + std::ostringstream msg; + msg << "Cannot perform " << operation << " on an mlx.core.array and " + << nb::type_name(get_handle_of_object(operand).type()).c_str(); + throw std::invalid_argument(msg.str()); +} + inline array to_array( const ScalarOrArray& v, std::optional dtype = std::nullopt) { diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 8f95edf2b..912f3bbb1 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -203,6 +203,29 @@ class TestInequality(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): a >= tpl_ + def test_invalid_op_on_array(self): + str_ = "hello" + a = mx.array([1, 2.5, 3.25]) + lst_ = [1, 2.1, 3.25] + tpl_ = (1, 2.5, 3.25) + + with self.assertRaises(ValueError): + a * str_ + with self.assertRaises(ValueError): + a *= str_ + with self.assertRaises(ValueError): + a /= lst_ + with self.assertRaises(ValueError): + a // lst_ + with self.assertRaises(ValueError): + a % lst_ + with self.assertRaises(ValueError): + a**tpl_ + with self.assertRaises(ValueError): + a & tpl_ + with self.assertRaises(ValueError): + a | str_ + class TestArray(mlx_tests.MLXTestCase): def test_array_basics(self):