Switch to nanobind (#839)

* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
This commit is contained in:
Awni Hannun
2024-03-18 20:12:25 -07:00
committed by GitHub
parent d39ed54f8e
commit 9a8ee00246
34 changed files with 2343 additions and 2344 deletions

View File

@@ -1,23 +1,22 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <numeric>
#include <optional>
#include <variant>
#include <pybind11/complex.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/complex.h>
#include <nanobind/stl/variant.h>
#include "mlx/array.h"
namespace py = pybind11;
namespace nb = nanobind;
using namespace mlx::core;
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray = std::
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
static constexpr std::monostate none{};
variant<nb::bool_, nb::int_, nb::float_, std::complex<float>, nb::object>;
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
std::vector<int> axes;
@@ -32,31 +31,36 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
return axes;
}
inline array to_array_with_accessor(py::object obj) {
if (py::hasattr(obj, "__mlx_array__")) {
return obj.attr("__mlx_array__")().cast<array>();
inline array to_array_with_accessor(nb::object obj) {
if (nb::hasattr(obj, "__mlx_array__")) {
return nb::cast<array>(obj.attr("__mlx_array__")());
} else if (nb::isinstance<array>(obj)) {
return nb::cast<array>(obj);
} else {
return obj.cast<array>();
std::ostringstream msg;
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
<< " received in array initialization.";
throw std::invalid_argument(msg.str());
}
}
inline array to_array(
const ScalarOrArray& v,
std::optional<Dtype> dtype = std::nullopt) {
if (auto pv = std::get_if<py::bool_>(&v); pv) {
return array(py::cast<bool>(*pv), dtype.value_or(bool_));
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return array(nb::cast<bool>(*pv), dtype.value_or(bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
auto out_t = dtype.value_or(int32);
// bool_ is an exception and is always promoted
return array(py::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
return array(nb::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
auto out_t = dtype.value_or(float32);
return array(
py::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32);
nb::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32);
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), complex64);
} else {
return to_array_with_accessor(std::get<py::object>(v));
return to_array_with_accessor(std::get<nb::object>(v));
}
}
@@ -68,14 +72,14 @@ inline std::pair<array, array> to_arrays(
// - If a is an array but b is not, treat b as a weak python type
// - If b is an array but a is not, treat a as a weak python type
// - If neither is an array convert to arrays but leave their types alone
if (auto pa = std::get_if<py::object>(&a); pa) {
if (auto pa = std::get_if<nb::object>(&a); pa) {
auto arr_a = to_array_with_accessor(*pa);
if (auto pb = std::get_if<py::object>(&b); pb) {
if (auto pb = std::get_if<nb::object>(&b); pb) {
auto arr_b = to_array_with_accessor(*pb);
return {arr_a, arr_b};
}
return {arr_a, to_array(b, arr_a.dtype())};
} else if (auto pb = std::get_if<py::object>(&b); pb) {
} else if (auto pb = std::get_if<nb::object>(&b); pb) {
auto arr_b = to_array_with_accessor(*pb);
return {to_array(a, arr_b.dtype()), arr_b};
} else {