mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
* Nicer exceptions for ops on non-arrays
This commit is contained in:

committed by
GitHub

parent
3fc993f82d
commit
0caf35f4b8
@@ -690,6 +690,9 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__add__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("addition", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
return add(a, b);
|
||||
},
|
||||
@@ -697,6 +700,9 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__iadd__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace addition", v);
|
||||
}
|
||||
a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
@@ -705,18 +711,27 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__radd__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("addition", v);
|
||||
}
|
||||
return add(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__sub__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("subtraction", v);
|
||||
}
|
||||
return subtract(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__isub__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace subtraction", v);
|
||||
}
|
||||
a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
@@ -725,18 +740,27 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__rsub__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("subtraction", v);
|
||||
}
|
||||
return subtract(to_array(v, a.dtype()), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__mul__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("multiplication", v);
|
||||
}
|
||||
return multiply(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__imul__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace multiplication", v);
|
||||
}
|
||||
a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
@@ -745,18 +769,27 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__rmul__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("multiplication", v);
|
||||
}
|
||||
return multiply(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__truediv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("division", v);
|
||||
}
|
||||
return divide(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__itruediv__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace division", v);
|
||||
}
|
||||
if (!issubdtype(a.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"In place division cannot cast to non-floating point type.");
|
||||
@@ -769,30 +802,45 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__rtruediv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("division", v);
|
||||
}
|
||||
return divide(to_array(v, a.dtype()), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__div__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("division", v);
|
||||
}
|
||||
return divide(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rdiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("division", v);
|
||||
}
|
||||
return divide(to_array(v, a.dtype()), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__floordiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("floor division", v);
|
||||
}
|
||||
return floor_divide(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ifloordiv__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace floor division", v);
|
||||
}
|
||||
a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
@@ -801,6 +849,9 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__rfloordiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("floor division", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
return floor_divide(b, a);
|
||||
},
|
||||
@@ -808,12 +859,18 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__mod__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("modulus", v);
|
||||
}
|
||||
return remainder(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__imod__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace modulus", v);
|
||||
}
|
||||
a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
@@ -822,6 +879,9 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__rmod__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("modulus", v);
|
||||
}
|
||||
return remainder(to_array(v, a.dtype()), a);
|
||||
},
|
||||
"other"_a)
|
||||
@@ -838,24 +898,36 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__lt__",
|
||||
[](const array& a, const ScalarOrArray v) -> array {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("less than", v);
|
||||
}
|
||||
return less(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__le__",
|
||||
[](const array& a, const ScalarOrArray v) -> array {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("less than or equal", v);
|
||||
}
|
||||
return less_equal(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__gt__",
|
||||
[](const array& a, const ScalarOrArray v) -> array {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("greater than", v);
|
||||
}
|
||||
return greater(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ge__",
|
||||
[](const array& a, const ScalarOrArray v) -> array {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("greater than or equal", v);
|
||||
}
|
||||
return greater_equal(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
@@ -897,18 +969,27 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__pow__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("power", v);
|
||||
}
|
||||
return power(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rpow__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("power", v);
|
||||
}
|
||||
return power(to_array(v, a.dtype()), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ipow__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace power", v);
|
||||
}
|
||||
a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
@@ -930,6 +1011,9 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__and__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("bitwise and", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
@@ -946,6 +1030,9 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__iand__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace bitwise and", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
@@ -964,6 +1051,9 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__or__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("bitwise or", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
@@ -980,6 +1070,9 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__ior__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace bitwise or", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#pragma once
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <variant>
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
@@ -56,6 +57,19 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
||||
}
|
||||
}
|
||||
|
||||
inline nb::handle get_handle_of_object(const ScalarOrArray& v) {
|
||||
return std::get<nb::object>(v).ptr();
|
||||
}
|
||||
|
||||
inline void throw_invalid_operation(
|
||||
const std::string& operation,
|
||||
const ScalarOrArray operand) {
|
||||
std::ostringstream msg;
|
||||
msg << "Cannot perform " << operation << " on an mlx.core.array and "
|
||||
<< nb::type_name(get_handle_of_object(operand).type()).c_str();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
inline array to_array(
|
||||
const ScalarOrArray& v,
|
||||
std::optional<Dtype> dtype = std::nullopt) {
|
||||
|
Reference in New Issue
Block a user