// Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include #include #include #include #include #include #include "mlx/array.h" #include "python/src/convert.h" namespace mx = mlx::core; namespace nb = nanobind; using IntOrVec = std::variant>; using ScalarOrArray = std::variant< nb::bool_, nb::int_, nb::float_, // Must be above ndarray mx::array, // Must be above complex nb::ndarray, std::complex, ArrayLike>; inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { std::vector axes; if (std::holds_alternative(v)) { axes.resize(dims); std::iota(axes.begin(), axes.end(), 0); } else if (auto pv = std::get_if(&v); pv) { axes.push_back(*pv); } else { axes = std::get>(v); } return axes; } inline bool is_comparable_with_array(const ScalarOrArray& v) { // Checks if the value can be compared to an array (or is already an // mlx array) if (auto pv = std::get_if(&v); pv) { auto obj = (*pv).obj; return nb::isinstance(obj) || nb::hasattr(obj, "__mlx_array__"); } else { // If it's not an object, it's a scalar (nb::int_, nb::float_, etc.) // and can be compared to an array return true; } } inline nb::handle get_handle_of_object(const ScalarOrArray& v) { return std::get(v).obj.ptr(); } inline void throw_invalid_operation( const std::string& operation, const ScalarOrArray operand) { std::ostringstream msg; msg << "Cannot perform " << operation << " on an mlx.core.array and " << nb::type_name(get_handle_of_object(operand).type()).c_str(); throw std::invalid_argument(msg.str()); } mx::array to_array( const ScalarOrArray& v, std::optional dtype = std::nullopt); std::pair to_arrays( const ScalarOrArray& a, const ScalarOrArray& b); mx::array to_array_with_accessor(nb::object obj);