diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md
index 784dc6329..6c151117c 100644
--- a/ACKNOWLEDGMENTS.md
+++ b/ACKNOWLEDGMENTS.md
@@ -15,7 +15,8 @@ MLX was developed with contributions from the following individuals:
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
-- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`.
+- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
+
diff --git a/python/src/array.cpp b/python/src/array.cpp
index 0ef27880b..789b8c00f 100644
--- a/python/src/array.cpp
+++ b/python/src/array.cpp
@@ -690,6 +690,9 @@ void init_array(nb::module_& m) {
.def(
"__add__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("addition", v);
+ }
auto b = to_array(v, a.dtype());
return add(a, b);
},
@@ -697,6 +700,9 @@ void init_array(nb::module_& m) {
.def(
"__iadd__",
[](array& a, const ScalarOrArray v) -> array& {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("inplace addition", v);
+ }
a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
return a;
},
@@ -705,18 +711,27 @@ void init_array(nb::module_& m) {
.def(
"__radd__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("addition", v);
+ }
return add(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__sub__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("subtraction", v);
+ }
return subtract(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__isub__",
[](array& a, const ScalarOrArray v) -> array& {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("inplace subtraction", v);
+ }
a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
return a;
},
@@ -725,18 +740,27 @@ void init_array(nb::module_& m) {
.def(
"__rsub__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("subtraction", v);
+ }
return subtract(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__mul__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("multiplication", v);
+ }
return multiply(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__imul__",
[](array& a, const ScalarOrArray v) -> array& {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("inplace multiplication", v);
+ }
a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
return a;
},
@@ -745,18 +769,27 @@ void init_array(nb::module_& m) {
.def(
"__rmul__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("multiplication", v);
+ }
return multiply(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__truediv__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("division", v);
+ }
return divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__itruediv__",
[](array& a, const ScalarOrArray v) -> array& {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("inplace division", v);
+ }
if (!issubdtype(a.dtype(), inexact)) {
throw std::invalid_argument(
"In place division cannot cast to non-floating point type.");
@@ -769,30 +802,45 @@ void init_array(nb::module_& m) {
.def(
"__rtruediv__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("division", v);
+ }
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__div__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("division", v);
+ }
return divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__rdiv__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("division", v);
+ }
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__floordiv__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("floor division", v);
+ }
return floor_divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__ifloordiv__",
[](array& a, const ScalarOrArray v) -> array& {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("inplace floor division", v);
+ }
a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
return a;
},
@@ -801,6 +849,9 @@ void init_array(nb::module_& m) {
.def(
"__rfloordiv__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("floor division", v);
+ }
auto b = to_array(v, a.dtype());
return floor_divide(b, a);
},
@@ -808,12 +859,18 @@ void init_array(nb::module_& m) {
.def(
"__mod__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("modulus", v);
+ }
return remainder(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__imod__",
[](array& a, const ScalarOrArray v) -> array& {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("inplace modulus", v);
+ }
a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
return a;
},
@@ -822,6 +879,9 @@ void init_array(nb::module_& m) {
.def(
"__rmod__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("modulus", v);
+ }
return remainder(to_array(v, a.dtype()), a);
},
"other"_a)
@@ -838,24 +898,36 @@ void init_array(nb::module_& m) {
.def(
"__lt__",
[](const array& a, const ScalarOrArray v) -> array {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("less than", v);
+ }
return less(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__le__",
[](const array& a, const ScalarOrArray v) -> array {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("less than or equal", v);
+ }
return less_equal(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__gt__",
[](const array& a, const ScalarOrArray v) -> array {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("greater than", v);
+ }
return greater(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__ge__",
[](const array& a, const ScalarOrArray v) -> array {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("greater than or equal", v);
+ }
return greater_equal(a, to_array(v, a.dtype()));
},
"other"_a)
@@ -897,18 +969,27 @@ void init_array(nb::module_& m) {
.def(
"__pow__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("power", v);
+ }
return power(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__rpow__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("power", v);
+ }
return power(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__ipow__",
[](array& a, const ScalarOrArray v) -> array& {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("inplace power", v);
+ }
a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
return a;
},
@@ -930,6 +1011,9 @@ void init_array(nb::module_& m) {
.def(
"__and__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("bitwise and", v);
+ }
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
@@ -946,6 +1030,9 @@ void init_array(nb::module_& m) {
.def(
"__iand__",
[](array& a, const ScalarOrArray v) -> array& {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("inplace bitwise and", v);
+ }
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
@@ -964,6 +1051,9 @@ void init_array(nb::module_& m) {
.def(
"__or__",
[](const array& a, const ScalarOrArray v) {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("bitwise or", v);
+ }
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
@@ -980,6 +1070,9 @@ void init_array(nb::module_& m) {
.def(
"__ior__",
[](array& a, const ScalarOrArray v) -> array& {
+ if (!is_comparable_with_array(v)) {
+ throw_invalid_operation("inplace bitwise or", v);
+ }
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
diff --git a/python/src/utils.h b/python/src/utils.h
index 41b422b98..35ac53a52 100644
--- a/python/src/utils.h
+++ b/python/src/utils.h
@@ -2,6 +2,7 @@
#pragma once
#include
#include
+#include
#include
#include
@@ -56,6 +57,19 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) {
}
}
+inline nb::handle get_handle_of_object(const ScalarOrArray& v) {
+ return std::get(v).ptr();
+}
+
+inline void throw_invalid_operation(
+ const std::string& operation,
+ const ScalarOrArray operand) {
+ std::ostringstream msg;
+ msg << "Cannot perform " << operation << " on an mlx.core.array and "
+ << nb::type_name(get_handle_of_object(operand).type()).c_str();
+ throw std::invalid_argument(msg.str());
+}
+
inline array to_array(
const ScalarOrArray& v,
std::optional dtype = std::nullopt) {
diff --git a/python/tests/test_array.py b/python/tests/test_array.py
index 8f95edf2b..912f3bbb1 100644
--- a/python/tests/test_array.py
+++ b/python/tests/test_array.py
@@ -203,6 +203,29 @@ class TestInequality(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
a >= tpl_
+ def test_invalid_op_on_array(self):
+ str_ = "hello"
+ a = mx.array([1, 2.5, 3.25])
+ lst_ = [1, 2.1, 3.25]
+ tpl_ = (1, 2.5, 3.25)
+
+ with self.assertRaises(ValueError):
+ a * str_
+ with self.assertRaises(ValueError):
+ a *= str_
+ with self.assertRaises(ValueError):
+ a /= lst_
+ with self.assertRaises(ValueError):
+ a // lst_
+ with self.assertRaises(ValueError):
+ a % lst_
+ with self.assertRaises(ValueError):
+ a**tpl_
+ with self.assertRaises(ValueError):
+ a & tpl_
+ with self.assertRaises(ValueError):
+ a | str_
+
class TestArray(mlx_tests.MLXTestCase):
def test_array_basics(self):