Switch to nanobind (#839)

* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
This commit is contained in:
Awni Hannun
2024-03-18 20:12:25 -07:00
committed by GitHub
parent d39ed54f8e
commit 9a8ee00246
34 changed files with 2343 additions and 2344 deletions

View File

@@ -1,32 +1,29 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <variant>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "mlx/linalg.h"
#include "python/src/load.h"
#include "python/src/utils.h"
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
using namespace mlx::core::linalg;
namespace {
py::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
const auto result = svd(a, s);
return py::make_tuple(result.at(0), result.at(1), result.at(2));
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
}
} // namespace
void init_linalg(py::module_& parent_module) {
py::options options;
options.disable_function_signatures();
void init_linalg(nb::module_& parent_module) {
auto m = parent_module.def_submodule(
"linalg", "mlx.core.linalg: linear algebra routines.");
@@ -59,16 +56,15 @@ void init_linalg(py::module_& parent_module) {
return norm(a, ord, axis, keepdims, stream);
}
},
"a"_a,
py::pos_only(),
"ord"_a = none,
"axis"_a = none,
nb::arg(),
"ord"_a = nb::none(),
"axis"_a = nb::none(),
"keepdims"_a = false,
py::kw_only(),
"stream"_a = none,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Matrix or vector norm.
This function computes vector or matrix norms depending on the value of
@@ -188,11 +184,11 @@ void init_linalg(py::module_& parent_module) {
"qr",
&qr,
"a"_a,
py::kw_only(),
"stream"_a = none,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)"),
R"pbdoc(
qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)
The QR factorization of the input matrix.
This function supports arrays with at least 2 dimensions. The matrices
@@ -221,11 +217,11 @@ void init_linalg(py::module_& parent_module) {
"svd",
&svd_helper,
"a"_a,
py::kw_only(),
"stream"_a = none,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)"),
R"pbdoc(
svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)
The Singular Value Decomposition (SVD) of the input matrix.
This function supports arrays with at least 2 dimensions. When the input
@@ -245,11 +241,11 @@ void init_linalg(py::module_& parent_module) {
"inv",
&inv,
"a"_a,
py::kw_only(),
"stream"_a = none,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array
Compute the inverse of a square matrix.
This function supports arrays with at least 2 dimensions. When the input