mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
updated, simplified mutex for thread safety
This commit is contained in:
parent
e496c5a4b4
commit
28902ece4e
@ -4,12 +4,15 @@
|
|||||||
#include "mlx/backend/gpu/available.h"
|
#include "mlx/backend/gpu/available.h"
|
||||||
#include "mlx/backend/gpu/eval.h"
|
#include "mlx/backend/gpu/eval.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/thread_safey.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
namespace mlx::core::gpu {
|
namespace mlx::core::gpu {
|
||||||
|
|
||||||
|
std::mutex metal_operation_mutex;
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -30,6 +33,7 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void eval(array& arr) {
|
void eval(array& arr) {
|
||||||
|
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
auto s = arr.primitive().stream();
|
auto s = arr.primitive().stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
@ -78,6 +82,7 @@ void eval(array& arr) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void finalize(Stream s) {
|
void finalize(Stream s) {
|
||||||
|
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto cb = d.get_command_buffer(s.index);
|
auto cb = d.get_command_buffer(s.index);
|
||||||
@ -88,6 +93,7 @@ void finalize(Stream s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void synchronize(Stream s) {
|
void synchronize(Stream s) {
|
||||||
|
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto cb = d.get_command_buffer(s.index);
|
auto cb = d.get_command_buffer(s.index);
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/event.h"
|
#include "mlx/event.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/thread_safey.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -27,6 +28,7 @@ void Event::wait(Stream stream) {
|
|||||||
if (stream.device == Device::cpu) {
|
if (stream.device == Device::cpu) {
|
||||||
scheduler::enqueue(stream, [*this]() mutable { wait(); });
|
scheduler::enqueue(stream, [*this]() mutable { wait(); });
|
||||||
} else {
|
} else {
|
||||||
|
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||||
auto& d = metal::device(stream.device);
|
auto& d = metal::device(stream.device);
|
||||||
d.end_encoding(stream.index);
|
d.end_encoding(stream.index);
|
||||||
auto command_buffer = d.get_command_buffer(stream.index);
|
auto command_buffer = d.get_command_buffer(stream.index);
|
||||||
@ -41,6 +43,7 @@ void Event::signal(Stream stream) {
|
|||||||
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
|
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
|
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||||
auto& d = metal::device(stream.device);
|
auto& d = metal::device(stream.device);
|
||||||
d.end_encoding(stream.index);
|
d.end_encoding(stream.index);
|
||||||
auto command_buffer = d.get_command_buffer(stream.index);
|
auto command_buffer = d.get_command_buffer(stream.index);
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include "mlx/fence.h"
|
#include "mlx/fence.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/thread_safey.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
@ -68,6 +69,7 @@ void Fence::wait(Stream stream, const array& x) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||||
auto& d = metal::device(stream.device);
|
auto& d = metal::device(stream.device);
|
||||||
auto idx = stream.index;
|
auto idx = stream.index;
|
||||||
|
|
||||||
@ -116,6 +118,7 @@ void Fence::update(Stream stream, const array& x) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||||
auto& d = metal::device(stream.device);
|
auto& d = metal::device(stream.device);
|
||||||
auto idx = stream.index;
|
auto idx = stream.index;
|
||||||
if (!f.use_fast) {
|
if (!f.use_fast) {
|
||||||
|
7
mlx/backend/metal/thread_safey.h
Normal file
7
mlx/backend/metal/thread_safey.h
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
namespace mlx::core::gpu {
|
||||||
|
extern std::mutex metal_operation_mutex;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user