mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	* Nicer exceptions for ops on non-arrays
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							3fc993f82d
						
					
				
				
					commit
					0caf35f4b8
				
			@@ -690,6 +690,9 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__add__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("addition", v);
 | 
			
		||||
            }
 | 
			
		||||
            auto b = to_array(v, a.dtype());
 | 
			
		||||
            return add(a, b);
 | 
			
		||||
          },
 | 
			
		||||
@@ -697,6 +700,9 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__iadd__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("inplace addition", v);
 | 
			
		||||
            }
 | 
			
		||||
            a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
 | 
			
		||||
            return a;
 | 
			
		||||
          },
 | 
			
		||||
@@ -705,18 +711,27 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__radd__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("addition", v);
 | 
			
		||||
            }
 | 
			
		||||
            return add(a, to_array(v, a.dtype()));
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__sub__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("subtraction", v);
 | 
			
		||||
            }
 | 
			
		||||
            return subtract(a, to_array(v, a.dtype()));
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__isub__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("inplace subtraction", v);
 | 
			
		||||
            }
 | 
			
		||||
            a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
 | 
			
		||||
            return a;
 | 
			
		||||
          },
 | 
			
		||||
@@ -725,18 +740,27 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__rsub__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("subtraction", v);
 | 
			
		||||
            }
 | 
			
		||||
            return subtract(to_array(v, a.dtype()), a);
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__mul__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("multiplication", v);
 | 
			
		||||
            }
 | 
			
		||||
            return multiply(a, to_array(v, a.dtype()));
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__imul__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("inplace multiplication", v);
 | 
			
		||||
            }
 | 
			
		||||
            a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
 | 
			
		||||
            return a;
 | 
			
		||||
          },
 | 
			
		||||
@@ -745,18 +769,27 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__rmul__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("multiplication", v);
 | 
			
		||||
            }
 | 
			
		||||
            return multiply(a, to_array(v, a.dtype()));
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__truediv__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("division", v);
 | 
			
		||||
            }
 | 
			
		||||
            return divide(a, to_array(v, a.dtype()));
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__itruediv__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("inplace division", v);
 | 
			
		||||
            }
 | 
			
		||||
            if (!issubdtype(a.dtype(), inexact)) {
 | 
			
		||||
              throw std::invalid_argument(
 | 
			
		||||
                  "In place division cannot cast to non-floating point type.");
 | 
			
		||||
@@ -769,30 +802,45 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__rtruediv__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("division", v);
 | 
			
		||||
            }
 | 
			
		||||
            return divide(to_array(v, a.dtype()), a);
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__div__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("division", v);
 | 
			
		||||
            }
 | 
			
		||||
            return divide(a, to_array(v, a.dtype()));
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__rdiv__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("division", v);
 | 
			
		||||
            }
 | 
			
		||||
            return divide(to_array(v, a.dtype()), a);
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__floordiv__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("floor division", v);
 | 
			
		||||
            }
 | 
			
		||||
            return floor_divide(a, to_array(v, a.dtype()));
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__ifloordiv__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("inplace floor division", v);
 | 
			
		||||
            }
 | 
			
		||||
            a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
 | 
			
		||||
            return a;
 | 
			
		||||
          },
 | 
			
		||||
@@ -801,6 +849,9 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__rfloordiv__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("floor division", v);
 | 
			
		||||
            }
 | 
			
		||||
            auto b = to_array(v, a.dtype());
 | 
			
		||||
            return floor_divide(b, a);
 | 
			
		||||
          },
 | 
			
		||||
@@ -808,12 +859,18 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__mod__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("modulus", v);
 | 
			
		||||
            }
 | 
			
		||||
            return remainder(a, to_array(v, a.dtype()));
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__imod__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("inplace modulus", v);
 | 
			
		||||
            }
 | 
			
		||||
            a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
 | 
			
		||||
            return a;
 | 
			
		||||
          },
 | 
			
		||||
