mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Adds device context manager (#679)
This commit is contained in:
		| @@ -1,5 +1,4 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| from collections import defaultdict | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -14,6 +14,7 @@ pybind11_add_module( | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp | ||||
| ) | ||||
|  | ||||
| if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) | ||||
|   | ||||
| @@ -12,7 +12,8 @@ using namespace py::literals; | ||||
| using namespace mlx::core; | ||||
|  | ||||
| void init_device(py::module_& m) { | ||||
|   auto device_class = py::class_<Device>(m, "Device"); | ||||
|   auto device_class = py::class_<Device>( | ||||
|       m, "Device", R"pbdoc(A device to run operations on.)pbdoc"); | ||||
|   py::enum_<Device::DeviceType>(m, "DeviceType") | ||||
|       .value("cpu", Device::DeviceType::cpu) | ||||
|       .value("gpu", Device::DeviceType::gpu) | ||||
| @@ -39,6 +40,13 @@ void init_device(py::module_& m) { | ||||
|  | ||||
|   py::implicitly_convertible<Device::DeviceType, Device>(); | ||||
|  | ||||
|   m.def("default_device", &default_device); | ||||
|   m.def("set_default_device", &set_default_device, "device"_a); | ||||
|   m.def( | ||||
|       "default_device", | ||||
|       &default_device, | ||||
|       R"pbdoc(Get the default device.)pbdoc"); | ||||
|   m.def( | ||||
|       "set_default_device", | ||||
|       &set_default_device, | ||||
|       "device"_a, | ||||
|       R"pbdoc(Set the default device.)pbdoc"); | ||||
| } | ||||
|   | ||||
| @@ -18,6 +18,7 @@ void init_fft(py::module_&); | ||||
| void init_linalg(py::module_&); | ||||
| void init_constants(py::module_&); | ||||
| void init_extensions(py::module_&); | ||||
| void init_utils(py::module_&); | ||||
|  | ||||
| PYBIND11_MODULE(core, m) { | ||||
|   m.doc() = "mlx: A framework for machine learning on Apple silicon."; | ||||
| @@ -35,5 +36,7 @@ PYBIND11_MODULE(core, m) { | ||||
|   init_linalg(m); | ||||
|   init_constants(m); | ||||
|   init_extensions(m); | ||||
|   init_utils(m); | ||||
|  | ||||
|   m.attr("__version__") = TOSTRING(_VERSION_); | ||||
| } | ||||
|   | ||||
| @@ -12,7 +12,12 @@ using namespace py::literals; | ||||
| using namespace mlx::core; | ||||
|  | ||||
| void init_stream(py::module_& m) { | ||||
|   py::class_<Stream>(m, "Stream") | ||||
|   py::class_<Stream>( | ||||
|       m, | ||||
|       "Stream", | ||||
|       R"pbdoc( | ||||
|       A stream for running operations on a given device. | ||||
|       )pbdoc") | ||||
|       .def(py::init<int, Device>(), "index"_a, "device"_a) | ||||
|       .def_readonly("device", &Stream::device) | ||||
|       .def( | ||||
| @@ -28,7 +33,27 @@ void init_stream(py::module_& m) { | ||||
|  | ||||
|   py::implicitly_convertible<Device::DeviceType, Device>(); | ||||
|  | ||||
|   m.def("default_stream", &default_stream, "device"_a); | ||||
|   m.def("set_default_stream", &set_default_stream, "stream"_a); | ||||
|   m.def("new_stream", &new_stream, "device"_a); | ||||
|   m.def( | ||||
|       "default_stream", | ||||
|       &default_stream, | ||||
|       "device"_a, | ||||
|       R"pbdoc(Get the device's default stream.)pbdoc"); | ||||
|   m.def( | ||||
|       "set_default_stream", | ||||
|       &set_default_stream, | ||||
|       "stream"_a, | ||||
|       R"pbdoc( | ||||
|         Set the default stream. | ||||
|  | ||||
|         This will make the given stream the default for the | ||||
|         streams device. It will not change the default device. | ||||
|  | ||||
|         Args: | ||||
|           stream (stream): Stream to make the default. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "new_stream", | ||||
|       &new_stream, | ||||
|       "device"_a, | ||||
|       R"pbdoc(Make a new stream on the given device.)pbdoc"); | ||||
| } | ||||
|   | ||||
							
								
								
									
										81
									
								
								python/src/utils.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								python/src/utils.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,81 @@ | ||||
|  | ||||
| #include "mlx/utils.h" | ||||
| #include <pybind11/pybind11.h> | ||||
| #include <pybind11/stl.h> | ||||
| #include <optional> | ||||
|  | ||||
| namespace py = pybind11; | ||||
| using namespace py::literals; | ||||
| using namespace mlx::core; | ||||
|  | ||||
| // Slightly different from the original, with python context on init we are not | ||||
| // in the context yet. Only create the inner context on enter then delete on | ||||
| // exit. | ||||
| class PyStreamContext { | ||||
|  public: | ||||
|   PyStreamContext(StreamOrDevice s) : _inner(nullptr) { | ||||
|     if (std::holds_alternative<std::monostate>(s)) { | ||||
|       throw std::runtime_error( | ||||
|           "[StreamContext] Invalid argument, please specify a stream or device."); | ||||
|     } | ||||
|     _s = s; | ||||
|   } | ||||
|  | ||||
|   void enter() { | ||||
|     _inner = new StreamContext(_s); | ||||
|   } | ||||
|  | ||||
|   void exit() { | ||||
|     if (_inner != nullptr) { | ||||
|       delete _inner; | ||||
|       _inner = nullptr; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   StreamOrDevice _s; | ||||
|   StreamContext* _inner; | ||||
| }; | ||||
|  | ||||
| void init_utils(py::module_& m) { | ||||
|   py::class_<PyStreamContext>(m, "StreamContext", R"pbdoc( | ||||
|         A context manager for setting the current device and stream. | ||||
|  | ||||
|         See :func:`stream` for usage. | ||||
|  | ||||
|         Args: | ||||
|             s: The stream or device to set as the default. | ||||
|   )pbdoc") | ||||
|       .def(py::init<StreamOrDevice>(), "s"_a) | ||||
|       .def("__enter__", [](PyStreamContext& scm) { scm.enter(); }) | ||||
|       .def( | ||||
|           "__exit__", | ||||
|           [](PyStreamContext& scm, | ||||
|              const std::optional<py::type>& exc_type, | ||||
|              const std::optional<py::object>& exc_value, | ||||
|              const std::optional<py::object>& traceback) { scm.exit(); }); | ||||
|   m.def( | ||||
|       "stream", | ||||
|       [](StreamOrDevice s) { return PyStreamContext(s); }, | ||||
|       "s"_a, | ||||
|       R"pbdoc( | ||||
|         Create a context manager to set the default device and stream. | ||||
|  | ||||
|         Args: | ||||
|             s: The :obj:`Stream` or :obj:`Device` to set as the default. | ||||
|  | ||||
|         Returns: | ||||
|             A context manager that sets the default device and stream. | ||||
|  | ||||
|         Example: | ||||
|  | ||||
|         .. code-block::python | ||||
|  | ||||
|           import mlx.core as mx | ||||
|  | ||||
|           # Create a context manager for the default device and stream. | ||||
|           with mx.stream(mx.cpu): | ||||
|               # Operations here will use mx.cpu by default. | ||||
|               pass | ||||
|       )pbdoc"); | ||||
| } | ||||
| @@ -38,6 +38,17 @@ class TestDevice(mlx_tests.MLXTestCase): | ||||
|         # Restore device | ||||
|         mx.set_default_device(device) | ||||
|  | ||||
|     @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") | ||||
|     def test_device_context(self): | ||||
|         default = mx.default_device() | ||||
|         diff = mx.cpu if default == mx.gpu else mx.gpu | ||||
|         self.assertNotEqual(default, diff) | ||||
|         with mx.stream(diff): | ||||
|             a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2))) | ||||
|             mx.eval(a) | ||||
|             self.assertEqual(mx.default_device(), diff) | ||||
|         self.assertEqual(mx.default_device(), default) | ||||
|  | ||||
|     def test_op_on_device(self): | ||||
|         x = mx.array(1.0) | ||||
|         y = mx.array(1.0) | ||||
|   | ||||
| @@ -19,72 +19,73 @@ class TestFFT(mlx_tests.MLXTestCase): | ||||
|             self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) | ||||
|  | ||||
|     def test_fft(self): | ||||
|         default = mx.default_device() | ||||
|         mx.set_default_device(mx.cpu) | ||||
|  | ||||
|         def check_mx_np(op_mx, op_np, a_np, **kwargs): | ||||
|             out_np = op_np(a_np, **kwargs) | ||||
|             a_mx = mx.array(a_np) | ||||
|             out_mx = op_mx(a_mx, **kwargs) | ||||
|             self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) | ||||
|  | ||||
|         r = np.random.rand(100).astype(np.float32) | ||||
|         i = np.random.rand(100).astype(np.float32) | ||||
|         a_np = r + 1j * i | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np) | ||||
|         with mx.stream(mx.cpu): | ||||
|             r = np.random.rand(100).astype(np.float32) | ||||
|             i = np.random.rand(100).astype(np.float32) | ||||
|             a_np = r + 1j * i | ||||
|             check_mx_np(mx.fft.fft, np.fft.fft, a_np) | ||||
|  | ||||
|         # Check with slicing and padding | ||||
|         r = np.random.rand(100).astype(np.float32) | ||||
|         i = np.random.rand(100).astype(np.float32) | ||||
|         a_np = r + 1j * i | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) | ||||
|             # Check with slicing and padding | ||||
|             r = np.random.rand(100).astype(np.float32) | ||||
|             i = np.random.rand(100).astype(np.float32) | ||||
|             a_np = r + 1j * i | ||||
|             check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) | ||||
|             check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) | ||||
|  | ||||
|         # Check different axes | ||||
|         r = np.random.rand(100, 100).astype(np.float32) | ||||
|         i = np.random.rand(100, 100).astype(np.float32) | ||||
|         a_np = r + 1j * i | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) | ||||
|             # Check different axes | ||||
|             r = np.random.rand(100, 100).astype(np.float32) | ||||
|             i = np.random.rand(100, 100).astype(np.float32) | ||||
|             a_np = r + 1j * i | ||||
|             check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) | ||||
|             check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) | ||||
|  | ||||
|         # Check real fft | ||||
|         a_np = np.random.rand(100).astype(np.float32) | ||||
|         check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) | ||||
|         check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) | ||||
|         check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) | ||||
|             # Check real fft | ||||
|             a_np = np.random.rand(100).astype(np.float32) | ||||
|             check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) | ||||
|             check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) | ||||
|             check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) | ||||
|  | ||||
|         # Check real inverse | ||||
|         r = np.random.rand(100, 100).astype(np.float32) | ||||
|         i = np.random.rand(100, 100).astype(np.float32) | ||||
|         a_np = r + 1j * i | ||||
|         check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) | ||||
|         check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) | ||||
|         check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) | ||||
|         check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) | ||||
|         check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) | ||||
|         check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) | ||||
|  | ||||
|         mx.set_default_device(default) | ||||
|             # Check real inverse | ||||
|             r = np.random.rand(100, 100).astype(np.float32) | ||||
|             i = np.random.rand(100, 100).astype(np.float32) | ||||
|             a_np = r + 1j * i | ||||
|             check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) | ||||
|             check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) | ||||
|             check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) | ||||
|             check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) | ||||
|             check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) | ||||
|             check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) | ||||
|  | ||||
|     def test_fftn(self): | ||||
|         default = mx.default_device() | ||||
|         mx.set_default_device(mx.cpu) | ||||
|         with mx.stream(mx.cpu): | ||||
|             r = np.random.randn(8, 8, 8).astype(np.float32) | ||||
|             i = np.random.randn(8, 8, 8).astype(np.float32) | ||||
|             a = r + 1j * i | ||||
|  | ||||
|         r = np.random.randn(8, 8, 8).astype(np.float32) | ||||
|         i = np.random.randn(8, 8, 8).astype(np.float32) | ||||
|         a = r + 1j * i | ||||
|             axes = [None, (1, 2), (2, 1), (0, 2)] | ||||
|             shapes = [None, (10, 5), (5, 10)] | ||||
|             ops = [ | ||||
|                 "fft2", | ||||
|                 "ifft2", | ||||
|                 "rfft2", | ||||
|                 "irfft2", | ||||
|                 "fftn", | ||||
|                 "ifftn", | ||||
|                 "rfftn", | ||||
|                 "irfftn", | ||||
|             ] | ||||
|  | ||||
|         axes = [None, (1, 2), (2, 1), (0, 2)] | ||||
|         shapes = [None, (10, 5), (5, 10)] | ||||
|         ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"] | ||||
|  | ||||
|         for op, ax, s in itertools.product(ops, axes, shapes): | ||||
|             x = a | ||||
|             if op in ["rfft2", "rfftn"]: | ||||
|                 x = r | ||||
|             self.check_mx_np(op, x, axes=ax, s=s) | ||||
|  | ||||
|         mx.set_default_device(default) | ||||
|             for op, ax, s in itertools.product(ops, axes, shapes): | ||||
|                 x = a | ||||
|                 if op in ["rfft2", "rfftn"]: | ||||
|                     x = r | ||||
|                 self.check_mx_np(op, x, axes=ax, s=s) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Diogo
					Diogo