Bump nanobind to 2.4 + fix (#1710)

* bump nanobind to 2.4 + fix

* fix
This commit is contained in:
Awni Hannun 2024-12-17 10:57:54 -08:00 committed by GitHub
parent a6b426422e
commit f110357aaa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 36 additions and 21 deletions

View File

@ -85,7 +85,7 @@ jobs:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install nanobind==2.4.0
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
@ -137,7 +137,7 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install nanobind==2.4.0
pip install numpy
pip install torch
pip install tensorflow
@ -226,7 +226,7 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install twine
@ -291,7 +291,7 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel

View File

@ -239,8 +239,7 @@ if(MLX_BUILD_PYTHON_BINDINGS)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()

View File

@ -18,8 +18,7 @@ find_package(
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)
# ----------------------------- Extensions -----------------------------

View File

@ -3,6 +3,6 @@ requires = [
"setuptools>=42",
"cmake>=3.25",
"mlx>=0.18.0",
"nanobind==2.2.0",
"nanobind==2.4.0",
]
build-backend = "setuptools.build_meta"

View File

@ -1,7 +1,7 @@
[build-system]
requires = [
"setuptools>=42",
"nanobind==2.2.0",
"nanobind==2.4.0",
"cmake>=3.25",
]
build-backend = "setuptools.build_meta"

View File

@ -193,6 +193,15 @@ void init_array(nb::module_& m) {
.def("maximum", &ArrayAt::maximum, "value"_a)
.def("minimum", &ArrayAt::minimum, "value"_a);
nb::class_<ArrayLike>(
m,
"ArrayLike",
R"pbdoc(
Any Python object which has an ``__mlx__array__`` method that
returns an :obj:`array`.
)pbdoc")
.def(nb::init_implicit<nb::object>());
nb::class_<ArrayPythonIterator>(
m,
"ArrayIterator",

View File

@ -477,7 +477,7 @@ mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
return mx::astype(*pv, t.value_or((*pv).dtype()));
} else {
auto arr = to_array_with_accessor(std::get<nb::object>(v));
auto arr = to_array_with_accessor(std::get<ArrayLike>(v).obj);
return mx::astype(arr, t.value_or(arr.dtype()));
}
}

View File

@ -12,6 +12,11 @@
namespace mx = mlx::core;
namespace nb = nanobind;
struct ArrayLike {
ArrayLike(nb::object obj) : obj(obj) {};
nb::object obj;
};
using ArrayInitType = std::variant<
nb::bool_,
nb::int_,
@ -23,7 +28,7 @@ using ArrayInitType = std::variant<
std::complex<float>,
nb::list,
nb::tuple,
nb::object>;
ArrayLike>;
mx::array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,

View File

@ -28,7 +28,7 @@ mx::array to_array(
pv) {
return nd_array_to_mlx(*pv, dtype);
} else {
return to_array_with_accessor(std::get<nb::object>(v));
return to_array_with_accessor(std::get<ArrayLike>(v).obj);
}
}
@ -42,14 +42,15 @@ std::pair<mx::array, mx::array> to_arrays(
// - If neither is an array convert to arrays but leave their types alone
auto is_mlx_array = [](const ScalarOrArray& x) {
return std::holds_alternative<mx::array>(x) ||
std::holds_alternative<nb::object>(x) &&
nb::hasattr(std::get<nb::object>(x), "__mlx_array__");
std::holds_alternative<ArrayLike>(x) &&
nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__");
};
auto get_mlx_array = [](const ScalarOrArray& x) {
if (auto px = std::get_if<mx::array>(&x); px) {
return *px;
} else {
return nb::cast<mx::array>(std::get<nb::object>(x).attr("__mlx_array__"));
return nb::cast<mx::array>(
std::get<ArrayLike>(x).obj.attr("__mlx_array__"));
}
};

View File

@ -11,6 +11,7 @@
#include <nanobind/stl/variant.h>
#include "mlx/array.h"
#include "python/src/convert.h"
namespace mx = mlx::core;
namespace nb = nanobind;
@ -25,7 +26,7 @@ using ScalarOrArray = std::variant<
// Must be above complex
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
std::complex<float>,
nb::object>;
ArrayLike>;
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
std::vector<int> axes;
@ -43,8 +44,9 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
inline bool is_comparable_with_array(const ScalarOrArray& v) {
// Checks if the value can be compared to an array (or is already an
// mlx array)
if (auto pv = std::get_if<nb::object>(&v); pv) {
return nb::isinstance<mx::array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
if (auto pv = std::get_if<ArrayLike>(&v); pv) {
auto obj = (*pv).obj;
return nb::isinstance<mx::array>(obj) || nb::hasattr(obj, "__mlx_array__");
} else {
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
// and can be compared to an array
@ -53,7 +55,7 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) {
}
inline nb::handle get_handle_of_object(const ScalarOrArray& v) {
return std::get<nb::object>(v).ptr();
return std::get<ArrayLike>(v).obj.ptr();
}
inline void throw_invalid_operation(

View File

@ -179,7 +179,7 @@ if __name__ == "__main__":
include_package_data=True,
extras_require={
"dev": [
"nanobind==2.2.0",
"nanobind==2.4.0",
"numpy",
"pre-commit",
"setuptools>=42",