mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Switch to nanobind (#839)
* mostly builds * most tests pass * fix circle build * add back buffer protocol * includes * fix for py38 * limit to cpu device * include * fix stubs * move signatures for docs * stubgen + docs fix * doc for compiled function, comments
This commit is contained in:
@@ -1,81 +0,0 @@
|
||||
|
||||
#include "mlx/utils.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <optional>
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
// Slightly different from the original, with python context on init we are not
|
||||
// in the context yet. Only create the inner context on enter then delete on
|
||||
// exit.
|
||||
class PyStreamContext {
|
||||
public:
|
||||
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
|
||||
if (std::holds_alternative<std::monostate>(s)) {
|
||||
throw std::runtime_error(
|
||||
"[StreamContext] Invalid argument, please specify a stream or device.");
|
||||
}
|
||||
_s = s;
|
||||
}
|
||||
|
||||
void enter() {
|
||||
_inner = new StreamContext(_s);
|
||||
}
|
||||
|
||||
void exit() {
|
||||
if (_inner != nullptr) {
|
||||
delete _inner;
|
||||
_inner = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
StreamOrDevice _s;
|
||||
StreamContext* _inner;
|
||||
};
|
||||
|
||||
void init_utils(py::module_& m) {
|
||||
py::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
|
||||
A context manager for setting the current device and stream.
|
||||
|
||||
See :func:`stream` for usage.
|
||||
|
||||
Args:
|
||||
s: The stream or device to set as the default.
|
||||
)pbdoc")
|
||||
.def(py::init<StreamOrDevice>(), "s"_a)
|
||||
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
|
||||
.def(
|
||||
"__exit__",
|
||||
[](PyStreamContext& scm,
|
||||
const std::optional<py::type>& exc_type,
|
||||
const std::optional<py::object>& exc_value,
|
||||
const std::optional<py::object>& traceback) { scm.exit(); });
|
||||
m.def(
|
||||
"stream",
|
||||
[](StreamOrDevice s) { return PyStreamContext(s); },
|
||||
"s"_a,
|
||||
R"pbdoc(
|
||||
Create a context manager to set the default device and stream.
|
||||
|
||||
Args:
|
||||
s: The :obj:`Stream` or :obj:`Device` to set as the default.
|
||||
|
||||
Returns:
|
||||
A context manager that sets the default device and stream.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
# Create a context manager for the default device and stream.
|
||||
with mx.stream(mx.cpu):
|
||||
# Operations here will use mx.cpu by default.
|
||||
pass
|
||||
)pbdoc");
|
||||
}
|
||||
Reference in New Issue
Block a user