From 35431a4ac8fbe286e42225f6bb5d521dbb8d7334 Mon Sep 17 00:00:00 2001 From: Diogo Date: Wed, 14 Feb 2024 17:14:58 -0500 Subject: [PATCH] Adds device context manager (#679) --- ACKNOWLEDGMENTS.md | 2 +- docs/src/conf.py | 1 + docs/src/python/devices_and_streams.rst | 3 +- mlx/ops.cpp | 10 --- mlx/ops.h | 6 +- mlx/utils.cpp | 10 +++ mlx/utils.h | 26 ++++++ python/mlx/utils.py | 1 - python/src/CMakeLists.txt | 1 + python/src/device.cpp | 14 +++- python/src/mlx.cpp | 3 + python/src/stream.cpp | 33 +++++++- python/src/utils.cpp | 81 ++++++++++++++++++ python/tests/test_device.py | 11 +++ python/tests/test_fft.py | 105 ++++++++++++------------ 15 files changed, 230 insertions(+), 77 deletions(-) create mode 100644 python/src/utils.cpp diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 2a3c6c612..36aedc77a 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals: - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. -- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support +- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``. diff --git a/docs/src/conf.py b/docs/src/conf.py index bec2c976c..0654cf53c 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -26,6 +26,7 @@ extensions = [ python_use_unqualified_type_names = True autosummary_generate = True +autosummary_filename_map = {"mlx.core.Stream": "stream_class"} intersphinx_mapping = { "https://docs.python.org/3": None, diff --git a/docs/src/python/devices_and_streams.rst b/docs/src/python/devices_and_streams.rst index bb9dfae2f..e16ab9875 100644 --- a/docs/src/python/devices_and_streams.rst +++ b/docs/src/python/devices_and_streams.rst @@ -9,9 +9,10 @@ Devices and Streams :toctree: _autosummary Device + Stream default_device set_default_device - Stream default_stream new_stream set_default_stream + stream diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 01ee6d388..549d26512 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -59,16 +59,6 @@ Dtype at_least_float(const Dtype& d) { } // namespace -Stream to_stream(StreamOrDevice s) { - if (std::holds_alternative(s)) { - return default_stream(default_device()); - } else if (std::holds_alternative(s)) { - return default_stream(std::get(s)); - } else { - return std::get(s); - } -} - array arange( double start, double stop, diff --git a/mlx/ops.h b/mlx/ops.h index a4b1dd1ef..f7036b8c6 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -3,18 +3,14 @@ #pragma once #include -#include #include "mlx/array.h" #include "mlx/device.h" #include "mlx/stream.h" +#include "mlx/utils.h" namespace mlx::core { -using StreamOrDevice = std::variant; - -Stream to_stream(StreamOrDevice s); - /** Creation operations */ /** diff --git a/mlx/utils.cpp b/mlx/utils.cpp index eece43717..c6365beb9 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -7,6 +7,16 @@ namespace mlx::core { +Stream to_stream(StreamOrDevice s) { + if (std::holds_alternative(s)) { + return default_stream(default_device()); + } else if (std::holds_alternative(s)) { + return default_stream(std::get(s)); + } else { + return std::get(s); + } +} + void PrintFormatter::print(std::ostream& os, bool val) { if (capitalize_bool) { os << (val ? "True" : "False"); diff --git a/mlx/utils.h b/mlx/utils.h index f28970369..88f47e3e1 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -2,6 +2,8 @@ #pragma once +#include + #include "array.h" #include "device.h" #include "dtype.h" @@ -9,6 +11,30 @@ namespace mlx::core { +using StreamOrDevice = std::variant; +Stream to_stream(StreamOrDevice s); + +struct StreamContext { + public: + StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) { + if (std::holds_alternative(s)) { + throw std::runtime_error( + "[StreamContext] Invalid argument, please specify a stream or device."); + } + auto _s = to_stream(s); + set_default_device(_s.device); + set_default_stream(_s); + } + + ~StreamContext() { + set_default_device(_stream.device); + set_default_stream(_stream); + } + + private: + Stream _stream; +}; + struct PrintFormatter { inline void print(std::ostream& os, bool val); inline void print(std::ostream& os, int16_t val); diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 137a8aae4..802b03831 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,5 +1,4 @@ # Copyright © 2023 Apple Inc. - from collections import defaultdict diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 7dd862033..4df503a4a 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -14,6 +14,7 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ) if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) diff --git a/python/src/device.cpp b/python/src/device.cpp index 8c36f0f85..c88144520 100644 --- a/python/src/device.cpp +++ b/python/src/device.cpp @@ -12,7 +12,8 @@ using namespace py::literals; using namespace mlx::core; void init_device(py::module_& m) { - auto device_class = py::class_(m, "Device"); + auto device_class = py::class_( + m, "Device", R"pbdoc(A device to run operations on.)pbdoc"); py::enum_(m, "DeviceType") .value("cpu", Device::DeviceType::cpu) .value("gpu", Device::DeviceType::gpu) @@ -39,6 +40,13 @@ void init_device(py::module_& m) { py::implicitly_convertible(); - m.def("default_device", &default_device); - m.def("set_default_device", &set_default_device, "device"_a); + m.def( + "default_device", + &default_device, + R"pbdoc(Get the default device.)pbdoc"); + m.def( + "set_default_device", + &set_default_device, + "device"_a, + R"pbdoc(Set the default device.)pbdoc"); } diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index ee0f469f9..5fb9e74e2 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -18,6 +18,7 @@ void init_fft(py::module_&); void init_linalg(py::module_&); void init_constants(py::module_&); void init_extensions(py::module_&); +void init_utils(py::module_&); PYBIND11_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -35,5 +36,7 @@ PYBIND11_MODULE(core, m) { init_linalg(m); init_constants(m); init_extensions(m); + init_utils(m); + m.attr("__version__") = TOSTRING(_VERSION_); } diff --git a/python/src/stream.cpp b/python/src/stream.cpp index 7b1b2f55d..768795fc1 100644 --- a/python/src/stream.cpp +++ b/python/src/stream.cpp @@ -12,7 +12,12 @@ using namespace py::literals; using namespace mlx::core; void init_stream(py::module_& m) { - py::class_(m, "Stream") + py::class_( + m, + "Stream", + R"pbdoc( + A stream for running operations on a given device. + )pbdoc") .def(py::init(), "index"_a, "device"_a) .def_readonly("device", &Stream::device) .def( @@ -28,7 +33,27 @@ void init_stream(py::module_& m) { py::implicitly_convertible(); - m.def("default_stream", &default_stream, "device"_a); - m.def("set_default_stream", &set_default_stream, "stream"_a); - m.def("new_stream", &new_stream, "device"_a); + m.def( + "default_stream", + &default_stream, + "device"_a, + R"pbdoc(Get the device's default stream.)pbdoc"); + m.def( + "set_default_stream", + &set_default_stream, + "stream"_a, + R"pbdoc( + Set the default stream. + + This will make the given stream the default for the + streams device. It will not change the default device. + + Args: + stream (stream): Stream to make the default. + )pbdoc"); + m.def( + "new_stream", + &new_stream, + "device"_a, + R"pbdoc(Make a new stream on the given device.)pbdoc"); } diff --git a/python/src/utils.cpp b/python/src/utils.cpp new file mode 100644 index 000000000..c07016709 --- /dev/null +++ b/python/src/utils.cpp @@ -0,0 +1,81 @@ + +#include "mlx/utils.h" +#include +#include +#include + +namespace py = pybind11; +using namespace py::literals; +using namespace mlx::core; + +// Slightly different from the original, with python context on init we are not +// in the context yet. Only create the inner context on enter then delete on +// exit. +class PyStreamContext { + public: + PyStreamContext(StreamOrDevice s) : _inner(nullptr) { + if (std::holds_alternative(s)) { + throw std::runtime_error( + "[StreamContext] Invalid argument, please specify a stream or device."); + } + _s = s; + } + + void enter() { + _inner = new StreamContext(_s); + } + + void exit() { + if (_inner != nullptr) { + delete _inner; + _inner = nullptr; + } + } + + private: + StreamOrDevice _s; + StreamContext* _inner; +}; + +void init_utils(py::module_& m) { + py::class_(m, "StreamContext", R"pbdoc( + A context manager for setting the current device and stream. + + See :func:`stream` for usage. + + Args: + s: The stream or device to set as the default. + )pbdoc") + .def(py::init(), "s"_a) + .def("__enter__", [](PyStreamContext& scm) { scm.enter(); }) + .def( + "__exit__", + [](PyStreamContext& scm, + const std::optional& exc_type, + const std::optional& exc_value, + const std::optional& traceback) { scm.exit(); }); + m.def( + "stream", + [](StreamOrDevice s) { return PyStreamContext(s); }, + "s"_a, + R"pbdoc( + Create a context manager to set the default device and stream. + + Args: + s: The :obj:`Stream` or :obj:`Device` to set as the default. + + Returns: + A context manager that sets the default device and stream. + + Example: + + .. code-block::python + + import mlx.core as mx + + # Create a context manager for the default device and stream. + with mx.stream(mx.cpu): + # Operations here will use mx.cpu by default. + pass + )pbdoc"); +} diff --git a/python/tests/test_device.py b/python/tests/test_device.py index 8aac105bc..53826cad7 100644 --- a/python/tests/test_device.py +++ b/python/tests/test_device.py @@ -38,6 +38,17 @@ class TestDevice(mlx_tests.MLXTestCase): # Restore device mx.set_default_device(device) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_device_context(self): + default = mx.default_device() + diff = mx.cpu if default == mx.gpu else mx.gpu + self.assertNotEqual(default, diff) + with mx.stream(diff): + a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2))) + mx.eval(a) + self.assertEqual(mx.default_device(), diff) + self.assertEqual(mx.default_device(), default) + def test_op_on_device(self): x = mx.array(1.0) y = mx.array(1.0) diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 4be12e21f..14473afa1 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -19,72 +19,73 @@ class TestFFT(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) def test_fft(self): - default = mx.default_device() - mx.set_default_device(mx.cpu) - def check_mx_np(op_mx, op_np, a_np, **kwargs): out_np = op_np(a_np, **kwargs) a_mx = mx.array(a_np) out_mx = op_mx(a_mx, **kwargs) self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) - r = np.random.rand(100).astype(np.float32) - i = np.random.rand(100).astype(np.float32) - a_np = r + 1j * i - check_mx_np(mx.fft.fft, np.fft.fft, a_np) + with mx.stream(mx.cpu): + r = np.random.rand(100).astype(np.float32) + i = np.random.rand(100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.fft, np.fft.fft, a_np) - # Check with slicing and padding - r = np.random.rand(100).astype(np.float32) - i = np.random.rand(100).astype(np.float32) - a_np = r + 1j * i - check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) - check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) + # Check with slicing and padding + r = np.random.rand(100).astype(np.float32) + i = np.random.rand(100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) + check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) - # Check different axes - r = np.random.rand(100, 100).astype(np.float32) - i = np.random.rand(100, 100).astype(np.float32) - a_np = r + 1j * i - check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) - check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) + # Check different axes + r = np.random.rand(100, 100).astype(np.float32) + i = np.random.rand(100, 100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) + check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) - # Check real fft - a_np = np.random.rand(100).astype(np.float32) - check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) - check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) - check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) + # Check real fft + a_np = np.random.rand(100).astype(np.float32) + check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) + check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) + check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) - # Check real inverse - r = np.random.rand(100, 100).astype(np.float32) - i = np.random.rand(100, 100).astype(np.float32) - a_np = r + 1j * i - check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) - check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) - check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) - check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) - check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) - check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) - - mx.set_default_device(default) + # Check real inverse + r = np.random.rand(100, 100).astype(np.float32) + i = np.random.rand(100, 100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) + check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) + check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) + check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) + check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) + check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) def test_fftn(self): - default = mx.default_device() - mx.set_default_device(mx.cpu) + with mx.stream(mx.cpu): + r = np.random.randn(8, 8, 8).astype(np.float32) + i = np.random.randn(8, 8, 8).astype(np.float32) + a = r + 1j * i - r = np.random.randn(8, 8, 8).astype(np.float32) - i = np.random.randn(8, 8, 8).astype(np.float32) - a = r + 1j * i + axes = [None, (1, 2), (2, 1), (0, 2)] + shapes = [None, (10, 5), (5, 10)] + ops = [ + "fft2", + "ifft2", + "rfft2", + "irfft2", + "fftn", + "ifftn", + "rfftn", + "irfftn", + ] - axes = [None, (1, 2), (2, 1), (0, 2)] - shapes = [None, (10, 5), (5, 10)] - ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"] - - for op, ax, s in itertools.product(ops, axes, shapes): - x = a - if op in ["rfft2", "rfftn"]: - x = r - self.check_mx_np(op, x, axes=ax, s=s) - - mx.set_default_device(default) + for op, ax, s in itertools.product(ops, axes, shapes): + x = a + if op in ["rfft2", "rfftn"]: + x = r + self.check_mx_np(op, x, axes=ax, s=s) if __name__ == "__main__":