// Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include #include #include #include #include #include "mlx/array.h" namespace nb = nanobind; using namespace mlx::core; using IntOrVec = std::variant>; using ScalarOrArray = std:: variant, nb::object>; inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { std::vector axes; if (std::holds_alternative(v)) { axes.resize(dims); std::iota(axes.begin(), axes.end(), 0); } else if (auto pv = std::get_if(&v); pv) { axes.push_back(*pv); } else { axes = std::get>(v); } return axes; } inline array to_array_with_accessor(nb::object obj) { if (nb::isinstance(obj)) { return nb::cast(obj); } else if (nb::hasattr(obj, "__mlx_array__")) { return nb::cast(obj.attr("__mlx_array__")()); } else { std::ostringstream msg; msg << "Invalid type " << nb::type_name(obj.type()).c_str() << " received in array initialization."; throw std::invalid_argument(msg.str()); } } inline bool is_comparable_with_array(const ScalarOrArray& v) { // Checks if the value can be compared to an array (or is already an // mlx array) if (auto pv = std::get_if(&v); pv) { return nb::isinstance(*pv) || nb::hasattr(*pv, "__mlx_array__"); } else { // If it's not an object, it's a scalar (nb::int_, nb::float_, etc.) // and can be compared to an array return true; } } inline nb::handle get_handle_of_object(const ScalarOrArray& v) { return std::get(v).ptr(); } inline void throw_invalid_operation( const std::string& operation, const ScalarOrArray operand) { std::ostringstream msg; msg << "Cannot perform " << operation << " on an mlx.core.array and " << nb::type_name(get_handle_of_object(operand).type()).c_str(); throw std::invalid_argument(msg.str()); } inline array to_array( const ScalarOrArray& v, std::optional dtype = std::nullopt) { if (auto pv = std::get_if(&v); pv) { return array(nb::cast(*pv), dtype.value_or(bool_)); } else if (auto pv = std::get_if(&v); pv) { auto out_t = dtype.value_or(int32); // bool_ is an exception and is always promoted return array(nb::cast(*pv), (out_t == bool_) ? int32 : out_t); } else if (auto pv = std::get_if(&v); pv) { auto out_t = dtype.value_or(float32); return array( nb::cast(*pv), issubdtype(out_t, floating) ? out_t : float32); } else if (auto pv = std::get_if>(&v); pv) { return array(static_cast(*pv), complex64); } else { return to_array_with_accessor(std::get(v)); } } inline std::pair to_arrays( const ScalarOrArray& a, const ScalarOrArray& b) { // Four cases: // - If both a and b are arrays leave their types alone // - If a is an array but b is not, treat b as a weak python type // - If b is an array but a is not, treat a as a weak python type // - If neither is an array convert to arrays but leave their types alone if (auto pa = std::get_if(&a); pa) { auto arr_a = to_array_with_accessor(*pa); if (auto pb = std::get_if(&b); pb) { auto arr_b = to_array_with_accessor(*pb); return {arr_a, arr_b}; } return {arr_a, to_array(b, arr_a.dtype())}; } else if (auto pb = std::get_if(&b); pb) { auto arr_b = to_array_with_accessor(*pb); return {to_array(a, arr_b.dtype()), arr_b}; } else { return {to_array(a), to_array(b)}; } }