Switch to nanobind (#839)

* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
This commit is contained in:
Awni Hannun
2024-03-18 20:12:25 -07:00
committed by GitHub
parent d39ed54f8e
commit 9a8ee00246
34 changed files with 2343 additions and 2344 deletions

View File

@@ -1,81 +0,0 @@
#include "mlx/utils.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <optional>
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
// Slightly different from the original, with python context on init we are not
// in the context yet. Only create the inner context on enter then delete on
// exit.
class PyStreamContext {
public:
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
if (std::holds_alternative<std::monostate>(s)) {
throw std::runtime_error(
"[StreamContext] Invalid argument, please specify a stream or device.");
}
_s = s;
}
void enter() {
_inner = new StreamContext(_s);
}
void exit() {
if (_inner != nullptr) {
delete _inner;
_inner = nullptr;
}
}
private:
StreamOrDevice _s;
StreamContext* _inner;
};
void init_utils(py::module_& m) {
py::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
A context manager for setting the current device and stream.
See :func:`stream` for usage.
Args:
s: The stream or device to set as the default.
)pbdoc")
.def(py::init<StreamOrDevice>(), "s"_a)
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
.def(
"__exit__",
[](PyStreamContext& scm,
const std::optional<py::type>& exc_type,
const std::optional<py::object>& exc_value,
const std::optional<py::object>& traceback) { scm.exit(); });
m.def(
"stream",
[](StreamOrDevice s) { return PyStreamContext(s); },
"s"_a,
R"pbdoc(
Create a context manager to set the default device and stream.
Args:
s: The :obj:`Stream` or :obj:`Device` to set as the default.
Returns:
A context manager that sets the default device and stream.
Example:
.. code-block::python
import mlx.core as mx
# Create a context manager for the default device and stream.
with mx.stream(mx.cpu):
# Operations here will use mx.cpu by default.
pass
)pbdoc");
}