diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index e90658650..eb9e54f8c 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -202,10 +202,7 @@ void Compiled::eval_gpu( // Get the kernel if someone else built it already auto& s = stream(); auto& d = metal::device(s.device); - auto lib = d.get_library(kernel_lib_); - - // If not we have to build it ourselves - if (lib == nullptr) { + auto lib = d.get_library(kernel_lib_, [&]() { std::ostringstream kernel; kernel << metal::utils() << metal::unary_ops() << metal::binary_ops() << metal::ternary_ops(); @@ -252,9 +249,8 @@ void Compiled::eval_gpu( /* contiguous = */ false, /* ndim = */ 0, /* dynamic_dims = */ true); - - lib = d.get_library(kernel_lib_, kernel.str()); - } + return kernel.str(); + }); // Figure out which kernel we are using auto& output_shape = outputs[0].shape(); diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index d7e9b143f..68898e914 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -39,10 +39,8 @@ void CustomKernel::eval_gpu( auto& d = metal::device(s.device); const auto& lib_name = name_; - auto lib = d.get_library(lib_name); - if (lib == nullptr) { - lib = d.get_library(lib_name, metal::utils() + source_); - } + auto lib = + d.get_library(lib_name, [this] { return metal::utils() + source_; }); auto kernel = d.get_kernel(name_, lib); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 8dfcf81e5..3a565e902 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -219,7 +219,6 @@ void Device::new_queue(int index) { // Multiple threads can ask the device for queues // We lock this as a critical section for safety - const std::lock_guard lock(mtx_); auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE); debug_set_stream_queue_label(q, index); if (!q) { @@ -227,21 +226,21 @@ void Device::new_queue(int index) { "[metal::Device] Failed to make new command queue."); } queue_map_.insert({index, q}); + buffer_map_.insert({index, {0, nullptr}}); + encoder_map_.insert({index, nullptr}); } int Device::get_command_buffer_ops(int index) { - auto bit = buffer_map_.find(index); - return bit->second.first; + return buffer_map_[index].first; } void Device::increment_command_buffer_ops(int index) { - auto bit = buffer_map_.find(index); - bit->second.first++; + buffer_map_[index].first++; } MTL::CommandBuffer* Device::get_command_buffer(int index) { auto bit = buffer_map_.find(index); - if (bit == buffer_map_.end()) { + if (bit->second.second == nullptr) { auto qit = queue_map_.find(index); if (qit == queue_map_.end()) { throw std::runtime_error( @@ -258,7 +257,7 @@ MTL::CommandBuffer* Device::get_command_buffer(int index) { // Increment ref count so the buffer is not garbage collected cb->retain(); - bit = buffer_map_.insert({index, {0, cb}}).first; + bit->second = {0, cb}; } return bit->second.second; } @@ -267,19 +266,18 @@ void Device::commit_command_buffer(int index) { auto bit = buffer_map_.find(index); bit->second.second->commit(); bit->second.second->release(); - buffer_map_.erase(bit); + bit->second = {0, nullptr}; } void Device::end_encoding(int index) { - encoder_map_.erase(index); + encoder_map_[index] = nullptr; } CommandEncoder& Device::get_command_encoder(int index) { auto eit = encoder_map_.find(index); - if (eit == encoder_map_.end()) { + if (eit->second == nullptr) { auto cb = get_command_buffer(index); - eit = - encoder_map_.emplace(index, std::make_unique(cb)).first; + eit->second = std::make_unique(cb); } return *(eit->second); } @@ -293,20 +291,7 @@ void Device::register_library( } } -MTL::Library* Device::get_library_cache_(const std::string& lib_name) { - // Search for cached metal lib - MTL::Library* mtl_lib; - if (auto it = library_map_.find(lib_name); it != library_map_.end()) { - mtl_lib = it->second; - } else { // Look for metallib alongside library - register_library(lib_name, get_colocated_mtllib_path(lib_name)); - mtl_lib = library_map_[lib_name]; - } - - return mtl_lib; -} - -MTL::Library* Device::get_library_(const std::string& source_string) { +MTL::Library* Device::build_library_(const std::string& source_string) { auto pool = new_scoped_memory_pool(); auto ns_code = @@ -332,25 +317,6 @@ MTL::Library* Device::get_library_(const std::string& source_string) { return mtl_lib; } -MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) { - auto pool = new_scoped_memory_pool(); - - NS::Error* error = nullptr; - auto mtl_lib = device_->newLibrary(desc, &error); - - // Throw error if unable to compile library - if (!mtl_lib) { - std::ostringstream msg; - msg << "[metal::Device] Unable to build stitched metal library" << "\n"; - if (error) { - msg << error->localizedDescription()->utf8String() << "\n"; - } - throw std::runtime_error(msg.str()); - } - - return mtl_lib; -} - MTL::Function* Device::get_function_( const std::string& name, MTL::Library* mtl_lib) { @@ -465,68 +431,32 @@ MTL::ComputePipelineState* Device::get_kernel_( return kernel; } -MTL::Library* Device::get_library(const std::string& name) { +MTL::Library* Device::get_library_(const std::string& name) { + std::shared_lock lock(library_mtx_); auto it = library_map_.find(name); return (it != library_map_.end()) ? it->second : nullptr; } MTL::Library* Device::get_library( const std::string& name, - const std::string& source, - bool cache /* = true */) { - if (cache) { + const std::function& builder) { + { + std::shared_lock rlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { return it->second; } } - auto mtl_lib = get_library_(source); - - if (cache) { - library_map_.insert({name, mtl_lib}); + std::unique_lock wlock(library_mtx_); + if (auto it = library_map_.find(name); it != library_map_.end()) { + return it->second; } + auto mtl_lib = build_library_(builder()); + library_map_.insert({name, mtl_lib}); return mtl_lib; } -MTL::Library* Device::get_library( - const std::string& name, - const MTL::StitchedLibraryDescriptor* desc, - bool cache /* = true */) { - if (cache) { - if (auto it = library_map_.find(name); it != library_map_.end()) { - return it->second; - } - } - - auto mtl_lib = get_library_(desc); - - if (cache) { - library_map_.insert({name, mtl_lib}); - } - - return mtl_lib; -} - -MTL::Function* Device::get_function( - const std::string& base_name, - MTL::Library* mtl_lib, - const std::string& specialized_name /* = "" */, - const MTLFCList& func_consts /* = {} */) { - return get_function_(base_name, specialized_name, func_consts, mtl_lib); -} - -MTL::Function* Device::get_function( - const std::string& base_name, - const std::string& lib_name /* = "mlx" */, - const std::string& specialized_name /* = "" */, - const MTLFCList& func_consts /* = {} */) { - // Search for cached metal lib - MTL::Library* mtl_lib = get_library_cache_(lib_name); - - return get_function(base_name, mtl_lib, specialized_name, func_consts); -} - MTL::LinkedFunctions* Device::get_linked_functions_( const std::vector& funcs) { if (funcs.empty()) { @@ -547,34 +477,55 @@ MTL::LinkedFunctions* Device::get_linked_functions_( return lfuncs; } +MTL::ComputePipelineState* Device::get_kernel_( + const std::string& base_name, + MTL::Library* mtl_lib, + const std::string& hash_name, + const MTLFCList& func_consts /* = {} */, + const std::vector& linked_functions /* = {} */) { + // Single writer allowed + std::unique_lock wlock(kernel_mtx_); + + // Try loading again to avoid loading twice + if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) { + return it->second; + } + + auto pool = new_scoped_memory_pool(); + + // Pull kernel from library + auto mtl_function = get_function_(base_name, hash_name, func_consts, mtl_lib); + + // Compile kernel to compute pipeline + auto mtl_linked_funcs = get_linked_functions_(linked_functions); + auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs); + + mtl_function->release(); + mtl_linked_funcs->release(); + + // Add kernel to cache + auto inserted = kernel_map_.insert({hash_name, kernel}); + + return kernel; +} + MTL::ComputePipelineState* Device::get_kernel( const std::string& base_name, MTL::Library* mtl_lib, const std::string& hash_name /* = "" */, const MTLFCList& func_consts /* = {} */, const std::vector& linked_functions /* = {} */) { - auto pool = new_scoped_memory_pool(); - - // Look for cached kernel const auto& kname = hash_name.empty() ? base_name : hash_name; - if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { - return it->second; + { + // Multiple readers allowed + std::shared_lock lock(kernel_mtx_); + + // Look for cached kernel + if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { + return it->second; + } } - - // Pull kernel from library - auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib); - - // Compile kernel to compute pipeline - auto mtl_linked_funcs = get_linked_functions_(linked_functions); - auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs); - - mtl_function->release(); - mtl_linked_funcs->release(); - - // Add kernel to cache - kernel_map_.insert({kname, kernel}); - - return kernel; + return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions); } MTL::ComputePipelineState* Device::get_kernel( @@ -583,16 +534,19 @@ MTL::ComputePipelineState* Device::get_kernel( const std::string& hash_name /* = "" */, const MTLFCList& func_consts /* = {} */, const std::vector& linked_functions /* = {} */) { - // Look for cached kernel const auto& kname = hash_name.size() == 0 ? base_name : hash_name; - if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { - return it->second; + { + // Multiple readers allowed + std::shared_lock lock(kernel_mtx_); + + // Look for cached kernel + if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { + return it->second; + } } - // Search for cached metal lib - MTL::Library* mtl_lib = get_library_cache_(lib_name); - - return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions); + MTL::Library* mtl_lib = get_library_(lib_name); + return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions); } Device& device(mlx::core::Device) { diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 2841e8103..7f851f929 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -114,29 +115,9 @@ class Device { } } - MTL::Library* get_library(const std::string& name); - MTL::Library* get_library( const std::string& name, - const std::string& source_string, - bool cache = true); - - MTL::Library* get_library( - const std::string& name, - const MTL::StitchedLibraryDescriptor* desc, - bool cache = true); - - MTL::Function* get_function( - const std::string& base_name, - MTL::Library* mtl_lib, - const std::string& specialized_name = "", - const MTLFCList& func_consts = {}); - - MTL::Function* get_function( - const std::string& base_name, - const std::string& lib_name = "mlx", - const std::string& specialized_name = "", - const MTLFCList& func_consts = {}); + const std::function& builder); MTL::ComputePipelineState* get_kernel( const std::string& base_name, @@ -158,8 +139,8 @@ class Device { private: MTL::Library* get_library_cache_(const std::string& name); - MTL::Library* get_library_(const std::string& source_string); - MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc); + MTL::Library* get_library_(const std::string& name); + MTL::Library* build_library_(const std::string& source_string); MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib); @@ -181,13 +162,23 @@ class Device { const MTL::Function* mtl_function, const MTL::LinkedFunctions* linked_functions); + MTL::ComputePipelineState* get_kernel_( + const std::string& base_name, + MTL::Library* mtl_lib, + const std::string& hash_name, + const MTLFCList& func_consts = {}, + const std::vector& linked_functions = {}); + MTL::Device* device_; std::unordered_map queue_map_; std::unordered_map> buffer_map_; std::unordered_map> encoder_map_; + + std::shared_mutex kernel_mtx_; std::unordered_map kernel_map_; + + std::shared_mutex library_mtx_; std::unordered_map library_map_; - std::mutex mtx_; }; Device& device(mlx::core::Device); diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index b4ae377d5..dd89b415e 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -60,32 +60,6 @@ std::string gen_hadamard_codelet(int m) { return source.str(); } -void launch_hadamard( - const array& in, - array& out, - int batch_size, - int threads_per, - const std::string kernel_name, - float scale, - const Stream& s) { - auto& d = metal::device(s.device); - - const auto& lib_name = kernel_name.substr(1); - auto lib = d.get_library(lib_name); - auto kernel = d.get_kernel(kernel_name, lib); - assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup()); - - auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); - compute_encoder.set_input_array(in, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&scale, sizeof(float), 2); - - MTL::Size group_dims = MTL::Size(1, threads_per, 1); - MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); -} - void Hadamard::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); @@ -113,7 +87,8 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc_or_wait(out.nbytes())); } - auto [n, m] = decompose_hadamard(in.shape(axis)); + int n, m; + std::tie(n, m) = decompose_hadamard(in.shape(axis)); if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) { throw std::invalid_argument( @@ -129,8 +104,7 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { auto kernel_name = kname.str(); auto& d = metal::device(s.device); const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; auto codelet = gen_hadamard_codelet(m); kernel_source << metal::utils() << codelet << metal::hadamard(); @@ -148,12 +122,31 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { n, m, read_width); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); int batch_size = in.size() / n; int threads_per = n / max_radix; + auto& compute_encoder = d.get_command_encoder(s.index); + + auto launch_hadamard = [&](const array& in, + array& out, + const std::string& kernel_name, + float scale) { + auto kernel = d.get_kernel(kernel_name, lib); + assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup()); + + compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder->setBytes(&scale, sizeof(float), 2); + + MTL::Size group_dims = MTL::Size(1, threads_per, 1); + MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + }; + if (m > 1) { // When m is greater than 1, we decompose the // computation into two uploads to the GPU: @@ -171,29 +164,14 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { temp.set_data(allocator::malloc_or_wait(temp.nbytes())); copies.push_back(temp); - launch_hadamard( - in_contiguous, - temp, - batch_size, - threads_per, - "n" + kernel_name, - 1.0, - s); + launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0); // Metal sometimes reports 256 max threads per group for hadamard_m kernel threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP); batch_size = in.size() / m / read_width / threads_per; - launch_hadamard( - temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s); + launch_hadamard(temp, out, "m" + kernel_name, scale_); } else { - launch_hadamard( - in_contiguous, - out, - batch_size, - threads_per, - "n" + kernel_name, - scale_, - s); + launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_); } if (!copies.empty()) { diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 6288025e4..37c511b39 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -64,8 +64,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { kernel_name = lib_name; } - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gather(); std::string out_type_str = get_type_string(out.dtype()); @@ -83,8 +82,8 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { idx_args, idx_arr, idx_ndim); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kernel_name, lib); @@ -236,8 +235,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { kernel_name = kname.str(); } - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::reduce_utils() << metal::scatter(); @@ -277,8 +275,8 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { nidx, idx_args, idx_arr); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kernel_name, lib); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 37e301142..f9a998c5d 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -25,15 +25,15 @@ MTL::ComputePipelineState* get_arange_kernel( metal::Device& d, const std::string& kernel_name, const array& out) { - const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(kernel_name, [&]() { std::ostringstream kernel_source; - kernel_source - << metal::utils() << metal::arange() - << fmt::format(arange_kernels, lib_name, get_type_string(out.dtype())); - lib = d.get_library(lib_name, kernel_source.str()); - } + kernel_source << metal::utils() << metal::arange() + << fmt::format( + arange_kernels, + kernel_name, + get_type_string(out.dtype())); + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -43,8 +43,7 @@ MTL::ComputePipelineState* get_unary_kernel( Dtype out_type, const std::string op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::unary_ops() << metal::unary(); kernel_source << get_template_definition( @@ -55,8 +54,8 @@ MTL::ComputePipelineState* get_unary_kernel( "g_" + lib_name, "unary_g", get_type_string(out_type), op); kernel_source << get_template_definition( "gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -105,13 +104,12 @@ MTL::ComputePipelineState* get_binary_kernel( Dtype out_type, const std::string op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::binary_ops() << metal::binary(); add_binary_kernels(lib_name, in_type, out_type, op, kernel_source); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -122,14 +120,13 @@ MTL::ComputePipelineState* get_binary_two_kernel( Dtype out_type, const std::string op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::binary_ops() << metal::binary_two(); add_binary_kernels(lib_name, in_type, out_type, op, kernel_source); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -139,8 +136,7 @@ MTL::ComputePipelineState* get_ternary_kernel( Dtype type, const std::string op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; const std::array, 6> kernel_types = {{ {"v", "ternary_v"}, @@ -159,8 +155,8 @@ MTL::ComputePipelineState* get_ternary_kernel( } kernel_source << get_template_definition( "gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -170,8 +166,7 @@ MTL::ComputePipelineState* get_copy_kernel( const array& in, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); @@ -198,8 +193,8 @@ MTL::ComputePipelineState* get_copy_kernel( "gg_" + lib_name, "copy_gg", in_type, out_type) << get_template_definition( "ggn4_" + lib_name, "copy_gg", in_type, out_type, 4); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -209,8 +204,7 @@ MTL::ComputePipelineState* get_softmax_kernel( bool precise, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&] { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::softmax() << fmt::format( @@ -218,8 +212,8 @@ MTL::ComputePipelineState* get_softmax_kernel( lib_name, get_type_string(out.dtype()), get_type_string(precise ? float32 : out.dtype())); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -232,8 +226,7 @@ MTL::ComputePipelineState* get_scan_kernel( const array& in, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::string op_name = "Cum" + reduce_type; op_name[3] = toupper(op_name[3]); std::ostringstream kernel_source; @@ -246,8 +239,8 @@ MTL::ComputePipelineState* get_scan_kernel( op_name, inclusive, reverse); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -259,8 +252,7 @@ MTL::ComputePipelineState* get_sort_kernel( int bn, int tn) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); @@ -285,8 +277,8 @@ MTL::ComputePipelineState* get_sort_kernel( bn, tn); } - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -298,8 +290,7 @@ MTL::ComputePipelineState* get_mb_sort_kernel( int bn, int tn) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::sort(); std::array, 3> kernel_types = { @@ -316,8 +307,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel( bn, tn); } - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -325,8 +316,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel( metal::Device& d, const std::string& kernel_name, const array& out) { - auto lib = d.get_library(kernel_name); - if (lib == nullptr) { + auto lib = d.get_library(kernel_name, [&]() { std::ostringstream kernel_source; std::string op_type = op_name(out); op_type[0] = std::toupper(op_name(out)[0]); @@ -335,8 +325,8 @@ MTL::ComputePipelineState* get_reduce_init_kernel( kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce(); kernel_source << get_template_definition( kernel_name, "init_reduce", out_type, op); - lib = d.get_library(kernel_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -350,8 +340,7 @@ MTL::ComputePipelineState* get_reduce_kernel( int ndim /* = -1 */, int bm /* = -1 */, int bn /* = -1 */) { - auto lib = d.get_library(kernel_name); - if (lib == nullptr) { + auto lib = d.get_library(kernel_name, [&]() { std::string op_type = op_name; op_type[0] = std::toupper(op_name[0]); std::ostringstream kernel_source; @@ -369,8 +358,8 @@ MTL::ComputePipelineState* get_reduce_kernel( kernel_source << get_template_definition( kernel_name, func_name, in_type, out_type, op); } - lib = d.get_library(kernel_name, kernel_source.str()); - } + return kernel_source.str(); + }); auto st = d.get_kernel(kernel_name, lib); return st; } @@ -389,8 +378,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel( int wm, int wn) { const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_fused() @@ -405,8 +393,8 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel( "wn"_a = wn, "trans_a"_a = transpose_a, "trans_b"_a = transpose_b); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } @@ -425,8 +413,7 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( bool mn_aligned, bool k_aligned) { const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_splitk() @@ -444,8 +431,8 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( "trans_b"_a = transpose_b, "mn_aligned"_a = mn_aligned, "k_aligned"_a = k_aligned); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -456,8 +443,7 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( const array& out, bool axbpy) { const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_splitk() @@ -467,8 +453,8 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( "name"_a = lib_name, "atype"_a = get_type_string(in.dtype()), "otype"_a = get_type_string(out.dtype())); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -488,8 +474,7 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( bool mn_aligned, bool k_aligned) { const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; auto out_mask_type = mask_out.has_value() ? get_type_string((*mask_out).dtype()) @@ -513,8 +498,8 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( "trans_b"_a = transpose_b, "mn_aligned"_a = mn_aligned, "k_aligned"_a = k_aligned); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -533,8 +518,7 @@ MTL::ComputePipelineState* get_gemv_masked_kernel( int tn, bool contiguous) { const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; auto out_mask_type = mask_out.has_value() ? get_type_string((*mask_out).dtype()) @@ -556,8 +540,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel( "tn"_a = tn, "trans"_a = transpose_mat ? "t_" : "", "nc"_a = contiguous ? "0" : "1"); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -573,8 +557,7 @@ MTL::ComputePipelineState* get_steel_conv_kernel( int n_channel_specialization, bool small_filter) { const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::conv() << metal::steel_conv() << fmt::format( @@ -588,8 +571,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel( "wn"_a = wn, "n_channels"_a = n_channel_specialization, "small_filter"_a = small_filter); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -603,8 +586,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel( int wm, int wn) { const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::conv() << metal::steel_conv_general() @@ -617,8 +599,8 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel( "bk"_a = bk, "wm"_a = wm, "wn"_a = wn); - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } @@ -629,13 +611,12 @@ MTL::ComputePipelineState* get_fft_kernel( const metal::MTLFCList& func_consts, const std::string& template_def) { const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; std::string kernel_string; kernel_source << metal::fft() << template_def; - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); } @@ -644,13 +625,12 @@ MTL::ComputePipelineState* get_quantized_kernel( const std::string& kernel_name, const std::string& template_def) { const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name); - if (lib == nullptr) { + auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::quantized() << template_def; - lib = d.get_library(lib_name, kernel_source.str()); - } + return kernel_source.str(); + }); return d.get_kernel(kernel_name, lib); } diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 81b51209a..5856d84fd 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -122,6 +122,21 @@ class TestEval(mlx_tests.MLXTestCase): out = mx.vjp(fn, (x,), (y,)) self.assertEqual(peak_mem, mx.metal.get_peak_memory()) + def test_async_eval_with_multiple_streams(self): + x = mx.array([1.0]) + y = mx.array([1.0]) + a = mx.array([1.0]) + b = mx.array([1.0]) + + d = mx.default_device() + s2 = mx.new_stream(d) + + for _ in range(50): + for _ in range(20): + x = x + y + mx.async_eval(x) + mx.eval(a + b) + if __name__ == "__main__": unittest.main()