// Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include #include #include #include #include #include #include "mlx/array.h" namespace nb = nanobind; using namespace mlx::core; using IntOrVec = std::variant>; using ScalarOrArray = std::variant< nb::bool_, nb::int_, nb::float_, // Must be above ndarray array, // Must be above complex nb::ndarray, std::complex, nb::object>; inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { std::vector axes; if (std::holds_alternative(v)) { axes.resize(dims); std::iota(axes.begin(), axes.end(), 0); } else if (auto pv = std::get_if(&v); pv) { axes.push_back(*pv); } else { axes = std::get>(v); } return axes; } inline bool is_comparable_with_array(const ScalarOrArray& v) { // Checks if the value can be compared to an array (or is already an // mlx array) if (auto pv = std::get_if(&v); pv) { return nb::isinstance(*pv) || nb::hasattr(*pv, "__mlx_array__"); } else { // If it's not an object, it's a scalar (nb::int_, nb::float_, etc.) // and can be compared to an array return true; } } inline nb::handle get_handle_of_object(const ScalarOrArray& v) { return std::get(v).ptr(); } inline void throw_invalid_operation( const std::string& operation, const ScalarOrArray operand) { std::ostringstream msg; msg << "Cannot perform " << operation << " on an mlx.core.array and " << nb::type_name(get_handle_of_object(operand).type()).c_str(); throw std::invalid_argument(msg.str()); } array to_array( const ScalarOrArray& v, std::optional dtype = std::nullopt); std::pair to_arrays( const ScalarOrArray& a, const ScalarOrArray& b); array to_array_with_accessor(nb::object obj);