diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index 49783200a..a21853bf5 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -4,12 +4,15 @@ #include "mlx/backend/gpu/available.h" #include "mlx/backend/gpu/eval.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/thread_safey.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" namespace mlx::core::gpu { +std::mutex metal_operation_mutex; + bool is_available() { return true; } @@ -30,6 +33,7 @@ inline void check_error(MTL::CommandBuffer* cbuf) { } void eval(array& arr) { + std::lock_guard lock(metal_operation_mutex); auto pool = metal::new_scoped_memory_pool(); auto s = arr.primitive().stream(); auto& d = metal::device(s.device); @@ -78,6 +82,7 @@ void eval(array& arr) { } void finalize(Stream s) { + std::lock_guard lock(metal_operation_mutex); auto pool = metal::new_scoped_memory_pool(); auto& d = metal::device(s.device); auto cb = d.get_command_buffer(s.index); @@ -88,6 +93,7 @@ void finalize(Stream s) { } void synchronize(Stream s) { + std::lock_guard lock(metal_operation_mutex); auto pool = metal::new_scoped_memory_pool(); auto& d = metal::device(s.device); auto cb = d.get_command_buffer(s.index); diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index eb7f1b58a..e7905105a 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -2,6 +2,7 @@ #include "mlx/event.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/thread_safey.h" #include "mlx/scheduler.h" namespace mlx::core { @@ -27,6 +28,7 @@ void Event::wait(Stream stream) { if (stream.device == Device::cpu) { scheduler::enqueue(stream, [*this]() mutable { wait(); }); } else { + std::lock_guard lock(gpu::metal_operation_mutex); auto& d = metal::device(stream.device); d.end_encoding(stream.index); auto command_buffer = d.get_command_buffer(stream.index); @@ -41,6 +43,7 @@ void Event::signal(Stream stream) { static_cast(event_.get())->setSignaledValue(value()); }); } else { + std::lock_guard lock(gpu::metal_operation_mutex); auto& d = metal::device(stream.device); d.end_encoding(stream.index); auto command_buffer = d.get_command_buffer(stream.index); diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index d4a88d983..4b9b8f27f 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include "mlx/fence.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/thread_safey.h" #include "mlx/scheduler.h" #include "mlx/utils.h" @@ -68,6 +69,7 @@ void Fence::wait(Stream stream, const array& x) { return; } + std::lock_guard lock(gpu::metal_operation_mutex); auto& d = metal::device(stream.device); auto idx = stream.index; @@ -116,6 +118,7 @@ void Fence::update(Stream stream, const array& x) { return; } + std::lock_guard lock(gpu::metal_operation_mutex); auto& d = metal::device(stream.device); auto idx = stream.index; if (!f.use_fast) { diff --git a/mlx/backend/metal/thread_safey.h b/mlx/backend/metal/thread_safey.h new file mode 100644 index 000000000..0666a64d4 --- /dev/null +++ b/mlx/backend/metal/thread_safey.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +namespace mlx::core::gpu { + extern std::mutex metal_operation_mutex; +} \ No newline at end of file