Bump nanobind to 2.4 + fix (#1710)

* bump nanobind to 2.4 + fix

* fix
This commit is contained in:
Awni Hannun
2024-12-17 10:57:54 -08:00
committed by GitHub
parent a6b426422e
commit f110357aaa
11 changed files with 36 additions and 21 deletions

View File

@@ -28,7 +28,7 @@ mx::array to_array(
pv) {
return nd_array_to_mlx(*pv, dtype);
} else {
return to_array_with_accessor(std::get<nb::object>(v));
return to_array_with_accessor(std::get<ArrayLike>(v).obj);
}
}
@@ -42,14 +42,15 @@ std::pair<mx::array, mx::array> to_arrays(
// - If neither is an array convert to arrays but leave their types alone
auto is_mlx_array = [](const ScalarOrArray& x) {
return std::holds_alternative<mx::array>(x) ||
std::holds_alternative<nb::object>(x) &&
nb::hasattr(std::get<nb::object>(x), "__mlx_array__");
std::holds_alternative<ArrayLike>(x) &&
nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__");
};
auto get_mlx_array = [](const ScalarOrArray& x) {
if (auto px = std::get_if<mx::array>(&x); px) {
return *px;
} else {
return nb::cast<mx::array>(std::get<nb::object>(x).attr("__mlx_array__"));
return nb::cast<mx::array>(
std::get<ArrayLike>(x).obj.attr("__mlx_array__"));
}
};