@@ -822,6 +879,9 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__rmod__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("modulus", v);
 | 
			
		||||
            }
 | 
			
		||||
            return remainder(to_array(v, a.dtype()), a);
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
@@ -838,24 +898,36 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__lt__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) -> array {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("less than", v);
 | 
			
		||||
            }
 | 
			
		||||
            return less(a, to_array(v, a.dtype()));
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__le__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) -> array {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("less than or equal", v);
 | 
			
		||||
            }
 | 
			
		||||
            return less_equal(a, to_array(v, a.dtype()));
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__gt__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) -> array {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("greater than", v);
 | 
			
		||||
            }
 | 
			
		||||
            return greater(a, to_array(v, a.dtype()));
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__ge__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) -> array {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("greater than or equal", v);
 | 
			
		||||
            }
 | 
			
		||||
            return greater_equal(a, to_array(v, a.dtype()));
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
@@ -897,18 +969,27 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__pow__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("power", v);
 | 
			
		||||
            }
 | 
			
		||||
            return power(a, to_array(v, a.dtype()));
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__rpow__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("power", v);
 | 
			
		||||
            }
 | 
			
		||||
            return power(to_array(v, a.dtype()), a);
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__ipow__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("inplace power", v);
 | 
			
		||||
            }
 | 
			
		||||
            a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
 | 
			
		||||
            return a;
 | 
			
		||||
          },
 | 
			
		||||
@@ -930,6 +1011,9 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__and__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("bitwise and", v);
 | 
			
		||||
            }
 | 
			
		||||
            auto b = to_array(v, a.dtype());
 | 
			
		||||
            if (issubdtype(a.dtype(), inexact) ||
 | 
			
		||||
                issubdtype(b.dtype(), inexact)) {
 | 
			
		||||
@@ -946,6 +1030,9 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__iand__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("inplace bitwise and", v);
 | 
			
		||||
            }
 | 
			
		||||
            auto b = to_array(v, a.dtype());
 | 
			
		||||
            if (issubdtype(a.dtype(), inexact) ||
 | 
			
		||||
                issubdtype(b.dtype(), inexact)) {
 | 
			
		||||
@@ -964,6 +1051,9 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__or__",
 | 
			
		||||
          [](const array& a, const ScalarOrArray v) {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("bitwise or", v);
 | 
			
		||||
            }
 | 
			
		||||
            auto b = to_array(v, a.dtype());
 | 
			
		||||
            if (issubdtype(a.dtype(), inexact) ||
 | 
			
		||||
                issubdtype(b.dtype(), inexact)) {
 | 
			
		||||
@@ -980,6 +1070,9 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          "__ior__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            if (!is_comparable_with_array(v)) {
 | 
			
		||||
              throw_invalid_operation("inplace bitwise or", v);
 | 
			
		||||
            }
 | 
			
		||||
            auto b = to_array(v, a.dtype());
 | 
			
		||||
            if (issubdtype(a.dtype(), inexact) ||
 | 
			
		||||
                issubdtype(b.dtype(), inexact)) {
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
#include <numeric>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <variant>
 | 
			
		||||
 | 
			
		||||
#include <nanobind/nanobind.h>
 | 
			
		||||
@@ -56,6 +57,19 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) {
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline nb::handle get_handle_of_object(const ScalarOrArray& v) {
 | 
			
		||||
  return std::get<nb::object>(v).ptr();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline void throw_invalid_operation(
 | 
			
		||||
    const std::string& operation,
 | 
			
		||||
    const ScalarOrArray operand) {
 | 
			
		||||
  std::ostringstream msg;
 | 
			
		||||
  msg << "Cannot perform " << operation << " on an mlx.core.array and "
 | 
			
		||||
      << nb::type_name(get_handle_of_object(operand).type()).c_str();
 | 
			
		||||
  throw std::invalid_argument(msg.str());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline array to_array(
 | 
			
		||||
    const ScalarOrArray& v,
 | 
			
		||||
    std::optional<Dtype> dtype = std::nullopt) {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user