mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
Switch to nanobind (#839)
* mostly builds * most tests pass * fix circle build * add back buffer protocol * includes * fix for py38 * limit to cpu device * include * fix stubs * move signatures for docs * stubgen + docs fix * doc for compiled function, comments
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <chrono>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
@@ -9,8 +12,8 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/random.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::random;
|
||||
|
||||
@@ -25,22 +28,22 @@ class PyKeySequence {
|
||||
}
|
||||
|
||||
array next() {
|
||||
auto out = split(py::cast<array>(state_[0]));
|
||||
auto out = split(nb::cast<array>(state_[0]));
|
||||
state_[0] = out.first;
|
||||
return out.second;
|
||||
}
|
||||
|
||||
py::list state() {
|
||||
nb::list state() {
|
||||
return state_;
|
||||
}
|
||||
|
||||
void release() {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
state_.release().dec_ref();
|
||||
}
|
||||
|
||||
private:
|
||||
py::list state_;
|
||||
nb::list state_;
|
||||
};
|
||||
|
||||
PyKeySequence& default_key() {
|
||||
@@ -54,7 +57,7 @@ PyKeySequence& default_key() {
|
||||
return ks;
|
||||
}
|
||||
|
||||
void init_random(py::module_& parent_module) {
|
||||
void init_random(nb::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"random",
|
||||
"mlx.core.random: functionality related to random number generation");
|
||||
@@ -85,10 +88,10 @@ void init_random(py::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"split",
|
||||
py::overload_cast<const array&, int, StreamOrDevice>(&random::split),
|
||||
nb::overload_cast<const array&, int, StreamOrDevice>(&random::split),
|
||||
"key"_a,
|
||||
"num"_a = 2,
|
||||
"stream"_a = none,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Split a PRNG key into sub keys.
|
||||
|
||||
@@ -119,9 +122,9 @@ void init_random(py::module_& parent_module) {
|
||||
"low"_a = 0,
|
||||
"high"_a = 1,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = std::optional{float32},
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
"dtype"_a.none() = float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Generate uniformly distributed random numbers.
|
||||
|
||||
@@ -151,11 +154,11 @@ void init_random(py::module_& parent_module) {
|
||||
return normal(shape, type.value_or(float32), loc, scale, key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = std::optional{float32},
|
||||
"dtype"_a.none() = float32,
|
||||
"loc"_a = 0.0,
|
||||
"scale"_a = 1.0,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Generate normally distributed random numbers.
|
||||
|
||||
@@ -184,9 +187,9 @@ void init_random(py::module_& parent_module) {
|
||||
"low"_a,
|
||||
"high"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = int32,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
"dtype"_a.none() = int32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Generate random integers from the given interval.
|
||||
|
||||
@@ -219,9 +222,9 @@ void init_random(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"p"_a = 0.5,
|
||||
"shape"_a = none,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
"shape"_a = nb::none(),
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Generate Bernoulli random values.
|
||||
|
||||
@@ -259,10 +262,10 @@ void init_random(py::module_& parent_module) {
|
||||
},
|
||||
"lower"_a,
|
||||
"upper"_a,
|
||||
"shape"_a = none,
|
||||
"dtype"_a = std::optional{float32},
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
"shape"_a = nb::none(),
|
||||
"dtype"_a.none() = float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Generate values from a truncated normal distribution.
|
||||
|
||||
@@ -292,9 +295,9 @@ void init_random(py::module_& parent_module) {
|
||||
return gumbel(shape, type.value_or(float32), key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = std::optional{float32},
|
||||
"stream"_a = none,
|
||||
"key"_a = none,
|
||||
"dtype"_a.none() = float32,
|
||||
"stream"_a = nb::none(),
|
||||
"key"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Sample from the standard Gumbel distribution.
|
||||
|
||||
@@ -331,10 +334,10 @@ void init_random(py::module_& parent_module) {
|
||||
},
|
||||
"logits"_a,
|
||||
"axis"_a = -1,
|
||||
"shape"_a = none,
|
||||
"num_samples"_a = none,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
"shape"_a = nb::none(),
|
||||
"num_samples"_a = nb::none(),
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Sample from a categorical distribution.
|
||||
|
||||
@@ -359,6 +362,6 @@ void init_random(py::module_& parent_module) {
|
||||
array: The ``shape``-sized output array with type ``uint32``.
|
||||
)pbdoc");
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() { default_key().release(); }));
|
||||
auto atexit = nb::module_::import_("atexit");
|
||||
atexit.attr("register")(nb::cpp_function([]() { default_key().release(); }));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user