From 28902ece4e195cdf59b39c95586349ed8b8082aa Mon Sep 17 00:00:00 2001 From: Andrew Sweet Date: Wed, 30 Apr 2025 16:17:12 -0700 Subject: [PATCH 1/2] updated, simplified mutex for thread safety --- mlx/backend/metal/eval.cpp | 6 ++++++ mlx/backend/metal/event.cpp | 3 +++ mlx/backend/metal/fence.cpp | 3 +++ mlx/backend/metal/thread_safey.h | 7 +++++++ 4 files changed, 19 insertions(+) create mode 100644 mlx/backend/metal/thread_safey.h diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index 49783200a..a21853bf5 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -4,12 +4,15 @@ #include "mlx/backend/gpu/available.h" #include "mlx/backend/gpu/eval.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/thread_safey.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" namespace mlx::core::gpu { +std::mutex metal_operation_mutex; + bool is_available() { return true; } @@ -30,6 +33,7 @@ inline void check_error(MTL::CommandBuffer* cbuf) { } void eval(array& arr) { + std::lock_guard lock(metal_operation_mutex); auto pool = metal::new_scoped_memory_pool(); auto s = arr.primitive().stream(); auto& d = metal::device(s.device); @@ -78,6 +82,7 @@ void eval(array& arr) { } void finalize(Stream s) { + std::lock_guard lock(metal_operation_mutex); auto pool = metal::new_scoped_memory_pool(); auto& d = metal::device(s.device); auto cb = d.get_command_buffer(s.index); @@ -88,6 +93,7 @@ void finalize(Stream s) { } void synchronize(Stream s) { + std::lock_guard lock(metal_operation_mutex); auto pool = metal::new_scoped_memory_pool(); auto& d = metal::device(s.device); auto cb = d.get_command_buffer(s.index); diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index eb7f1b58a..e7905105a 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -2,6 +2,7 @@ #include "mlx/event.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/thread_safey.h" #include "mlx/scheduler.h" namespace mlx::core { @@ -27,6 +28,7 @@ void Event::wait(Stream stream) { if (stream.device == Device::cpu) { scheduler::enqueue(stream, [*this]() mutable { wait(); }); } else { + std::lock_guard lock(gpu::metal_operation_mutex); auto& d = metal::device(stream.device); d.end_encoding(stream.index); auto command_buffer = d.get_command_buffer(stream.index); @@ -41,6 +43,7 @@ void Event::signal(Stream stream) { static_cast(event_.get())->setSignaledValue(value()); }); } else { + std::lock_guard lock(gpu::metal_operation_mutex); auto& d = metal::device(stream.device); d.end_encoding(stream.index); auto command_buffer = d.get_command_buffer(stream.index); diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index d4a88d983..4b9b8f27f 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include "mlx/fence.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/thread_safey.h" #include "mlx/scheduler.h" #include "mlx/utils.h" @@ -68,6 +69,7 @@ void Fence::wait(Stream stream, const array& x) { return; } + std::lock_guard lock(gpu::metal_operation_mutex); auto& d = metal::device(stream.device); auto idx = stream.index; @@ -116,6 +118,7 @@ void Fence::update(Stream stream, const array& x) { return; } + std::lock_guard lock(gpu::metal_operation_mutex); auto& d = metal::device(stream.device); auto idx = stream.index; if (!f.use_fast) { diff --git a/mlx/backend/metal/thread_safey.h b/mlx/backend/metal/thread_safey.h new file mode 100644 index 000000000..0666a64d4 --- /dev/null +++ b/mlx/backend/metal/thread_safey.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +namespace mlx::core::gpu { + extern std::mutex metal_operation_mutex; +} \ No newline at end of file From c8d4d974472315ef6316dbf4d63cd71883f076b9 Mon Sep 17 00:00:00 2001 From: Andrew Sweet Date: Wed, 7 May 2025 11:59:54 -0700 Subject: [PATCH 2/2] tests added --- mlx/backend/metal/thread_safey.h | 2 +- tests/CMakeLists.txt | 2 +- tests/array_tests.cpp | 1 + tests/metal_thread_safety_tests.cpp | 250 ++++++++++++++++++++++++++++ 4 files changed, 253 insertions(+), 2 deletions(-) create mode 100644 tests/metal_thread_safety_tests.cpp diff --git a/mlx/backend/metal/thread_safey.h b/mlx/backend/metal/thread_safey.h index 0666a64d4..d4724a654 100644 --- a/mlx/backend/metal/thread_safey.h +++ b/mlx/backend/metal/thread_safey.h @@ -3,5 +3,5 @@ #include namespace mlx::core::gpu { - extern std::mutex metal_operation_mutex; +extern std::mutex metal_operation_mutex; } \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cf0ba3d5d..aa06ed248 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -10,7 +10,7 @@ FetchContent_MakeAvailable(doctest) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) if(MLX_BUILD_METAL) - set(METAL_TEST_SOURCES gpu_tests.cpp) + set(METAL_TEST_SOURCES gpu_tests.cpp metal_thread_safety_tests.cpp) endif() include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake) diff --git a/tests/array_tests.cpp b/tests/array_tests.cpp index b31da9899..2e4f44a3d 100644 --- a/tests/array_tests.cpp +++ b/tests/array_tests.cpp @@ -589,6 +589,7 @@ TEST_CASE("test array shared buffer") { array b = array(buf_b, shape, float32, deleter); eval(a + b); + synchronize(); // ensure all operations complete before test ends } TEST_CASE("test make empty array") { diff --git a/tests/metal_thread_safety_tests.cpp b/tests/metal_thread_safety_tests.cpp new file mode 100644 index 000000000..44feb8f2b --- /dev/null +++ b/tests/metal_thread_safety_tests.cpp @@ -0,0 +1,250 @@ +#include "doctest/doctest.h" +#include "mlx/mlx.h" +#include "mlx/backend/metal/device.h" + +#include +#include +#include +#include +#include +#include + +using namespace mlx::core; + +// Helper function to run operations across multiple threads with pre-created streams +void run_in_threads(int num_threads, const std::function& func, + const std::vector& streams) { + std::vector threads; + threads.reserve(num_threads); + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back(func, i, streams[i % streams.size()]); + } + for (auto& t : threads) { + if (t.joinable()) { + t.join(); + } + } +} + +// Helper function for tasks not requiring streams (e.g., using default stream) +void run_in_threads_default(int num_threads, const std::function& func) { + std::vector threads; + threads.reserve(num_threads); + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back(func, i); + } + for (auto& t : threads) { + if (t.joinable()) { + t.join(); + } + } +} + +// Thread-safe result collection +struct TestResults { + std::mutex mutex; + std::vector shape_checks; + std::vector availability_checks; + std::vector value_checks; + std::vector expected_values; + std::vector actual_values; + + void record_result(bool shape_ok, bool available_ok, bool value_ok, + float expected, float actual) { + std::lock_guard lock(mutex); + shape_checks.push_back(shape_ok); + availability_checks.push_back(available_ok); + value_checks.push_back(value_ok); + expected_values.push_back(expected); + actual_values.push_back(actual); + } +}; + +TEST_CASE("test metal concurrent eval operations") { + Device D_GPU = Device::gpu; + const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 8; + const int ops_per_thread = 10; + const int array_size = 32; + std::atomic completed_ops{0}; + TestResults results; + + // Pre-create streams to avoid concurrent stream creation + std::vector streams; + for (int i = 0; i < num_threads; ++i) { + streams.push_back(new_stream(D_GPU)); + } + synchronize(); // Ensure stream creation is complete + + auto task = [&](int thread_id, Stream s) { + try { + for (int i = 0; i < ops_per_thread; ++i) { + float val1 = static_cast(thread_id * ops_per_thread + i + 1); + float val2 = val1 * 2.0f; + + auto x = full({array_size, array_size}, val1, s); + auto y = full({array_size, array_size}, val2, s); + auto z = add(x, y); + eval(z); + + bool shape_ok = (z.shape() == Shape{array_size, array_size}); + bool available_ok = z.is_available(); + + // Get a value from the array + int mid = array_size/2; + auto sample = slice(z, {mid, mid}, {mid+1, mid+1}); + float actual = sample.item(); + float expected = val1 + val2; + + bool values_match = (std::abs(actual - expected) < 1e-5); + + results.record_result(shape_ok, available_ok, values_match, expected, actual); + + if (shape_ok && available_ok && values_match) { + completed_ops++; + } + } + } catch (const std::exception& e) { + std::cerr << "Thread " << thread_id << " exception: " << e.what() << std::endl; + } + }; + + // Run the threads with pre-created streams + CHECK_NOTHROW(run_in_threads(num_threads, task, streams)); + + // Check all results outside of threads + for (size_t i = 0; i < results.shape_checks.size(); ++i) { + CAPTURE(i); // Help identify which operation failed + CHECK(results.shape_checks[i]); + CHECK(results.availability_checks[i]); + CHECK(results.value_checks[i]); + if (!results.value_checks[i]) { + CAPTURE(results.expected_values[i]); + CAPTURE(results.actual_values[i]); + } + } + + // Verify all operations completed successfully + CHECK_EQ(completed_ops.load(), num_threads * ops_per_thread); +} + +TEST_CASE("test metal high contention on default stream eval") { + Device D_GPU = Device::gpu; + const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 8; + const int ops_per_thread = 5; + const int array_size = 16; + Stream default_gpu_stream = default_stream(D_GPU); + std::atomic successful_ops{0}; + std::vector thread_errors; + std::mutex errors_mutex; + TestResults results; + + auto task = [&](int thread_id) { + try { + for (int i = 0; i < ops_per_thread; ++i) { + float val = static_cast(thread_id * 100 + i + 1); + auto x = full({array_size, array_size}, val, default_gpu_stream); + auto y = full({array_size, array_size}, val * 0.5f, default_gpu_stream); + auto z = multiply(x, y); + eval(z); + + // Sample a value + auto sample = slice(z, {0, 0}, {1, 1}); + float actual = sample.item(); + float expected = val * val * 0.5f; + + bool shape_ok = (z.shape() == Shape{array_size, array_size}); + bool available_ok = z.is_available(); + bool values_match = (std::abs(actual - expected) < 1e-5); + + results.record_result(shape_ok, available_ok, values_match, expected, actual); + + if (shape_ok && available_ok && values_match) { + successful_ops++; + } + } + } catch (const std::exception& e) { + std::lock_guard lock(errors_mutex); + thread_errors.push_back(std::string("Thread ") + + std::to_string(thread_id) + + " exception: " + e.what()); + } + }; + + // Use the default helper for this test since it uses the default stream + CHECK_NOTHROW(run_in_threads_default(num_threads, task)); + + // Check for thread errors + CHECK(thread_errors.empty()); + if (!thread_errors.empty()) { + for (const auto& err : thread_errors) { + CAPTURE(err); + } + } + + // Check all results + for (size_t i = 0; i < results.shape_checks.size(); ++i) { + CAPTURE(i); + CHECK(results.shape_checks[i]); + CHECK(results.availability_checks[i]); + CHECK(results.value_checks[i]); + if (!results.value_checks[i]) { + CAPTURE(results.expected_values[i]); + CAPTURE(results.actual_values[i]); + } + } + + // Verify operation count + CHECK_EQ(successful_ops.load(), num_threads * ops_per_thread); +} + +TEST_CASE("test metal concurrent graph eval from different threads") { + Device D_GPU = Device::gpu; + const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 4; // Keep modest for clarity + const int array_size = 64; + TestResults all_results; + + // Pre-create streams + std::vector streams; + for (int i = 0; i < num_threads; ++i) { + streams.push_back(new_stream(D_GPU)); + } + synchronize(); + + auto task = [&](int thread_id, Stream s) { + try { + float val1_base = static_cast(thread_id + 1) * 10.0f; + auto x = full({array_size, array_size}, val1_base, s); + auto y = full({array_size, array_size}, val1_base + 1.0f, s); + auto z = add(x, y); + auto w = multiply(z, x); + eval(w); + + float expected_val = (val1_base + (val1_base + 1.0f)) * val1_base; + auto sample = slice(w, {0,0}, {1,1}); + float actual_val = sample.item(); + + bool shape_ok = (w.shape() == Shape{array_size, array_size}); + bool available_ok = w.is_available(); + bool value_ok = (std::abs(actual_val - expected_val) < 1e-4); + + all_results.record_result(shape_ok, available_ok, value_ok, expected_val, actual_val); + + } catch (const std::exception& e) { + std::cerr << "Thread " << thread_id << " exception in concurrent graph eval: " << e.what() << std::endl; + } + }; + + CHECK_NOTHROW(run_in_threads(num_threads, task, streams)); + + CHECK_EQ(all_results.shape_checks.size(), num_threads); // One result per thread + for (size_t i = 0; i < num_threads; ++i) { + CAPTURE(i); + CHECK(all_results.shape_checks[i]); + CHECK(all_results.availability_checks[i]); + CHECK(all_results.value_checks[i]); + if (!all_results.value_checks[i]) { + CAPTURE(all_results.expected_values[i]); + CAPTURE(all_results.actual_values[i]); + } + } +} \ No newline at end of file