Switch to nanobind (#839)

* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
This commit is contained in:
Awni Hannun
2024-03-18 20:12:25 -07:00
committed by GitHub
parent d39ed54f8e
commit 9a8ee00246
34 changed files with 2343 additions and 2344 deletions

View File

@@ -1,7 +1,10 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <chrono>
#include "python/src/utils.h"
@@ -9,8 +12,8 @@
#include "mlx/ops.h"
#include "mlx/random.h"
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
using namespace mlx::core::random;
@@ -25,22 +28,22 @@ class PyKeySequence {
}
array next() {
auto out = split(py::cast<array>(state_[0]));
auto out = split(nb::cast<array>(state_[0]));
state_[0] = out.first;
return out.second;
}
py::list state() {
nb::list state() {
return state_;
}
void release() {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
state_.release().dec_ref();
}
private:
py::list state_;
nb::list state_;
};
PyKeySequence& default_key() {
@@ -54,7 +57,7 @@ PyKeySequence& default_key() {
return ks;
}
void init_random(py::module_& parent_module) {
void init_random(nb::module_& parent_module) {
auto m = parent_module.def_submodule(
"random",
"mlx.core.random: functionality related to random number generation");
@@ -85,10 +88,10 @@ void init_random(py::module_& parent_module) {
)pbdoc");
m.def(
"split",
py::overload_cast<const array&, int, StreamOrDevice>(&random::split),
nb::overload_cast<const array&, int, StreamOrDevice>(&random::split),
"key"_a,
"num"_a = 2,
"stream"_a = none,
"stream"_a = nb::none(),
R"pbdoc(
Split a PRNG key into sub keys.
@@ -119,9 +122,9 @@ void init_random(py::module_& parent_module) {
"low"_a = 0,
"high"_a = 1,
"shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32},
"key"_a = none,
"stream"_a = none,
"dtype"_a.none() = float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Generate uniformly distributed random numbers.
@@ -151,11 +154,11 @@ void init_random(py::module_& parent_module) {
return normal(shape, type.value_or(float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32},
"dtype"_a.none() = float32,
"loc"_a = 0.0,
"scale"_a = 1.0,
"key"_a = none,
"stream"_a = none,
"key"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Generate normally distributed random numbers.
@@ -184,9 +187,9 @@ void init_random(py::module_& parent_module) {
"low"_a,
"high"_a,
"shape"_a = std::vector<int>{},
"dtype"_a = int32,
"key"_a = none,
"stream"_a = none,
"dtype"_a.none() = int32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Generate random integers from the given interval.
@@ -219,9 +222,9 @@ void init_random(py::module_& parent_module) {
}
},
"p"_a = 0.5,
"shape"_a = none,
"key"_a = none,
"stream"_a = none,
"shape"_a = nb::none(),
"key"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Generate Bernoulli random values.
@@ -259,10 +262,10 @@ void init_random(py::module_& parent_module) {
},
"lower"_a,
"upper"_a,
"shape"_a = none,
"dtype"_a = std::optional{float32},
"key"_a = none,
"stream"_a = none,
"shape"_a = nb::none(),
"dtype"_a.none() = float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Generate values from a truncated normal distribution.
@@ -292,9 +295,9 @@ void init_random(py::module_& parent_module) {
return gumbel(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32},
"stream"_a = none,
"key"_a = none,
"dtype"_a.none() = float32,
"stream"_a = nb::none(),
"key"_a = nb::none(),
R"pbdoc(
Sample from the standard Gumbel distribution.
@@ -331,10 +334,10 @@ void init_random(py::module_& parent_module) {
},
"logits"_a,
"axis"_a = -1,
"shape"_a = none,
"num_samples"_a = none,
"key"_a = none,
"stream"_a = none,
"shape"_a = nb::none(),
"num_samples"_a = nb::none(),
"key"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Sample from a categorical distribution.
@@ -359,6 +362,6 @@ void init_random(py::module_& parent_module) {
array: The ``shape``-sized output array with type ``uint32``.
)pbdoc");
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");
atexit.attr("register")(py::cpp_function([]() { default_key().release(); }));
auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([]() { default_key().release(); }));
}