mlx/python/src/utils.cpp
2024-02-14 14:14:58 -08:00

82 lines
2.1 KiB
C++

#include "mlx/utils.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <optional>
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
// Slightly different from the original, with python context on init we are not
// in the context yet. Only create the inner context on enter then delete on
// exit.
class PyStreamContext {
public:
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
if (std::holds_alternative<std::monostate>(s)) {
throw std::runtime_error(
"[StreamContext] Invalid argument, please specify a stream or device.");
}
_s = s;
}
void enter() {
_inner = new StreamContext(_s);
}
void exit() {
if (_inner != nullptr) {
delete _inner;
_inner = nullptr;
}
}
private:
StreamOrDevice _s;
StreamContext* _inner;
};
void init_utils(py::module_& m) {
py::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
A context manager for setting the current device and stream.
See :func:`stream` for usage.
Args:
s: The stream or device to set as the default.
)pbdoc")
.def(py::init<StreamOrDevice>(), "s"_a)
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
.def(
"__exit__",
[](PyStreamContext& scm,
const std::optional<py::type>& exc_type,
const std::optional<py::object>& exc_value,
const std::optional<py::object>& traceback) { scm.exit(); });
m.def(
"stream",
[](StreamOrDevice s) { return PyStreamContext(s); },
"s"_a,
R"pbdoc(
Create a context manager to set the default device and stream.
Args:
s: The :obj:`Stream` or :obj:`Device` to set as the default.
Returns:
A context manager that sets the default device and stream.
Example:
.. code-block::python
import mlx.core as mx
# Create a context manager for the default device and stream.
with mx.stream(mx.cpu):
# Operations here will use mx.cpu by default.
pass
)pbdoc");
}