mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -28,7 +28,7 @@ mx::array to_array(
|
||||
pv) {
|
||||
return nd_array_to_mlx(*pv, dtype);
|
||||
} else {
|
||||
return to_array_with_accessor(std::get<nb::object>(v));
|
||||
return to_array_with_accessor(std::get<ArrayLike>(v).obj);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,14 +42,15 @@ std::pair<mx::array, mx::array> to_arrays(
|
||||
// - If neither is an array convert to arrays but leave their types alone
|
||||
auto is_mlx_array = [](const ScalarOrArray& x) {
|
||||
return std::holds_alternative<mx::array>(x) ||
|
||||
std::holds_alternative<nb::object>(x) &&
|
||||
nb::hasattr(std::get<nb::object>(x), "__mlx_array__");
|
||||
std::holds_alternative<ArrayLike>(x) &&
|
||||
nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__");
|
||||
};
|
||||
auto get_mlx_array = [](const ScalarOrArray& x) {
|
||||
if (auto px = std::get_if<mx::array>(&x); px) {
|
||||
return *px;
|
||||
} else {
|
||||
return nb::cast<mx::array>(std::get<nb::object>(x).attr("__mlx_array__"));
|
||||
return nb::cast<mx::array>(
|
||||
std::get<ArrayLike>(x).obj.attr("__mlx_array__"));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user