mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 09:58:17 +08:00
Adds device context manager (#679)
This commit is contained in:
@@ -14,6 +14,7 @@ pybind11_add_module(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
||||
|
||||
@@ -12,7 +12,8 @@ using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_device(py::module_& m) {
|
||||
auto device_class = py::class_<Device>(m, "Device");
|
||||
auto device_class = py::class_<Device>(
|
||||
m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
|
||||
py::enum_<Device::DeviceType>(m, "DeviceType")
|
||||
.value("cpu", Device::DeviceType::cpu)
|
||||
.value("gpu", Device::DeviceType::gpu)
|
||||
@@ -39,6 +40,13 @@ void init_device(py::module_& m) {
|
||||
|
||||
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||
|
||||
m.def("default_device", &default_device);
|
||||
m.def("set_default_device", &set_default_device, "device"_a);
|
||||
m.def(
|
||||
"default_device",
|
||||
&default_device,
|
||||
R"pbdoc(Get the default device.)pbdoc");
|
||||
m.def(
|
||||
"set_default_device",
|
||||
&set_default_device,
|
||||
"device"_a,
|
||||
R"pbdoc(Set the default device.)pbdoc");
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ void init_fft(py::module_&);
|
||||
void init_linalg(py::module_&);
|
||||
void init_constants(py::module_&);
|
||||
void init_extensions(py::module_&);
|
||||
void init_utils(py::module_&);
|
||||
|
||||
PYBIND11_MODULE(core, m) {
|
||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||
@@ -35,5 +36,7 @@ PYBIND11_MODULE(core, m) {
|
||||
init_linalg(m);
|
||||
init_constants(m);
|
||||
init_extensions(m);
|
||||
init_utils(m);
|
||||
|
||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||
}
|
||||
|
||||
@@ -12,7 +12,12 @@ using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_stream(py::module_& m) {
|
||||
py::class_<Stream>(m, "Stream")
|
||||
py::class_<Stream>(
|
||||
m,
|
||||
"Stream",
|
||||
R"pbdoc(
|
||||
A stream for running operations on a given device.
|
||||
)pbdoc")
|
||||
.def(py::init<int, Device>(), "index"_a, "device"_a)
|
||||
.def_readonly("device", &Stream::device)
|
||||
.def(
|
||||
@@ -28,7 +33,27 @@ void init_stream(py::module_& m) {
|
||||
|
||||
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||
|
||||
m.def("default_stream", &default_stream, "device"_a);
|
||||
m.def("set_default_stream", &set_default_stream, "stream"_a);
|
||||
m.def("new_stream", &new_stream, "device"_a);
|
||||
m.def(
|
||||
"default_stream",
|
||||
&default_stream,
|
||||
"device"_a,
|
||||
R"pbdoc(Get the device's default stream.)pbdoc");
|
||||
m.def(
|
||||
"set_default_stream",
|
||||
&set_default_stream,
|
||||
"stream"_a,
|
||||
R"pbdoc(
|
||||
Set the default stream.
|
||||
|
||||
This will make the given stream the default for the
|
||||
streams device. It will not change the default device.
|
||||
|
||||
Args:
|
||||
stream (stream): Stream to make the default.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"new_stream",
|
||||
&new_stream,
|
||||
"device"_a,
|
||||
R"pbdoc(Make a new stream on the given device.)pbdoc");
|
||||
}
|
||||
|
||||
81
python/src/utils.cpp
Normal file
81
python/src/utils.cpp
Normal file
@@ -0,0 +1,81 @@
|
||||
|
||||
#include "mlx/utils.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <optional>
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
// Slightly different from the original, with python context on init we are not
|
||||
// in the context yet. Only create the inner context on enter then delete on
|
||||
// exit.
|
||||
class PyStreamContext {
|
||||
public:
|
||||
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
|
||||
if (std::holds_alternative<std::monostate>(s)) {
|
||||
throw std::runtime_error(
|
||||
"[StreamContext] Invalid argument, please specify a stream or device.");
|
||||
}
|
||||
_s = s;
|
||||
}
|
||||
|
||||
void enter() {
|
||||
_inner = new StreamContext(_s);
|
||||
}
|
||||
|
||||
void exit() {
|
||||
if (_inner != nullptr) {
|
||||
delete _inner;
|
||||
_inner = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
StreamOrDevice _s;
|
||||
StreamContext* _inner;
|
||||
};
|
||||
|
||||
void init_utils(py::module_& m) {
|
||||
py::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
|
||||
A context manager for setting the current device and stream.
|
||||
|
||||
See :func:`stream` for usage.
|
||||
|
||||
Args:
|
||||
s: The stream or device to set as the default.
|
||||
)pbdoc")
|
||||
.def(py::init<StreamOrDevice>(), "s"_a)
|
||||
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
|
||||
.def(
|
||||
"__exit__",
|
||||
[](PyStreamContext& scm,
|
||||
const std::optional<py::type>& exc_type,
|
||||
const std::optional<py::object>& exc_value,
|
||||
const std::optional<py::object>& traceback) { scm.exit(); });
|
||||
m.def(
|
||||
"stream",
|
||||
[](StreamOrDevice s) { return PyStreamContext(s); },
|
||||
"s"_a,
|
||||
R"pbdoc(
|
||||
Create a context manager to set the default device and stream.
|
||||
|
||||
Args:
|
||||
s: The :obj:`Stream` or :obj:`Device` to set as the default.
|
||||
|
||||
Returns:
|
||||
A context manager that sets the default device and stream.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
# Create a context manager for the default device and stream.
|
||||
with mx.stream(mx.cpu):
|
||||
# Operations here will use mx.cpu by default.
|
||||
pass
|
||||
)pbdoc");
|
||||
}
|
||||
Reference in New Issue
Block a user