mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
82 lines
2.1 KiB
C++
82 lines
2.1 KiB
C++
|
|
#include "mlx/utils.h"
|
|
#include <pybind11/pybind11.h>
|
|
#include <pybind11/stl.h>
|
|
#include <optional>
|
|
|
|
namespace py = pybind11;
|
|
using namespace py::literals;
|
|
using namespace mlx::core;
|
|
|
|
// Slightly different from the original, with python context on init we are not
|
|
// in the context yet. Only create the inner context on enter then delete on
|
|
// exit.
|
|
class PyStreamContext {
|
|
public:
|
|
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
|
|
if (std::holds_alternative<std::monostate>(s)) {
|
|
throw std::runtime_error(
|
|
"[StreamContext] Invalid argument, please specify a stream or device.");
|
|
}
|
|
_s = s;
|
|
}
|
|
|
|
void enter() {
|
|
_inner = new StreamContext(_s);
|
|
}
|
|
|
|
void exit() {
|
|
if (_inner != nullptr) {
|
|
delete _inner;
|
|
_inner = nullptr;
|
|
}
|
|
}
|
|
|
|
private:
|
|
StreamOrDevice _s;
|
|
StreamContext* _inner;
|
|
};
|
|
|
|
void init_utils(py::module_& m) {
|
|
py::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
|
|
A context manager for setting the current device and stream.
|
|
|
|
See :func:`stream` for usage.
|
|
|
|
Args:
|
|
s: The stream or device to set as the default.
|
|
)pbdoc")
|
|
.def(py::init<StreamOrDevice>(), "s"_a)
|
|
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
|
|
.def(
|
|
"__exit__",
|
|
[](PyStreamContext& scm,
|
|
const std::optional<py::type>& exc_type,
|
|
const std::optional<py::object>& exc_value,
|
|
const std::optional<py::object>& traceback) { scm.exit(); });
|
|
m.def(
|
|
"stream",
|
|
[](StreamOrDevice s) { return PyStreamContext(s); },
|
|
"s"_a,
|
|
R"pbdoc(
|
|
Create a context manager to set the default device and stream.
|
|
|
|
Args:
|
|
s: The :obj:`Stream` or :obj:`Device` to set as the default.
|
|
|
|
Returns:
|
|
A context manager that sets the default device and stream.
|
|
|
|
Example:
|
|
|
|
.. code-block::python
|
|
|
|
import mlx.core as mx
|
|
|
|
# Create a context manager for the default device and stream.
|
|
with mx.stream(mx.cpu):
|
|
# Operations here will use mx.cpu by default.
|
|
pass
|
|
)pbdoc");
|
|
}
|