2024-05-07 07:02:49 +08:00
|
|
|
// Copyright © 2024 Apple Inc.
|
|
|
|
|
|
|
|
#include "python/src/utils.h"
|
|
|
|
#include "mlx/ops.h"
|
2025-03-28 10:54:56 +08:00
|
|
|
#include "mlx/utils.h"
|
2024-05-07 07:02:49 +08:00
|
|
|
#include "python/src/convert.h"
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::array to_array(
|
2024-05-07 07:02:49 +08:00
|
|
|
const ScalarOrArray& v,
|
2024-12-12 07:45:39 +08:00
|
|
|
std::optional<mx::Dtype> dtype /* = std::nullopt */) {
|
2024-05-07 07:02:49 +08:00
|
|
|
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::array(nb::cast<bool>(*pv), dtype.value_or(mx::bool_));
|
2024-05-07 07:02:49 +08:00
|
|
|
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
2025-02-13 11:23:46 +08:00
|
|
|
auto val = nb::cast<long>(*pv);
|
|
|
|
auto default_type = (val > std::numeric_limits<int>::max() ||
|
|
|
|
val < std::numeric_limits<int>::min())
|
|
|
|
? mx::int64
|
|
|
|
: mx::int32;
|
|
|
|
auto out_t = dtype.value_or(default_type);
|
2025-03-28 10:54:56 +08:00
|
|
|
if (mx::issubdtype(out_t, mx::integer) && out_t.size() < 8) {
|
|
|
|
auto info = mx::iinfo(out_t);
|
|
|
|
if (val < info.min || val > static_cast<int64_t>(info.max)) {
|
|
|
|
std::ostringstream msg;
|
|
|
|
msg << "Converting " << val << " to " << out_t
|
|
|
|
<< " would result in overflow.";
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-05-07 07:02:49 +08:00
|
|
|
// bool_ is an exception and is always promoted
|
2025-02-13 11:23:46 +08:00
|
|
|
return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t);
|
2024-05-07 07:02:49 +08:00
|
|
|
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
auto out_t = dtype.value_or(mx::float32);
|
|
|
|
return mx::array(
|
|
|
|
nb::cast<float>(*pv),
|
|
|
|
mx::issubdtype(out_t, mx::floating) ? out_t : mx::float32);
|
2024-05-07 07:02:49 +08:00
|
|
|
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::array(static_cast<mx::complex64_t>(*pv), mx::complex64);
|
|
|
|
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
|
2024-05-07 07:02:49 +08:00
|
|
|
return *pv;
|
|
|
|
} else if (auto pv = std::get_if<
|
|
|
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
|
|
|
|
pv) {
|
|
|
|
return nd_array_to_mlx(*pv, dtype);
|
|
|
|
} else {
|
2024-12-18 02:57:54 +08:00
|
|
|
return to_array_with_accessor(std::get<ArrayLike>(v).obj);
|
2024-05-07 07:02:49 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
std::pair<mx::array, mx::array> to_arrays(
|
2024-05-07 07:02:49 +08:00
|
|
|
const ScalarOrArray& a,
|
|
|
|
const ScalarOrArray& b) {
|
|
|
|
// Four cases:
|
|
|
|
// - If both a and b are arrays leave their types alone
|
|
|
|
// - If a is an array but b is not, treat b as a weak python type
|
|
|
|
// - If b is an array but a is not, treat a as a weak python type
|
|
|
|
// - If neither is an array convert to arrays but leave their types alone
|
|
|
|
auto is_mlx_array = [](const ScalarOrArray& x) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return std::holds_alternative<mx::array>(x) ||
|
2024-12-18 02:57:54 +08:00
|
|
|
std::holds_alternative<ArrayLike>(x) &&
|
|
|
|
nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__");
|
2024-05-07 07:02:49 +08:00
|
|
|
};
|
|
|
|
auto get_mlx_array = [](const ScalarOrArray& x) {
|
2024-12-12 07:45:39 +08:00
|
|
|
if (auto px = std::get_if<mx::array>(&x); px) {
|
2024-05-07 07:02:49 +08:00
|
|
|
return *px;
|
|
|
|
} else {
|
2024-12-18 02:57:54 +08:00
|
|
|
return nb::cast<mx::array>(
|
|
|
|
std::get<ArrayLike>(x).obj.attr("__mlx_array__"));
|
2024-05-07 07:02:49 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
if (is_mlx_array(a)) {
|
|
|
|
auto arr_a = get_mlx_array(a);
|
|
|
|
if (is_mlx_array(b)) {
|
|
|
|
auto arr_b = get_mlx_array(b);
|
|
|
|
return {arr_a, arr_b};
|
|
|
|
}
|
|
|
|
return {arr_a, to_array(b, arr_a.dtype())};
|
|
|
|
} else if (is_mlx_array(b)) {
|
|
|
|
auto arr_b = get_mlx_array(b);
|
|
|
|
return {to_array(a, arr_b.dtype()), arr_b};
|
|
|
|
} else {
|
|
|
|
return {to_array(a), to_array(b)};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::array to_array_with_accessor(nb::object obj) {
|
|
|
|
if (nb::isinstance<mx::array>(obj)) {
|
|
|
|
return nb::cast<mx::array>(obj);
|
2024-05-07 07:02:49 +08:00
|
|
|
} else if (nb::hasattr(obj, "__mlx_array__")) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return nb::cast<mx::array>(obj.attr("__mlx_array__")());
|
2024-05-07 07:02:49 +08:00
|
|
|
} else {
|
|
|
|
std::ostringstream msg;
|
|
|
|
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
|
|
|
<< " received in array initialization.";
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
}
|
|
|
|
}
|