mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Adds device context manager (#679)
This commit is contained in:
		@@ -14,6 +14,7 @@ pybind11_add_module(
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,8 @@ using namespace py::literals;
 | 
			
		||||
using namespace mlx::core;
 | 
			
		||||
 | 
			
		||||
void init_device(py::module_& m) {
 | 
			
		||||
  auto device_class = py::class_<Device>(m, "Device");
 | 
			
		||||
  auto device_class = py::class_<Device>(
 | 
			
		||||
      m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
 | 
			
		||||
  py::enum_<Device::DeviceType>(m, "DeviceType")
 | 
			
		||||
      .value("cpu", Device::DeviceType::cpu)
 | 
			
		||||
      .value("gpu", Device::DeviceType::gpu)
 | 
			
		||||
@@ -39,6 +40,13 @@ void init_device(py::module_& m) {
 | 
			
		||||
 | 
			
		||||
  py::implicitly_convertible<Device::DeviceType, Device>();
 | 
			
		||||
 | 
			
		||||
  m.def("default_device", &default_device);
 | 
			
		||||
  m.def("set_default_device", &set_default_device, "device"_a);
 | 
			
		||||
  m.def(
 | 
			
		||||
      "default_device",
 | 
			
		||||
      &default_device,
 | 
			
		||||
      R"pbdoc(Get the default device.)pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "set_default_device",
 | 
			
		||||
      &set_default_device,
 | 
			
		||||
      "device"_a,
 | 
			
		||||
      R"pbdoc(Set the default device.)pbdoc");
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -18,6 +18,7 @@ void init_fft(py::module_&);
 | 
			
		||||
void init_linalg(py::module_&);
 | 
			
		||||
void init_constants(py::module_&);
 | 
			
		||||
void init_extensions(py::module_&);
 | 
			
		||||
void init_utils(py::module_&);
 | 
			
		||||
 | 
			
		||||
PYBIND11_MODULE(core, m) {
 | 
			
		||||
  m.doc() = "mlx: A framework for machine learning on Apple silicon.";
 | 
			
		||||
@@ -35,5 +36,7 @@ PYBIND11_MODULE(core, m) {
 | 
			
		||||
  init_linalg(m);
 | 
			
		||||
  init_constants(m);
 | 
			
		||||
  init_extensions(m);
 | 
			
		||||
  init_utils(m);
 | 
			
		||||
 | 
			
		||||
  m.attr("__version__") = TOSTRING(_VERSION_);
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,12 @@ using namespace py::literals;
 | 
			
		||||
using namespace mlx::core;
 | 
			
		||||
 | 
			
		||||
void init_stream(py::module_& m) {
 | 
			
		||||
  py::class_<Stream>(m, "Stream")
 | 
			
		||||
  py::class_<Stream>(
 | 
			
		||||
      m,
 | 
			
		||||
      "Stream",
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
      A stream for running operations on a given device.
 | 
			
		||||
      )pbdoc")
 | 
			
		||||
      .def(py::init<int, Device>(), "index"_a, "device"_a)
 | 
			
		||||
      .def_readonly("device", &Stream::device)
 | 
			
		||||
      .def(
 | 
			
		||||
@@ -28,7 +33,27 @@ void init_stream(py::module_& m) {
 | 
			
		||||
 | 
			
		||||
  py::implicitly_convertible<Device::DeviceType, Device>();
 | 
			
		||||
 | 
			
		||||
  m.def("default_stream", &default_stream, "device"_a);
 | 
			
		||||
  m.def("set_default_stream", &set_default_stream, "stream"_a);
 | 
			
		||||
  m.def("new_stream", &new_stream, "device"_a);
 | 
			
		||||
  m.def(
 | 
			
		||||
      "default_stream",
 | 
			
		||||
      &default_stream,
 | 
			
		||||
      "device"_a,
 | 
			
		||||
      R"pbdoc(Get the device's default stream.)pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "set_default_stream",
 | 
			
		||||
      &set_default_stream,
 | 
			
		||||
      "stream"_a,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Set the default stream.
 | 
			
		||||
 | 
			
		||||
        This will make the given stream the default for the
 | 
			
		||||
        streams device. It will not change the default device.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
          stream (stream): Stream to make the default.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "new_stream",
 | 
			
		||||
      &new_stream,
 | 
			
		||||
      "device"_a,
 | 
			
		||||
      R"pbdoc(Make a new stream on the given device.)pbdoc");
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										81
									
								
								python/src/utils.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								python/src/utils.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,81 @@
 | 
			
		||||
 | 
			
		||||
#include "mlx/utils.h"
 | 
			
		||||
#include <pybind11/pybind11.h>
 | 
			
		||||
#include <pybind11/stl.h>
 | 
			
		||||
#include <optional>
 | 
			
		||||
 | 
			
		||||
namespace py = pybind11;
 | 
			
		||||
using namespace py::literals;
 | 
			
		||||
using namespace mlx::core;
 | 
			
		||||
 | 
			
		||||
// Slightly different from the original, with python context on init we are not
 | 
			
		||||
// in the context yet. Only create the inner context on enter then delete on
 | 
			
		||||
// exit.
 | 
			
		||||
class PyStreamContext {
 | 
			
		||||
 public:
 | 
			
		||||
  PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
 | 
			
		||||
    if (std::holds_alternative<std::monostate>(s)) {
 | 
			
		||||
      throw std::runtime_error(
 | 
			
		||||
          "[StreamContext] Invalid argument, please specify a stream or device.");
 | 
			
		||||
    }
 | 
			
		||||
    _s = s;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void enter() {
 | 
			
		||||
    _inner = new StreamContext(_s);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void exit() {
 | 
			
		||||
    if (_inner != nullptr) {
 | 
			
		||||
      delete _inner;
 | 
			
		||||
      _inner = nullptr;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  StreamOrDevice _s;
 | 
			
		||||
  StreamContext* _inner;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void init_utils(py::module_& m) {
 | 
			
		||||
  py::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
 | 
			
		||||
        A context manager for setting the current device and stream.
 | 
			
		||||
 | 
			
		||||
        See :func:`stream` for usage.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            s: The stream or device to set as the default.
 | 
			
		||||
  )pbdoc")
 | 
			
		||||
      .def(py::init<StreamOrDevice>(), "s"_a)
 | 
			
		||||
      .def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
 | 
			
		||||
      .def(
 | 
			
		||||
          "__exit__",
 | 
			
		||||
          [](PyStreamContext& scm,
 | 
			
		||||
             const std::optional<py::type>& exc_type,
 | 
			
		||||
             const std::optional<py::object>& exc_value,
 | 
			
		||||
             const std::optional<py::object>& traceback) { scm.exit(); });
 | 
			
		||||
  m.def(
 | 
			
		||||
      "stream",
 | 
			
		||||
      [](StreamOrDevice s) { return PyStreamContext(s); },
 | 
			
		||||
      "s"_a,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Create a context manager to set the default device and stream.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            s: The :obj:`Stream` or :obj:`Device` to set as the default.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A context manager that sets the default device and stream.
 | 
			
		||||
 | 
			
		||||
        Example:
 | 
			
		||||
 | 
			
		||||
        .. code-block::python
 | 
			
		||||
 | 
			
		||||
          import mlx.core as mx
 | 
			
		||||
 | 
			
		||||
          # Create a context manager for the default device and stream.
 | 
			
		||||
          with mx.stream(mx.cpu):
 | 
			
		||||
              # Operations here will use mx.cpu by default.
 | 
			
		||||
              pass
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user