From f110357aaaaf668cf55c7d5137204ead42702c12 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 17 Dec 2024 10:57:54 -0800 Subject: [PATCH] Bump nanobind to 2.4 + fix (#1710) * bump nanobind to 2.4 + fix * fix --- .circleci/config.yml | 8 ++++---- CMakeLists.txt | 3 +-- examples/extensions/CMakeLists.txt | 3 +-- examples/extensions/pyproject.toml | 2 +- pyproject.toml | 2 +- python/src/array.cpp | 9 +++++++++ python/src/convert.cpp | 2 +- python/src/convert.h | 7 ++++++- python/src/utils.cpp | 9 +++++---- python/src/utils.h | 10 ++++++---- setup.py | 2 +- 11 files changed, 36 insertions(+), 21 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index f1f4ca28e..559b0f4c6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -85,7 +85,7 @@ jobs: name: Install dependencies command: | pip install --upgrade cmake - pip install nanobind==2.2.0 + pip install nanobind==2.4.0 pip install numpy sudo apt-get update sudo apt-get install libblas-dev liblapack-dev liblapacke-dev @@ -137,7 +137,7 @@ jobs: source env/bin/activate pip install --upgrade pip pip install --upgrade cmake - pip install nanobind==2.2.0 + pip install nanobind==2.4.0 pip install numpy pip install torch pip install tensorflow @@ -226,7 +226,7 @@ jobs: source env/bin/activate pip install --upgrade pip pip install --upgrade cmake - pip install nanobind==2.2.0 + pip install nanobind==2.4.0 pip install --upgrade setuptools pip install numpy pip install twine @@ -291,7 +291,7 @@ jobs: source env/bin/activate pip install --upgrade pip pip install --upgrade cmake - pip install nanobind==2.2.0 + pip install nanobind==2.4.0 pip install --upgrade setuptools pip install numpy pip install auditwheel diff --git a/CMakeLists.txt b/CMakeLists.txt index 5fc960b5c..60cf657e2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,8 +239,7 @@ if(MLX_BUILD_PYTHON_BINDINGS) execute_process( COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir OUTPUT_STRIP_TRAILING_WHITESPACE - OUTPUT_VARIABLE NB_DIR) - list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") + OUTPUT_VARIABLE nanobind_ROOT) find_package(nanobind CONFIG REQUIRED) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) endif() diff --git a/examples/extensions/CMakeLists.txt b/examples/extensions/CMakeLists.txt index 1bdb03488..db2ba9b59 100644 --- a/examples/extensions/CMakeLists.txt +++ b/examples/extensions/CMakeLists.txt @@ -18,8 +18,7 @@ find_package( execute_process( COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir OUTPUT_STRIP_TRAILING_WHITESPACE - OUTPUT_VARIABLE NB_DIR) -list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") + OUTPUT_VARIABLE nanobind_ROOT) find_package(nanobind CONFIG REQUIRED) # ----------------------------- Extensions ----------------------------- diff --git a/examples/extensions/pyproject.toml b/examples/extensions/pyproject.toml index 5393fe35b..9f0f09752 100644 --- a/examples/extensions/pyproject.toml +++ b/examples/extensions/pyproject.toml @@ -3,6 +3,6 @@ requires = [ "setuptools>=42", "cmake>=3.25", "mlx>=0.18.0", - "nanobind==2.2.0", + "nanobind==2.4.0", ] build-backend = "setuptools.build_meta" diff --git a/pyproject.toml b/pyproject.toml index 7fac87b94..ad0d2e328 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [build-system] requires = [ "setuptools>=42", - "nanobind==2.2.0", + "nanobind==2.4.0", "cmake>=3.25", ] build-backend = "setuptools.build_meta" diff --git a/python/src/array.cpp b/python/src/array.cpp index 9d871af33..d1c56ae55 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -193,6 +193,15 @@ void init_array(nb::module_& m) { .def("maximum", &ArrayAt::maximum, "value"_a) .def("minimum", &ArrayAt::minimum, "value"_a); + nb::class_( + m, + "ArrayLike", + R"pbdoc( + Any Python object which has an ``__mlx__array__`` method that + returns an :obj:`array`. + )pbdoc") + .def(nb::init_implicit()); + nb::class_( m, "ArrayIterator", diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 04c4f05b6..67cc27bb8 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -477,7 +477,7 @@ mx::array create_array(ArrayInitType v, std::optional t) { } else if (auto pv = std::get_if(&v); pv) { return mx::astype(*pv, t.value_or((*pv).dtype())); } else { - auto arr = to_array_with_accessor(std::get(v)); + auto arr = to_array_with_accessor(std::get(v).obj); return mx::astype(arr, t.value_or(arr.dtype())); } } diff --git a/python/src/convert.h b/python/src/convert.h index 44a090c2b..f5016c8af 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -12,6 +12,11 @@ namespace mx = mlx::core; namespace nb = nanobind; +struct ArrayLike { + ArrayLike(nb::object obj) : obj(obj) {}; + nb::object obj; +}; + using ArrayInitType = std::variant< nb::bool_, nb::int_, @@ -23,7 +28,7 @@ using ArrayInitType = std::variant< std::complex, nb::list, nb::tuple, - nb::object>; + ArrayLike>; mx::array nd_array_to_mlx( nb::ndarray nd_array, diff --git a/python/src/utils.cpp b/python/src/utils.cpp index 959cd98a6..70dbb3ddc 100644 --- a/python/src/utils.cpp +++ b/python/src/utils.cpp @@ -28,7 +28,7 @@ mx::array to_array( pv) { return nd_array_to_mlx(*pv, dtype); } else { - return to_array_with_accessor(std::get(v)); + return to_array_with_accessor(std::get(v).obj); } } @@ -42,14 +42,15 @@ std::pair to_arrays( // - If neither is an array convert to arrays but leave their types alone auto is_mlx_array = [](const ScalarOrArray& x) { return std::holds_alternative(x) || - std::holds_alternative(x) && - nb::hasattr(std::get(x), "__mlx_array__"); + std::holds_alternative(x) && + nb::hasattr(std::get(x).obj, "__mlx_array__"); }; auto get_mlx_array = [](const ScalarOrArray& x) { if (auto px = std::get_if(&x); px) { return *px; } else { - return nb::cast(std::get(x).attr("__mlx_array__")); + return nb::cast( + std::get(x).obj.attr("__mlx_array__")); } }; diff --git a/python/src/utils.h b/python/src/utils.h index 38e474746..583d72a08 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -11,6 +11,7 @@ #include #include "mlx/array.h" +#include "python/src/convert.h" namespace mx = mlx::core; namespace nb = nanobind; @@ -25,7 +26,7 @@ using ScalarOrArray = std::variant< // Must be above complex nb::ndarray, std::complex, - nb::object>; + ArrayLike>; inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { std::vector axes; @@ -43,8 +44,9 @@ inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { inline bool is_comparable_with_array(const ScalarOrArray& v) { // Checks if the value can be compared to an array (or is already an // mlx array) - if (auto pv = std::get_if(&v); pv) { - return nb::isinstance(*pv) || nb::hasattr(*pv, "__mlx_array__"); + if (auto pv = std::get_if(&v); pv) { + auto obj = (*pv).obj; + return nb::isinstance(obj) || nb::hasattr(obj, "__mlx_array__"); } else { // If it's not an object, it's a scalar (nb::int_, nb::float_, etc.) // and can be compared to an array @@ -53,7 +55,7 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) { } inline nb::handle get_handle_of_object(const ScalarOrArray& v) { - return std::get(v).ptr(); + return std::get(v).obj.ptr(); } inline void throw_invalid_operation( diff --git a/setup.py b/setup.py index e336ad8cf..f6fa171ba 100644 --- a/setup.py +++ b/setup.py @@ -179,7 +179,7 @@ if __name__ == "__main__": include_package_data=True, extras_require={ "dev": [ - "nanobind==2.2.0", + "nanobind==2.4.0", "numpy", "pre-commit", "setuptools>=42",