mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Make the GPU device more thread safe (#1478)
* gpu stream safety * comment * fix
This commit is contained in:
@@ -219,7 +219,6 @@ void Device::new_queue(int index) {
|
||||
|
||||
// Multiple threads can ask the device for queues
|
||||
// We lock this as a critical section for safety
|
||||
const std::lock_guard<std::mutex> lock(mtx_);
|
||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||
debug_set_stream_queue_label(q, index);
|
||||
if (!q) {
|
||||
@@ -227,21 +226,21 @@ void Device::new_queue(int index) {
|
||||
"[metal::Device] Failed to make new command queue.");
|
||||
}
|
||||
queue_map_.insert({index, q});
|
||||
buffer_map_.insert({index, {0, nullptr}});
|
||||
encoder_map_.insert({index, nullptr});
|
||||
}
|
||||
|
||||
int Device::get_command_buffer_ops(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
return bit->second.first;
|
||||
return buffer_map_[index].first;
|
||||
}
|
||||
|
||||
void Device::increment_command_buffer_ops(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
bit->second.first++;
|
||||
buffer_map_[index].first++;
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
if (bit == buffer_map_.end()) {
|
||||
if (bit->second.second == nullptr) {
|
||||
auto qit = queue_map_.find(index);
|
||||
if (qit == queue_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
@@ -258,7 +257,7 @@ MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
cb->retain();
|
||||
|
||||
bit = buffer_map_.insert({index, {0, cb}}).first;
|
||||
bit->second = {0, cb};
|
||||
}
|
||||
return bit->second.second;
|
||||
}
|
||||
@@ -267,19 +266,18 @@ void Device::commit_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
bit->second.second->commit();
|
||||
bit->second.second->release();
|
||||
buffer_map_.erase(bit);
|
||||
bit->second = {0, nullptr};
|
||||
}
|
||||
|
||||
void Device::end_encoding(int index) {
|
||||
encoder_map_.erase(index);
|
||||
encoder_map_[index] = nullptr;
|
||||
}
|
||||
|
||||
CommandEncoder& Device::get_command_encoder(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit == encoder_map_.end()) {
|
||||
if (eit->second == nullptr) {
|
||||
auto cb = get_command_buffer(index);
|
||||
eit =
|
||||
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
|
||||
eit->second = std::make_unique<CommandEncoder>(cb);
|
||||
}
|
||||
return *(eit->second);
|
||||
}
|
||||
@@ -293,20 +291,7 @@ void Device::register_library(
|
||||
}
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib;
|
||||
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
|
||||
mtl_lib = it->second;
|
||||
} else { // Look for metallib alongside library
|
||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||
mtl_lib = library_map_[lib_name];
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library_(const std::string& source_string) {
|
||||
MTL::Library* Device::build_library_(const std::string& source_string) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
auto ns_code =
|
||||
@@ -332,25 +317,6 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
NS::Error* error = nullptr;
|
||||
auto mtl_lib = device_->newLibrary(desc, &error);
|
||||
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Function* Device::get_function_(
|
||||
const std::string& name,
|
||||
MTL::Library* mtl_lib) {
|
||||
@@ -465,68 +431,32 @@ MTL::ComputePipelineState* Device::get_kernel_(
|
||||
return kernel;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(const std::string& name) {
|
||||
MTL::Library* Device::get_library_(const std::string& name) {
|
||||
std::shared_lock lock(library_mtx_);
|
||||
auto it = library_map_.find(name);
|
||||
return (it != library_map_.end()) ? it->second : nullptr;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const std::string& source,
|
||||
bool cache /* = true */) {
|
||||
if (cache) {
|
||||
const std::function<std::string(void)>& builder) {
|
||||
{
|
||||
std::shared_lock rlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
auto mtl_lib = get_library_(source);
|
||||
|
||||
if (cache) {
|
||||
library_map_.insert({name, mtl_lib});
|
||||
std::unique_lock wlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto mtl_lib = build_library_(builder());
|
||||
library_map_.insert({name, mtl_lib});
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const MTL::StitchedLibraryDescriptor* desc,
|
||||
bool cache /* = true */) {
|
||||
if (cache) {
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
auto mtl_lib = get_library_(desc);
|
||||
|
||||
if (cache) {
|
||||
library_map_.insert({name, mtl_lib});
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Function* Device::get_function(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& specialized_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */) {
|
||||
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
|
||||
}
|
||||
|
||||
MTL::Function* Device::get_function(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name /* = "mlx" */,
|
||||
const std::string& specialized_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */) {
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
||||
|
||||
return get_function(base_name, mtl_lib, specialized_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||
const std::vector<MTL::Function*>& funcs) {
|
||||
if (funcs.empty()) {
|
||||
@@ -547,34 +477,55 @@ MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||
return lfuncs;
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel_(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& hash_name,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
// Single writer allowed
|
||||
std::unique_lock wlock(kernel_mtx_);
|
||||
|
||||
// Try loading again to avoid loading twice
|
||||
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
// Pull kernel from library
|
||||
auto mtl_function = get_function_(base_name, hash_name, func_consts, mtl_lib);
|
||||
|
||||
// Compile kernel to compute pipeline
|
||||
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
|
||||
auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs);
|
||||
|
||||
mtl_function->release();
|
||||
mtl_linked_funcs->release();
|
||||
|
||||
// Add kernel to cache
|
||||
auto inserted = kernel_map_.insert({hash_name, kernel});
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& hash_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
// Look for cached kernel
|
||||
const auto& kname = hash_name.empty() ? base_name : hash_name;
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
{
|
||||
// Multiple readers allowed
|
||||
std::shared_lock lock(kernel_mtx_);
|
||||
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
// Pull kernel from library
|
||||
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
|
||||
|
||||
// Compile kernel to compute pipeline
|
||||
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
|
||||
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
|
||||
|
||||
mtl_function->release();
|
||||
mtl_linked_funcs->release();
|
||||
|
||||
// Add kernel to cache
|
||||
kernel_map_.insert({kname, kernel});
|
||||
|
||||
return kernel;
|
||||
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
@@ -583,16 +534,19 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& hash_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
// Look for cached kernel
|
||||
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
{
|
||||
// Multiple readers allowed
|
||||
std::shared_lock lock(kernel_mtx_);
|
||||
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
||||
|
||||
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||
MTL::Library* mtl_lib = get_library_(lib_name);
|
||||
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device) {
|
||||
|
||||
Reference in New Issue
Block a user