mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Switch to nanobind (#839)
* mostly builds * most tests pass * fix circle build * add back buffer protocol * includes * fix for py38 * limit to cpu device * include * fix stubs * move signatures for docs * stubgen + docs fix * doc for compiled function, comments
This commit is contained in:
@@ -1,32 +1,29 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <variant>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "mlx/linalg.h"
|
||||
|
||||
#include "python/src/load.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::linalg;
|
||||
|
||||
namespace {
|
||||
py::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
|
||||
nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
|
||||
const auto result = svd(a, s);
|
||||
return py::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void init_linalg(py::module_& parent_module) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
void init_linalg(nb::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"linalg", "mlx.core.linalg: linear algebra routines.");
|
||||
|
||||
@@ -59,16 +56,15 @@ void init_linalg(py::module_& parent_module) {
|
||||
return norm(a, ord, axis, keepdims, stream);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
"ord"_a = none,
|
||||
"axis"_a = none,
|
||||
nb::arg(),
|
||||
"ord"_a = nb::none(),
|
||||
"axis"_a = nb::none(),
|
||||
"keepdims"_a = false,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Matrix or vector norm.
|
||||
|
||||
This function computes vector or matrix norms depending on the value of
|
||||
@@ -188,11 +184,11 @@ void init_linalg(py::module_& parent_module) {
|
||||
"qr",
|
||||
&qr,
|
||||
"a"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)"),
|
||||
R"pbdoc(
|
||||
qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)
|
||||
|
||||
The QR factorization of the input matrix.
|
||||
|
||||
This function supports arrays with at least 2 dimensions. The matrices
|
||||
@@ -221,11 +217,11 @@ void init_linalg(py::module_& parent_module) {
|
||||
"svd",
|
||||
&svd_helper,
|
||||
"a"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)"),
|
||||
R"pbdoc(
|
||||
svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)
|
||||
|
||||
The Singular Value Decomposition (SVD) of the input matrix.
|
||||
|
||||
This function supports arrays with at least 2 dimensions. When the input
|
||||
@@ -245,11 +241,11 @@ void init_linalg(py::module_& parent_module) {
|
||||
"inv",
|
||||
&inv,
|
||||
"a"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Compute the inverse of a square matrix.
|
||||
|
||||
This function supports arrays with at least 2 dimensions. When the input
|
||||
|
||||
Reference in New Issue
Block a user