From 9231617eb337307870617606e8724a2e6738c775 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 8 Aug 2024 17:17:46 -0700 Subject: [PATCH] Move to nanobind v2 (#1316) --- examples/extensions/requirements.txt | 4 ++-- python/src/array.cpp | 22 +++++++++++----------- setup.py | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/extensions/requirements.txt b/examples/extensions/requirements.txt index cecbc3338..4ab5eaf93 100644 --- a/examples/extensions/requirements.txt +++ b/examples/extensions/requirements.txt @@ -1,4 +1,4 @@ setuptools>=42 cmake>=3.24 -mlx>=0.9.0 -nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4 +mlx>=0.16.2 +nanobind==2.0 diff --git a/python/src/array.cpp b/python/src/array.cpp index cd42dc26e..22b8f69c1 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -126,7 +126,7 @@ void init_array(nb::module_& m) { m.attr("float32") = nb::cast(float32); m.attr("bfloat16") = nb::cast(bfloat16); m.attr("complex64") = nb::cast(complex64); - nb::class_( + nb::enum_( m, "DtypeCategory", R"pbdoc( @@ -165,16 +165,16 @@ void init_array(nb::module_& m) { * :ref:`complex64 ` See also :func:`~mlx.core.issubdtype`. - )pbdoc"); - m.attr("complexfloating") = nb::cast(complexfloating); - m.attr("floating") = nb::cast(floating); - m.attr("inexact") = nb::cast(inexact); - m.attr("signedinteger") = nb::cast(signedinteger); - m.attr("unsignedinteger") = nb::cast(unsignedinteger); - m.attr("integer") = nb::cast(integer); - m.attr("number") = nb::cast(number); - m.attr("generic") = nb::cast(generic); - + )pbdoc") + .value("complexfloating", complexfloating) + .value("floating", floating) + .value("inexact", inexact) + .value("signedinteger", signedinteger) + .value("unsignedinteger", unsignedinteger) + .value("integer", integer) + .value("number", number) + .value("generic", generic) + .export_values(); nb::class_( m, "_ArrayAt", diff --git a/setup.py b/setup.py index 2e699e171..7bd1be4f7 100644 --- a/setup.py +++ b/setup.py @@ -176,7 +176,7 @@ if __name__ == "__main__": include_package_data=True, extras_require={ "dev": [ - "nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4", + "nanobind==2.0", "numpy", "pre-commit", "setuptools>=42",