// Copyright © 2024 Apple Inc. #include "python/src/utils.h" #include "mlx/ops.h" #include "mlx/utils.h" #include "python/src/convert.h" mx::array to_array( const ScalarOrArray& v, std::optional dtype /* = std::nullopt */) { if (auto pv = std::get_if(&v); pv) { return mx::array(nb::cast(*pv), dtype.value_or(mx::bool_)); } else if (auto pv = std::get_if(&v); pv) { auto val = nb::cast(*pv); auto default_type = (val > std::numeric_limits::max() || val < std::numeric_limits::min()) ? mx::int64 : mx::int32; auto out_t = dtype.value_or(default_type); if (mx::issubdtype(out_t, mx::integer) && out_t.size() < 8) { auto info = mx::iinfo(out_t); if (val < info.min || val > static_cast(info.max)) { std::ostringstream msg; msg << "Converting " << val << " to " << out_t << " would result in overflow."; throw std::invalid_argument(msg.str()); } } // bool_ is an exception and is always promoted return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t); } else if (auto pv = std::get_if(&v); pv) { auto out_t = dtype.value_or(mx::float32); return mx::array( nb::cast(*pv), mx::issubdtype(out_t, mx::floating) ? out_t : mx::float32); } else if (auto pv = std::get_if>(&v); pv) { return mx::array(static_cast(*pv), mx::complex64); } else if (auto pv = std::get_if(&v); pv) { return *pv; } else if (auto pv = std::get_if< nb::ndarray>(&v); pv) { return nd_array_to_mlx(*pv, dtype); } else { return to_array_with_accessor(std::get(v).obj); } } std::pair to_arrays( const ScalarOrArray& a, const ScalarOrArray& b) { // Four cases: // - If both a and b are arrays leave their types alone // - If a is an array but b is not, treat b as a weak python type // - If b is an array but a is not, treat a as a weak python type // - If neither is an array convert to arrays but leave their types alone auto is_mlx_array = [](const ScalarOrArray& x) { return std::holds_alternative(x) || std::holds_alternative(x) && nb::hasattr(std::get(x).obj, "__mlx_array__"); }; auto get_mlx_array = [](const ScalarOrArray& x) { if (auto px = std::get_if(&x); px) { return *px; } else { return nb::cast( std::get(x).obj.attr("__mlx_array__")); } }; if (is_mlx_array(a)) { auto arr_a = get_mlx_array(a); if (is_mlx_array(b)) { auto arr_b = get_mlx_array(b); return {arr_a, arr_b}; } return {arr_a, to_array(b, arr_a.dtype())}; } else if (is_mlx_array(b)) { auto arr_b = get_mlx_array(b); return {to_array(a, arr_b.dtype()), arr_b}; } else { return {to_array(a), to_array(b)}; } } mx::array to_array_with_accessor(nb::object obj) { if (nb::isinstance(obj)) { return nb::cast(obj); } else if (nb::hasattr(obj, "__mlx_array__")) { return nb::cast(obj.attr("__mlx_array__")()); } else { std::ostringstream msg; msg << "Invalid type " << nb::type_name(obj.type()).c_str() << " received in array initialization."; throw std::invalid_argument(msg.str()); } }