#include "mlx/utils.h" #include #include #include namespace py = pybind11; using namespace py::literals; using namespace mlx::core; // Slightly different from the original, with python context on init we are not // in the context yet. Only create the inner context on enter then delete on // exit. class PyStreamContext { public: PyStreamContext(StreamOrDevice s) : _inner(nullptr) { if (std::holds_alternative(s)) { throw std::runtime_error( "[StreamContext] Invalid argument, please specify a stream or device."); } _s = s; } void enter() { _inner = new StreamContext(_s); } void exit() { if (_inner != nullptr) { delete _inner; _inner = nullptr; } } private: StreamOrDevice _s; StreamContext* _inner; }; void init_utils(py::module_& m) { py::class_(m, "StreamContext", R"pbdoc( A context manager for setting the current device and stream. See :func:`stream` for usage. Args: s: The stream or device to set as the default. )pbdoc") .def(py::init(), "s"_a) .def("__enter__", [](PyStreamContext& scm) { scm.enter(); }) .def( "__exit__", [](PyStreamContext& scm, const std::optional& exc_type, const std::optional& exc_value, const std::optional& traceback) { scm.exit(); }); m.def( "stream", [](StreamOrDevice s) { return PyStreamContext(s); }, "s"_a, R"pbdoc( Create a context manager to set the default device and stream. Args: s: The :obj:`Stream` or :obj:`Device` to set as the default. Returns: A context manager that sets the default device and stream. Example: .. code-block::python import mlx.core as mx # Create a context manager for the default device and stream. with mx.stream(mx.cpu): # Operations here will use mx.cpu by default. pass )pbdoc"); }