// Copyright © 2023-2024 Apple Inc. #include #include #include #include #include #include "mlx/stream.h" #include "mlx/utils.h" namespace nb = nanobind; using namespace nb::literals; using namespace mlx::core; // Create the StreamContext on enter and delete on exit. class PyStreamContext { public: PyStreamContext(StreamOrDevice s) : _inner(nullptr) { if (std::holds_alternative(s)) { throw std::runtime_error( "[StreamContext] Invalid argument, please specify a stream or device."); } _s = s; } void enter() { _inner = new StreamContext(_s); } void exit() { if (_inner != nullptr) { delete _inner; _inner = nullptr; } } private: StreamOrDevice _s; StreamContext* _inner; }; void init_stream(nb::module_& m) { nb::class_( m, "Stream", R"pbdoc( A stream for running operations on a given device. )pbdoc") .def_ro("device", &Stream::device) .def( "__repr__", [](const Stream& s) { std::ostringstream os; os << s; return os.str(); }) .def("__eq__", [](const Stream& s, const nb::object& other) { return nb::isinstance(other) && s == nb::cast(other); }); nb::implicitly_convertible(); m.def( "default_stream", &default_stream, "device"_a, R"pbdoc(Get the device's default stream.)pbdoc"); m.def( "set_default_stream", &set_default_stream, "stream"_a, R"pbdoc( Set the default stream. This will make the given stream the default for the streams device. It will not change the default device. Args: stream (stream): Stream to make the default. )pbdoc"); m.def( "new_stream", &new_stream, "device"_a, R"pbdoc(Make a new stream on the given device.)pbdoc"); nb::class_(m, "StreamContext", R"pbdoc( A context manager for setting the current device and stream. See :func:`stream` for usage. Args: s: The stream or device to set as the default. )pbdoc") .def(nb::init(), "s"_a) .def("__enter__", [](PyStreamContext& scm) { scm.enter(); }) .def( "__exit__", [](PyStreamContext& scm, const std::optional& exc_type, const std::optional& exc_value, const std::optional& traceback) { scm.exit(); }, "exc_type"_a = nb::none(), "exc_value"_a = nb::none(), "traceback"_a = nb::none()); m.def( "stream", [](StreamOrDevice s) { return PyStreamContext(s); }, "s"_a, R"pbdoc( Create a context manager to set the default device and stream. Args: s: The :obj:`Stream` or :obj:`Device` to set as the default. Returns: A context manager that sets the default device and stream. Example: .. code-block::python import mlx.core as mx # Create a context manager for the default device and stream. with mx.stream(mx.cpu): # Operations here will use mx.cpu by default. pass )pbdoc"); m.def( "synchronize", [](const std::optional& s) { s ? synchronize(s.value()) : synchronize(); }, "stream"_a = nb::none(), R"pbdoc( Synchronize with the given stream. Args: stream (Stream, optional): The stream to synchronize with. If ``None`` then the default stream of the default device is used. Default: ``None``. )pbdoc"); }