mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
82 lines
2.1 KiB
C++
82 lines
2.1 KiB
C++
![]() |
|
||
|
#include "mlx/utils.h"
|
||
|
#include <pybind11/pybind11.h>
|
||
|
#include <pybind11/stl.h>
|
||
|
#include <optional>
|
||
|
|
||
|
namespace py = pybind11;
|
||
|
using namespace py::literals;
|
||
|
using namespace mlx::core;
|
||
|
|
||
|
// Slightly different from the original, with python context on init we are not
|
||
|
// in the context yet. Only create the inner context on enter then delete on
|
||
|
// exit.
|
||
|
class PyStreamContext {
|
||
|
public:
|
||
|
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
|
||
|
if (std::holds_alternative<std::monostate>(s)) {
|
||
|
throw std::runtime_error(
|
||
|
"[StreamContext] Invalid argument, please specify a stream or device.");
|
||
|
}
|
||
|
_s = s;
|
||
|
}
|
||
|
|
||
|
void enter() {
|
||
|
_inner = new StreamContext(_s);
|
||
|
}
|
||
|
|
||
|
void exit() {
|
||
|
if (_inner != nullptr) {
|
||
|
delete _inner;
|
||
|
_inner = nullptr;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
StreamOrDevice _s;
|
||
|
StreamContext* _inner;
|
||
|
};
|
||
|
|
||
|
void init_utils(py::module_& m) {
|
||
|
py::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
|
||
|
A context manager for setting the current device and stream.
|
||
|
|
||
|
See :func:`stream` for usage.
|
||
|
|
||
|
Args:
|
||
|
s: The stream or device to set as the default.
|
||
|
)pbdoc")
|
||
|
.def(py::init<StreamOrDevice>(), "s"_a)
|
||
|
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
|
||
|
.def(
|
||
|
"__exit__",
|
||
|
[](PyStreamContext& scm,
|
||
|
const std::optional<py::type>& exc_type,
|
||
|
const std::optional<py::object>& exc_value,
|
||
|
const std::optional<py::object>& traceback) { scm.exit(); });
|
||
|
m.def(
|
||
|
"stream",
|
||
|
[](StreamOrDevice s) { return PyStreamContext(s); },
|
||
|
"s"_a,
|
||
|
R"pbdoc(
|
||
|
Create a context manager to set the default device and stream.
|
||
|
|
||
|
Args:
|
||
|
s: The :obj:`Stream` or :obj:`Device` to set as the default.
|
||
|
|
||
|
Returns:
|
||
|
A context manager that sets the default device and stream.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
.. code-block::python
|
||
|
|
||
|
import mlx.core as mx
|
||
|
|
||
|
# Create a context manager for the default device and stream.
|
||
|
with mx.stream(mx.cpu):
|
||
|
# Operations here will use mx.cpu by default.
|
||
|
pass
|
||
|
)pbdoc");
|
||
|
}
|