From 375446453e6ffc6997df673e42874dd23c00a2b4 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Tue, 30 Jan 2024 15:42:36 -0800 Subject: [PATCH] Update Compute Pipeline Creation API (#581) * Add option to specialize metal functions on function constants * Update Compute Pipeline Creation API * Add options to make libraries from source and stitching * Update function specialization name options --- mlx/backend/metal/device.cpp | 281 +++++++++++++++++++++++++++++++++-- mlx/backend/metal/device.h | 66 +++++++- 2 files changed, 328 insertions(+), 19 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 060197343..999c67084 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-24 Apple Inc. #include #include @@ -242,37 +242,127 @@ void Device::register_library( } } -MTL::ComputePipelineState* Device::get_kernel( - const std::string& name, - const std::string& lib_name /* = "mlx" */) { - auto pool = new_scoped_memory_pool(); - // Look for cached kernel - if (auto it = kernel_map_.find(name); it != kernel_map_.end()) { - return it->second; - } - - // Prepare new kernel - +MTL::Library* Device::get_library_cache_(const std::string& lib_name) { // Search for cached metal lib MTL::Library* mtl_lib; - if (auto it = library_map_.find(name); it != library_map_.end()) { + if (auto it = library_map_.find(lib_name); it != library_map_.end()) { mtl_lib = it->second; } else { // Look for metallib alongside library register_library(lib_name); mtl_lib = library_map_[lib_name]; } + return mtl_lib; +} + +MTL::Library* Device::get_library_(const std::string& source_string) { + auto pool = new_scoped_memory_pool(); + + auto ns_code = + NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding); + + NS::Error* error = nullptr; + auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error); + + // Throw error if unable to compile library + if (!mtl_lib) { + std::ostringstream msg; + msg << "[metal::Device] Unable to load build metal library from source" + << "\n"; + if (error) { + msg << error->localizedDescription()->utf8String() << "\n"; + } + throw std::runtime_error(msg.str()); + } + + return mtl_lib; +} + +MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) { + auto pool = new_scoped_memory_pool(); + + NS::Error* error = nullptr; + auto mtl_lib = device_->newLibrary(desc, &error); + + // Throw error if unable to compile library + if (!mtl_lib) { + std::ostringstream msg; + msg << "[metal::Device] Unable to load build stitched metal library" + << "\n"; + if (error) { + msg << error->localizedDescription()->utf8String() << "\n"; + } + throw std::runtime_error(msg.str()); + } + + return mtl_lib; +} + +MTL::Function* Device::get_function_( + const std::string& name, + MTL::Library* mtl_lib) { // Pull kernel from library auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding); auto mtl_function = mtl_lib->newFunction(ns_name); + return mtl_function; +} + +MTL::Function* Device::get_function_( + const std::string& name, + const std::string& specialized_name, + const MTLFCList& func_consts, + MTL::Library* mtl_lib) { + if (func_consts.empty() && (specialized_name == name)) { + return get_function_(name, mtl_lib); + } + + // Prepare function constants + auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init(); + + for (auto [value, type, index] : func_consts) { + mtl_func_consts->setConstantValue(value, type, index); + } + + // Prepare function desc + auto desc = MTL::FunctionDescriptor::functionDescriptor(); + desc->setName(NS::String::string(name.c_str(), NS::ASCIIStringEncoding)); + desc->setSpecializedName( + NS::String::string(specialized_name.c_str(), NS::ASCIIStringEncoding)); + desc->setConstantValues(mtl_func_consts); + + // Pull kernel from library + NS::Error* error = nullptr; + auto mtl_function = mtl_lib->newFunction(desc, &error); + + // Throw error if unable to build metal function + if (!mtl_function) { + std::ostringstream msg; + msg << "[metal::Device] Unable to load function " << name << "\n"; + if (error) { + msg << error->localizedDescription()->utf8String() << "\n"; + } + throw std::runtime_error(msg.str()); + } + + mtl_func_consts->release(); + desc->release(); + + return mtl_function; +} + +MTL::ComputePipelineState* Device::get_kernel_( + const std::string& name, + const MTL::Function* mtl_function) { // Compile kernel to compute pipeline NS::Error* error = nullptr; MTL::ComputePipelineState* kernel; + if (mtl_function) { kernel = device_->newComputePipelineState(mtl_function, &error); - mtl_function->release(); } + + // Throw error if unable to compile metal function if (!mtl_function || !kernel) { std::ostringstream msg; msg << "[metal::Device] Unable to load kernel " << name << "\n"; @@ -282,11 +372,170 @@ MTL::ComputePipelineState* Device::get_kernel( throw std::runtime_error(msg.str()); } - // Add kernel to cache - kernel_map_.insert({name, kernel}); return kernel; } +MTL::ComputePipelineState* Device::get_kernel_( + const std::string& name, + const MTL::Function* mtl_function, + const MTL::LinkedFunctions* linked_functions) { + // Check inputs + if (!linked_functions) { + return get_kernel_(name, mtl_function); + } + + if (!mtl_function) { + std::ostringstream msg; + msg << "[metal::Device] Unable to load kernel " << name << "\n"; + throw std::runtime_error(msg.str()); + } + + // Prepare compute pipeline state descriptor + auto desc = MTL::ComputePipelineDescriptor::alloc()->init(); + desc->setComputeFunction(mtl_function); + desc->setLinkedFunctions(linked_functions); + + // Compile kernel to compute pipeline + NS::Error* error = nullptr; + auto kernel = device_->newComputePipelineState( + desc, MTL::PipelineOptionNone, nullptr, &error); + + // Throw error if unable to compile metal function + if (!kernel) { + std::ostringstream msg; + msg << "[metal::Device] Unable to load kernel " << name << "\n"; + if (error) { + msg << error->localizedDescription()->utf8String() << "\n"; + } + throw std::runtime_error(msg.str()); + } + + return kernel; +} + +MTL::Library* Device::get_library( + const std::string& name, + const std::string& source, + bool cache /* = true */) { + if (cache) { + if (auto it = library_map_.find(name); it != library_map_.end()) { + return it->second; + } + } + + auto mtl_lib = get_library_(source); + + if (cache) { + library_map_.insert({name, mtl_lib}); + } + + return mtl_lib; +} + +MTL::Library* Device::get_library( + const std::string& name, + const MTL::StitchedLibraryDescriptor* desc, + bool cache /* = true */) { + if (cache) { + if (auto it = library_map_.find(name); it != library_map_.end()) { + return it->second; + } + } + + auto mtl_lib = get_library_(desc); + + if (cache) { + library_map_.insert({name, mtl_lib}); + } + + return mtl_lib; +} + +MTL::Function* Device::get_function( + const std::string& base_name, + MTL::Library* mtl_lib, + const std::string& specialized_name /* = "" */, + const MTLFCList& func_consts /* = {} */) { + return get_function_(base_name, specialized_name, func_consts, mtl_lib); +} + +MTL::Function* Device::get_function( + const std::string& base_name, + const std::string& lib_name /* = "mlx" */, + const std::string& specialized_name /* = "" */, + const MTLFCList& func_consts /* = {} */) { + // Search for cached metal lib + MTL::Library* mtl_lib = get_library_cache_(lib_name); + + return get_function(base_name, mtl_lib, specialized_name, func_consts); +} + +MTL::LinkedFunctions* Device::get_linked_functions_( + const std::vector& funcs) { + if (funcs.empty()) { + return nullptr; + } + + auto lfuncs = MTL::LinkedFunctions::linkedFunctions(); + + std::vector objs(funcs.size()); + for (int i = 0; i < funcs.size(); i++) { + objs[i] = funcs[i]; + } + + NS::Array* funcs_arr = NS::Array::array(objs.data(), funcs.size()); + + lfuncs->setPrivateFunctions(funcs_arr); + + return lfuncs; +} + +MTL::ComputePipelineState* Device::get_kernel( + const std::string& base_name, + MTL::Library* mtl_lib, + const std::string& hash_name /* = "" */, + const MTLFCList& func_consts /* = {} */, + const std::vector& linked_functions /* = {} */) { + auto pool = new_scoped_memory_pool(); + + // Look for cached kernel + const auto& kname = hash_name.empty() ? base_name : hash_name; + if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { + return it->second; + } + + // Pull kernel from library + auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib); + + // Compile kernel to compute pipeline + auto mtl_linked_funcs = get_linked_functions_(linked_functions); + auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs); + mtl_function->release(); + mtl_linked_funcs->release(); + + // Add kernel to cache + kernel_map_.insert({kname, kernel}); + return kernel; +} + +MTL::ComputePipelineState* Device::get_kernel( + const std::string& base_name, + const std::string& lib_name /* = "mlx" */, + const std::string& hash_name /* = "" */, + const MTLFCList& func_consts /* = {} */, + const std::vector& linked_functions /* = {} */) { + // Look for cached kernel + const auto& kname = hash_name.size() == 0 ? base_name : hash_name; + if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { + return it->second; + } + + // Search for cached metal lib + MTL::Library* mtl_lib = get_library_cache_(lib_name); + + return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions); +} + Device& device(mlx::core::Device) { static Device metal_device; return metal_device; diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 45449a332..6acfe9332 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-24 Apple Inc. #pragma once @@ -31,6 +31,9 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) { return mtllib_path; } +using MTLFCList = + std::vector>; + class Device { public: Device(); @@ -59,14 +62,71 @@ class Device { const std::function& lib_path_func = get_colocated_mtllib_path); - MTL::ComputePipelineState* get_kernel( + MTL::Library* get_library( const std::string& name, - const std::string& lib_name = "mlx"); + const std::string& source_string, + bool cache = true); + + MTL::Library* get_library( + const std::string& name, + const MTL::StitchedLibraryDescriptor* desc, + bool cache = true); + + MTL::Function* get_function( + const std::string& base_name, + MTL::Library* mtl_lib, + const std::string& specialized_name = "", + const MTLFCList& func_consts = {}); + + MTL::Function* get_function( + const std::string& base_name, + const std::string& lib_name = "mlx", + const std::string& specialized_name = "", + const MTLFCList& func_consts = {}); + + MTL::ComputePipelineState* get_kernel( + const std::string& base_name, + MTL::Library* mtl_lib, + const std::string& hash_name = "", + const MTLFCList& func_consts = {}, + const std::vector& linked_functions = {}); + + MTL::ComputePipelineState* get_kernel( + const std::string& base_name, + const std::string& lib_name = "mlx", + const std::string& hash_name = "", + const MTLFCList& func_consts = {}, + const std::vector& linked_functions = {}); MTL::ArgumentEncoder* argument_encoder( const std::vector& arg_descs) const; private: + MTL::Library* get_library_cache_(const std::string& name); + + MTL::Library* get_library_(const std::string& source_string); + MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc); + + MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib); + + MTL::Function* get_function_( + const std::string& name, + const std::string& specialized_name, + const MTLFCList& func_consts, + MTL::Library* mtl_lib); + + MTL::LinkedFunctions* get_linked_functions_( + const std::vector& funcs); + + MTL::ComputePipelineState* get_kernel_( + const std::string& name, + const MTL::Function* mtl_function); + + MTL::ComputePipelineState* get_kernel_( + const std::string& name, + const MTL::Function* mtl_function, + const MTL::LinkedFunctions* linked_functions); + MTL::Device* device_; std::unordered_map queue_map_; std::unordered_map> buffer_map_;