From 3d405fb3b1c073fb7f12abf3d2433536991572e1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 22 Apr 2024 08:25:46 -0700 Subject: [PATCH] Add synchronize function (#1006) * add synchronize function * fix linux * fix linux * fix and fix docs * fix test * try synchronize in stream destroy * synchronize works for both cpu and gpu --- docs/src/dev/metal_debugger.rst | 7 +++-- docs/src/python/devices_and_streams.rst | 1 + examples/cpp/metal_capture.cpp | 2 +- mlx/backend/metal/metal.cpp | 36 +++++++++++++++++++++---- mlx/backend/metal/metal.h | 2 +- mlx/backend/metal/metal_impl.h | 5 ++++ mlx/backend/no_metal/metal.cpp | 12 ++++++--- mlx/scheduler.cpp | 15 +++++++++++ mlx/scheduler.h | 1 + mlx/stream.h | 6 +++++ python/src/metal.cpp | 4 +-- python/src/stream.cpp | 13 +++++++++ python/tests/test_metal.py | 6 +++-- tests/metal_tests.cpp | 8 +++--- 14 files changed, 95 insertions(+), 23 deletions(-) diff --git a/docs/src/dev/metal_debugger.rst b/docs/src/dev/metal_debugger.rst index 94d25258c..df4b98822 100644 --- a/docs/src/dev/metal_debugger.rst +++ b/docs/src/dev/metal_debugger.rst @@ -32,10 +32,9 @@ work. trace_file = "mlx_trace.gputrace" - if not mx.metal.start_capture(trace_file): - print("Make sure to run with MTL_CAPTURE_ENABLED=1 and " - f"that the path {trace_file} does not already exist.") - exit(1) + # Make sure to run with MTL_CAPTURE_ENABLED=1 and + # that the path trace_file does not already exist. + mx.metal.start_capture(trace_file) for _ in range(10): mx.eval(mx.add(a, b)) diff --git a/docs/src/python/devices_and_streams.rst b/docs/src/python/devices_and_streams.rst index e16ab9875..2a5adc05a 100644 --- a/docs/src/python/devices_and_streams.rst +++ b/docs/src/python/devices_and_streams.rst @@ -16,3 +16,4 @@ Devices and Streams new_stream set_default_stream stream + synchronize diff --git a/examples/cpp/metal_capture.cpp b/examples/cpp/metal_capture.cpp index 1033b614b..d31c49f96 100644 --- a/examples/cpp/metal_capture.cpp +++ b/examples/cpp/metal_capture.cpp @@ -11,7 +11,7 @@ int main() { // To use Metal debugging and profiling: // 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON). // 2. Run with MTL_CAPTURE_ENABLED=1. - assert(metal::start_capture("mlx_trace.gputrace")); + metal::start_capture("mlx_trace.gputrace"); // Start at index two because the default GPU and CPU streams have indices // zero and one, respectively. This naming matches the label assigned to each diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 1de8ceec5..2cdbc49a5 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -1,6 +1,5 @@ // Copyright © 2023-2024 Apple Inc. #include -#include #include #include "mlx/backend/metal/device.h" @@ -115,13 +114,31 @@ std::function make_task(array arr, bool signal) { return task; } -bool start_capture(std::string path, id object) { +std::function make_synchronize_task( + Stream s, + std::shared_ptr> p) { + return [s, p = std::move(p)]() { + auto& d = metal::device(s.device); + auto cb = d.get_command_buffer(s.index); + if (cb == nullptr) { + cb = d.new_command_buffer(s.index); + } else { + d.end_encoding(s.index); + } + d.commit_command_buffer(s.index); + cb->waitUntilCompleted(); + check_error(cb); + p->set_value(); + }; +} + +void start_capture(std::string path, id object) { auto pool = new_scoped_memory_pool(); auto descriptor = MTL::CaptureDescriptor::alloc()->init(); descriptor->setCaptureObject(object); - if (path.length() > 0) { + if (!path.empty()) { auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding); auto url = NS::URL::fileURLWithPath(string); descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument); @@ -129,15 +146,24 @@ bool start_capture(std::string path, id object) { } auto manager = MTL::CaptureManager::sharedCaptureManager(); - return manager->startCapture(descriptor, nullptr); + NS::Error* error; + bool started = manager->startCapture(descriptor, &error); + descriptor->release(); + if (!started) { + std::ostringstream msg; + msg << "[metal::start_capture] Failed to start: " + << error->localizedDescription()->utf8String(); + throw std::runtime_error(msg.str()); + } } -bool start_capture(std::string path) { +void start_capture(std::string path) { auto& device = metal::device(mlx::core::Device::gpu); return start_capture(path, device.mtl_device()); } void stop_capture() { + auto pool = new_scoped_memory_pool(); auto manager = MTL::CaptureManager::sharedCaptureManager(); manager->stopCapture(); } diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index fd417b3d7..86a47f37d 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -55,7 +55,7 @@ size_t set_memory_limit(size_t limit, bool relaxed = true); size_t set_cache_limit(size_t limit); /** Capture a GPU trace, saving it to an absolute file `path` */ -bool start_capture(std::string path = ""); +void start_capture(std::string path = ""); void stop_capture(); } // namespace mlx::core::metal diff --git a/mlx/backend/metal/metal_impl.h b/mlx/backend/metal/metal_impl.h index 885dd33e3..5c0a14357 100644 --- a/mlx/backend/metal/metal_impl.h +++ b/mlx/backend/metal/metal_impl.h @@ -2,6 +2,7 @@ #pragma once +#include #include #include "mlx/array.h" @@ -15,4 +16,8 @@ std::unique_ptr> new_scoped_memory_pool(); std::function make_task(array arr, bool signal); +std::function make_synchronize_task( + Stream s, + std::shared_ptr> p); + } // namespace mlx::core::metal diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index 4993a7313..0a0b635ae 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -22,6 +22,14 @@ std::function make_task(array arr, bool signal) { "[metal::make_task] Cannot make GPU task without metal backend"); } +std::function make_synchronize_task( + Stream s, + std::shared_ptr> p) { + throw std::runtime_error( + "[metal::make_synchronize_task] Cannot synchronize GPU" + " without metal backend"); +} + // No-ops when Metal is not available. size_t get_active_memory() { return 0; @@ -38,9 +46,7 @@ size_t set_memory_limit(size_t, bool) { size_t set_cache_limit(size_t) { return 0; } -bool start_capture(std::string path) { - return false; -} +void start_capture(std::string path) {} void stop_capture() {} } // namespace mlx::core::metal diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index af66af06e..9e4342583 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -33,6 +33,21 @@ Stream new_stream() { return scheduler::scheduler().new_stream(default_device()); } +void synchronize(Stream s) { + auto p = std::make_shared>(); + std::future f = p->get_future(); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); }); + } else { + scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p))); + } + f.wait(); +} + +void synchronize() { + synchronize(default_stream(default_device())); +} + namespace scheduler { /** A singleton scheduler to manage devices, streams, and task execution. */ diff --git a/mlx/scheduler.h b/mlx/scheduler.h index f50a8c310..40e33f2be 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -27,6 +27,7 @@ struct StreamThread { : stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {} ~StreamThread() { + synchronize(stream); { std::unique_lock lk(mtx); stop = true; diff --git a/mlx/stream.h b/mlx/stream.h index d7b4268fd..2aec2c96f 100644 --- a/mlx/stream.h +++ b/mlx/stream.h @@ -29,4 +29,10 @@ inline bool operator!=(const Stream& lhs, const Stream& rhs) { return !(lhs == rhs); } +/* Synchronize with the default stream. */ +void synchronize(); + +/* Synchronize with the provided stream. */ +void synchronize(Stream); + } // namespace mlx::core diff --git a/python/src/metal.cpp b/python/src/metal.cpp index 53e14a228..6c7a27655 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -2,6 +2,7 @@ #include "mlx/backend/metal/metal.h" #include +#include #include namespace nb = nanobind; @@ -99,9 +100,6 @@ void init_metal(nb::module_& m) { Args: path (str): The path to save the capture which should have the extension ``.gputrace``. - - Returns: - bool: Whether the capture was successfully started. )pbdoc"); metal.def( "stop_capture", diff --git a/python/src/stream.cpp b/python/src/stream.cpp index c83d9b447..7eb8c6bf3 100644 --- a/python/src/stream.cpp +++ b/python/src/stream.cpp @@ -129,4 +129,17 @@ void init_stream(nb::module_& m) { # Operations here will use mx.cpu by default. pass )pbdoc"); + m.def( + "synchronize", + [](const std::optional& s) { + s ? synchronize(s.value()) : synchronize(); + }, + "stream"_a = nb::none(), + R"pbdoc( + Synchronize with the given stream. + + Args: + (Stream, optional): The stream to synchronize with. If ``None`` then + the default stream of the default device is used. Default: ``None``. + )pbdoc"); } diff --git a/python/tests/test_metal.py b/python/tests/test_metal.py index 51bceb38f..2b3b107b1 100644 --- a/python/tests/test_metal.py +++ b/python/tests/test_metal.py @@ -24,14 +24,16 @@ class TestMetal(mlx_tests.MLXTestCase): self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit) # Query active and peak memory - a = mx.zeros((4096,), stream=mx.cpu) + a = mx.zeros((4096,)) mx.eval(a) + mx.synchronize() active_mem = mx.metal.get_active_memory() self.assertTrue(active_mem >= 4096 * 4) - b = mx.zeros((4096,), stream=mx.cpu) + b = mx.zeros((4096,)) mx.eval(b) del b + mx.synchronize() new_active_mem = mx.metal.get_active_memory() self.assertEqual(new_active_mem, active_mem) diff --git a/tests/metal_tests.cpp b/tests/metal_tests.cpp index 976317f2f..1ce50dcd2 100644 --- a/tests/metal_tests.cpp +++ b/tests/metal_tests.cpp @@ -495,16 +495,16 @@ TEST_CASE("test metal memory info") { // Query active and peak memory { - // Do these tests on the CPU since deallocation is synchronized - // with the main thread. - auto a = zeros({4096}, Device::cpu); + auto a = zeros({4096}); eval(a); + synchronize(); auto active_mem = metal::get_active_memory(); CHECK(active_mem >= 4096 * 4); { - auto b = zeros({4096}, Device::cpu); + auto b = zeros({4096}); eval(b); } + synchronize(); auto new_active_mem = metal::get_active_memory(); CHECK_EQ(new_active_mem, active_mem); auto peak_mem = metal::get_peak_memory();