Update Compute Pipeline Creation API (#581)

* Add option to specialize metal functions on function constants
* Update Compute Pipeline Creation API
* Add options to make libraries from source and stitching
* Update function specialization name options
This commit is contained in:
Jagrit Digani 2024-01-30 15:42:36 -08:00 committed by GitHub
parent 1895d34c20
commit 375446453e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 328 additions and 19 deletions

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-24 Apple Inc.
#include <dlfcn.h>
#include <cstdlib>
@ -242,37 +242,127 @@ void Device::register_library(
}
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& name,
const std::string& lib_name /* = "mlx" */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
return it->second;
}
// Prepare new kernel
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
// Search for cached metal lib
MTL::Library* mtl_lib;
if (auto it = library_map_.find(name); it != library_map_.end()) {
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name);
mtl_lib = library_map_[lib_name];
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const std::string& source_string) {
auto pool = new_scoped_memory_pool();
auto ns_code =
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build metal library from source"
<< "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
auto pool = new_scoped_memory_pool();
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(desc, &error);
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build stitched metal library"
<< "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return mtl_lib;
}
MTL::Function* Device::get_function_(
const std::string& name,
MTL::Library* mtl_lib) {
// Pull kernel from library
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
auto mtl_function = mtl_lib->newFunction(ns_name);
return mtl_function;
}
MTL::Function* Device::get_function_(
const std::string& name,
const std::string& specialized_name,
const MTLFCList& func_consts,
MTL::Library* mtl_lib) {
if (func_consts.empty() && (specialized_name == name)) {
return get_function_(name, mtl_lib);
}
// Prepare function constants
auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init();
for (auto [value, type, index] : func_consts) {
mtl_func_consts->setConstantValue(value, type, index);
}
// Prepare function desc
auto desc = MTL::FunctionDescriptor::functionDescriptor();
desc->setName(NS::String::string(name.c_str(), NS::ASCIIStringEncoding));
desc->setSpecializedName(
NS::String::string(specialized_name.c_str(), NS::ASCIIStringEncoding));
desc->setConstantValues(mtl_func_consts);
// Pull kernel from library
NS::Error* error = nullptr;
auto mtl_function = mtl_lib->newFunction(desc, &error);
// Throw error if unable to build metal function
if (!mtl_function) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load function " << name << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
mtl_func_consts->release();
desc->release();
return mtl_function;
}
MTL::ComputePipelineState* Device::get_kernel_(
const std::string& name,
const MTL::Function* mtl_function) {
// Compile kernel to compute pipeline
NS::Error* error = nullptr;
MTL::ComputePipelineState* kernel;
if (mtl_function) {
kernel = device_->newComputePipelineState(mtl_function, &error);
mtl_function->release();
}
// Throw error if unable to compile metal function
if (!mtl_function || !kernel) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
@ -282,11 +372,170 @@ MTL::ComputePipelineState* Device::get_kernel(
throw std::runtime_error(msg.str());
}
// Add kernel to cache
kernel_map_.insert({name, kernel});
return kernel;
}
MTL::ComputePipelineState* Device::get_kernel_(
const std::string& name,
const MTL::Function* mtl_function,
const MTL::LinkedFunctions* linked_functions) {
// Check inputs
if (!linked_functions) {
return get_kernel_(name, mtl_function);
}
if (!mtl_function) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
throw std::runtime_error(msg.str());
}
// Prepare compute pipeline state descriptor
auto desc = MTL::ComputePipelineDescriptor::alloc()->init();
desc->setComputeFunction(mtl_function);
desc->setLinkedFunctions(linked_functions);
// Compile kernel to compute pipeline
NS::Error* error = nullptr;
auto kernel = device_->newComputePipelineState(
desc, MTL::PipelineOptionNone, nullptr, &error);
// Throw error if unable to compile metal function
if (!kernel) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return kernel;
}
MTL::Library* Device::get_library(
const std::string& name,
const std::string& source,
bool cache /* = true */) {
if (cache) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(source);
if (cache) {
library_map_.insert({name, mtl_lib});
}
return mtl_lib;
}
MTL::Library* Device::get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache /* = true */) {
if (cache) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(desc);
if (cache) {
library_map_.insert({name, mtl_lib});
}
return mtl_lib;
}
MTL::Function* Device::get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
}
MTL::Function* Device::get_function(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
return get_function(base_name, mtl_lib, specialized_name, func_consts);
}
MTL::LinkedFunctions* Device::get_linked_functions_(
const std::vector<MTL::Function*>& funcs) {
if (funcs.empty()) {
return nullptr;
}
auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
std::vector<NS::Object*> objs(funcs.size());
for (int i = 0; i < funcs.size(); i++) {
objs[i] = funcs[i];
}
NS::Array* funcs_arr = NS::Array::array(objs.data(), funcs.size());
lfuncs->setPrivateFunctions(funcs_arr);
return lfuncs;
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
const auto& kname = hash_name.empty() ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
// Pull kernel from library
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
kernel_map_.insert({kname, kernel});
return kernel;
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
// Look for cached kernel
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
}
Device& device(mlx::core::Device) {
static Device metal_device;
return metal_device;

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-24 Apple Inc.
#pragma once
@ -31,6 +31,9 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
return mtllib_path;
}
using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
class Device {
public:
Device();
@ -59,14 +62,71 @@ class Device {
const std::function<std::string(const std::string&)>& lib_path_func =
get_colocated_mtllib_path);
MTL::ComputePipelineState* get_kernel(
MTL::Library* get_library(
const std::string& name,
const std::string& lib_name = "mlx");
const std::string& source_string,
bool cache = true);
MTL::Library* get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache = true);
MTL::Function* get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
MTL::Function* get_function(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name = "",
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& hash_name = "",
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
private:
MTL::Library* get_library_cache_(const std::string& name);
MTL::Library* get_library_(const std::string& source_string);
MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
MTL::Function* get_function_(
const std::string& name,
const std::string& specialized_name,
const MTLFCList& func_consts,
MTL::Library* mtl_lib);
MTL::LinkedFunctions* get_linked_functions_(
const std::vector<MTL::Function*>& funcs);
MTL::ComputePipelineState* get_kernel_(
const std::string& name,
const MTL::Function* mtl_function);
MTL::ComputePipelineState* get_kernel_(
const std::string& name,
const MTL::Function* mtl_function,
const MTL::LinkedFunctions* linked_functions);
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;