mlx/python/src/stream.cpp

146 lines
3.7 KiB
C++
Raw Normal View History

// Copyright © 2023-2024 Apple Inc.
2023-12-01 03:12:53 +08:00
2023-11-30 02:42:59 +08:00
#include <sstream>
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/variant.h>
2023-11-30 02:42:59 +08:00
#include "mlx/stream.h"
#include "mlx/utils.h"
namespace nb = nanobind;
using namespace nb::literals;
2023-11-30 02:42:59 +08:00
using namespace mlx::core;
// Create the StreamContext on enter and delete on exit.
class PyStreamContext {
public:
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
if (std::holds_alternative<std::monostate>(s)) {
throw std::runtime_error(
"[StreamContext] Invalid argument, please specify a stream or device.");
}
_s = s;
}
void enter() {
_inner = new StreamContext(_s);
}
void exit() {
if (_inner != nullptr) {
delete _inner;
_inner = nullptr;
}
}
private:
StreamOrDevice _s;
StreamContext* _inner;
};
void init_stream(nb::module_& m) {
nb::class_<Stream>(
2024-02-15 06:14:58 +08:00
m,
"Stream",
R"pbdoc(
A stream for running operations on a given device.
)pbdoc")
.def(nb::init<int, Device>(), "index"_a, "device"_a)
.def_ro("device", &Stream::device)
2023-11-30 02:42:59 +08:00
.def(
"__repr__",
[](const Stream& s) {
std::ostringstream os;
os << s;
return os.str();
})
.def("__eq__", [](const Stream& s1, const Stream& s2) {
return s1 == s2;
});
nb::implicitly_convertible<Device::DeviceType, Device>();
2023-11-30 02:42:59 +08:00
2024-02-15 06:14:58 +08:00
m.def(
"default_stream",
&default_stream,
"device"_a,
R"pbdoc(Get the device's default stream.)pbdoc");
m.def(
"set_default_stream",
&set_default_stream,
"stream"_a,
R"pbdoc(
Set the default stream.
This will make the given stream the default for the
streams device. It will not change the default device.
Args:
stream (stream): Stream to make the default.
)pbdoc");
m.def(
"new_stream",
&new_stream,
"device"_a,
R"pbdoc(Make a new stream on the given device.)pbdoc");
nb::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
A context manager for setting the current device and stream.
See :func:`stream` for usage.
Args:
s: The stream or device to set as the default.
)pbdoc")
.def(nb::init<StreamOrDevice>(), "s"_a)
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
.def(
"__exit__",
[](PyStreamContext& scm,
const std::optional<nb::type_object>& exc_type,
const std::optional<nb::object>& exc_value,
const std::optional<nb::object>& traceback) { scm.exit(); },
"exc_type"_a = nb::none(),
"exc_value"_a = nb::none(),
"traceback"_a = nb::none());
m.def(
"stream",
[](StreamOrDevice s) { return PyStreamContext(s); },
"s"_a,
R"pbdoc(
Create a context manager to set the default device and stream.
Args:
s: The :obj:`Stream` or :obj:`Device` to set as the default.
Returns:
A context manager that sets the default device and stream.
Example:
.. code-block::python
import mlx.core as mx
# Create a context manager for the default device and stream.
with mx.stream(mx.cpu):
# Operations here will use mx.cpu by default.
pass
)pbdoc");
m.def(
"synchronize",
[](const std::optional<Stream>& s) {
s ? synchronize(s.value()) : synchronize();
},
"stream"_a = nb::none(),
R"pbdoc(
Synchronize with the given stream.
Args:
(Stream, optional): The stream to synchronize with. If ``None`` then
the default stream of the default device is used. Default: ``None``.
)pbdoc");
2023-11-30 02:42:59 +08:00
}