// Copyright © 2023 Apple Inc. #pragma once #include #include #include #include #include #include "mlx/array.h" namespace py = pybind11; using namespace mlx::core; using IntOrVec = std::variant>; using ScalarOrArray = std::variant, array>; static constexpr std::monostate none{}; inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { std::vector axes; if (std::holds_alternative(v)) { axes.resize(dims); std::iota(axes.begin(), axes.end(), 0); } else if (auto pv = std::get_if(&v); pv) { axes.push_back(*pv); } else { axes = std::get>(v); } return axes; } inline array to_array( const ScalarOrArray& v, std::optional dtype = std::nullopt) { if (auto pv = std::get_if(&v); pv) { return array(py::cast(*pv), dtype.value_or(bool_)); } else if (auto pv = std::get_if(&v); pv) { auto out_t = dtype.value_or(int32); // bool_ is an exception and is always promoted return array(py::cast(*pv), (out_t == bool_) ? int32 : out_t); } else if (auto pv = std::get_if(&v); pv) { auto out_t = dtype.value_or(float32); return array( py::cast(*pv), is_floating_point(out_t) ? out_t : float32); } else if (auto pv = std::get_if>(&v); pv) { return array(static_cast(*pv), complex64); } else { return std::get(v); } } inline std::pair to_arrays( const ScalarOrArray& a, const ScalarOrArray& b) { // Four cases: // - If both a and b are arrays leave their types alone // - If a is an array but b is not, treat b as a weak python type // - If b is an array but a is not, treat a as a weak python type // - If neither is an array convert to arrays but leave their types alone if (auto pa = std::get_if(&a); pa) { if (auto pb = std::get_if(&b); pb) { return {*pa, *pb}; } return {*pa, to_array(b, pa->dtype())}; } else if (auto pb = std::get_if(&b); pb) { return {to_array(a, pb->dtype()), *pb}; } else { return {to_array(a), to_array(b)}; } }