mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
Update Compute Pipeline Creation API (#581)
* Add option to specialize metal functions on function constants * Update Compute Pipeline Creation API * Add options to make libraries from source and stitching * Update function specialization name options
This commit is contained in:
parent
1895d34c20
commit
375446453e
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <cstdlib>
|
||||
@ -242,37 +242,127 @@ void Device::register_library(
|
||||
}
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& name,
|
||||
const std::string& lib_name /* = "mlx" */) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Prepare new kernel
|
||||
|
||||
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib;
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
|
||||
mtl_lib = it->second;
|
||||
} else { // Look for metallib alongside library
|
||||
register_library(lib_name);
|
||||
mtl_lib = library_map_[lib_name];
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library_(const std::string& source_string) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
auto ns_code =
|
||||
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
|
||||
|
||||
NS::Error* error = nullptr;
|
||||
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
|
||||
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load build metal library from source"
|
||||
<< "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
NS::Error* error = nullptr;
|
||||
auto mtl_lib = device_->newLibrary(desc, &error);
|
||||
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load build stitched metal library"
|
||||
<< "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Function* Device::get_function_(
|
||||
const std::string& name,
|
||||
MTL::Library* mtl_lib) {
|
||||
// Pull kernel from library
|
||||
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
|
||||
auto mtl_function = mtl_lib->newFunction(ns_name);
|
||||
|
||||
return mtl_function;
|
||||
}
|
||||
|
||||
MTL::Function* Device::get_function_(
|
||||
const std::string& name,
|
||||
const std::string& specialized_name,
|
||||
const MTLFCList& func_consts,
|
||||
MTL::Library* mtl_lib) {
|
||||
if (func_consts.empty() && (specialized_name == name)) {
|
||||
return get_function_(name, mtl_lib);
|
||||
}
|
||||
|
||||
// Prepare function constants
|
||||
auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init();
|
||||
|
||||
for (auto [value, type, index] : func_consts) {
|
||||
mtl_func_consts->setConstantValue(value, type, index);
|
||||
}
|
||||
|
||||
// Prepare function desc
|
||||
auto desc = MTL::FunctionDescriptor::functionDescriptor();
|
||||
desc->setName(NS::String::string(name.c_str(), NS::ASCIIStringEncoding));
|
||||
desc->setSpecializedName(
|
||||
NS::String::string(specialized_name.c_str(), NS::ASCIIStringEncoding));
|
||||
desc->setConstantValues(mtl_func_consts);
|
||||
|
||||
// Pull kernel from library
|
||||
NS::Error* error = nullptr;
|
||||
auto mtl_function = mtl_lib->newFunction(desc, &error);
|
||||
|
||||
// Throw error if unable to build metal function
|
||||
if (!mtl_function) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load function " << name << "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
mtl_func_consts->release();
|
||||
desc->release();
|
||||
|
||||
return mtl_function;
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel_(
|
||||
const std::string& name,
|
||||
const MTL::Function* mtl_function) {
|
||||
// Compile kernel to compute pipeline
|
||||
NS::Error* error = nullptr;
|
||||
MTL::ComputePipelineState* kernel;
|
||||
|
||||
if (mtl_function) {
|
||||
kernel = device_->newComputePipelineState(mtl_function, &error);
|
||||
mtl_function->release();
|
||||
}
|
||||
|
||||
// Throw error if unable to compile metal function
|
||||
if (!mtl_function || !kernel) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
||||
@ -282,11 +372,170 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Add kernel to cache
|
||||
kernel_map_.insert({name, kernel});
|
||||
return kernel;
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel_(
|
||||
const std::string& name,
|
||||
const MTL::Function* mtl_function,
|
||||
const MTL::LinkedFunctions* linked_functions) {
|
||||
// Check inputs
|
||||
if (!linked_functions) {
|
||||
return get_kernel_(name, mtl_function);
|
||||
}
|
||||
|
||||
if (!mtl_function) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Prepare compute pipeline state descriptor
|
||||
auto desc = MTL::ComputePipelineDescriptor::alloc()->init();
|
||||
desc->setComputeFunction(mtl_function);
|
||||
desc->setLinkedFunctions(linked_functions);
|
||||
|
||||
// Compile kernel to compute pipeline
|
||||
NS::Error* error = nullptr;
|
||||
auto kernel = device_->newComputePipelineState(
|
||||
desc, MTL::PipelineOptionNone, nullptr, &error);
|
||||
|
||||
// Throw error if unable to compile metal function
|
||||
if (!kernel) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const std::string& source,
|
||||
bool cache /* = true */) {
|
||||
if (cache) {
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
auto mtl_lib = get_library_(source);
|
||||
|
||||
if (cache) {
|
||||
library_map_.insert({name, mtl_lib});
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const MTL::StitchedLibraryDescriptor* desc,
|
||||
bool cache /* = true */) {
|
||||
if (cache) {
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
auto mtl_lib = get_library_(desc);
|
||||
|
||||
if (cache) {
|
||||
library_map_.insert({name, mtl_lib});
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Function* Device::get_function(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& specialized_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */) {
|
||||
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
|
||||
}
|
||||
|
||||
MTL::Function* Device::get_function(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name /* = "mlx" */,
|
||||
const std::string& specialized_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */) {
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
||||
|
||||
return get_function(base_name, mtl_lib, specialized_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||
const std::vector<MTL::Function*>& funcs) {
|
||||
if (funcs.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
|
||||
|
||||
std::vector<NS::Object*> objs(funcs.size());
|
||||
for (int i = 0; i < funcs.size(); i++) {
|
||||
objs[i] = funcs[i];
|
||||
}
|
||||
|
||||
NS::Array* funcs_arr = NS::Array::array(objs.data(), funcs.size());
|
||||
|
||||
lfuncs->setPrivateFunctions(funcs_arr);
|
||||
|
||||
return lfuncs;
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& hash_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
// Look for cached kernel
|
||||
const auto& kname = hash_name.empty() ? base_name : hash_name;
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Pull kernel from library
|
||||
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
|
||||
|
||||
// Compile kernel to compute pipeline
|
||||
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
|
||||
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
|
||||
mtl_function->release();
|
||||
mtl_linked_funcs->release();
|
||||
|
||||
// Add kernel to cache
|
||||
kernel_map_.insert({kname, kernel});
|
||||
return kernel;
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name /* = "mlx" */,
|
||||
const std::string& hash_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
// Look for cached kernel
|
||||
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
||||
|
||||
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device) {
|
||||
static Device metal_device;
|
||||
return metal_device;
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@ -31,6 +31,9 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
return mtllib_path;
|
||||
}
|
||||
|
||||
using MTLFCList =
|
||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||
|
||||
class Device {
|
||||
public:
|
||||
Device();
|
||||
@ -59,14 +62,71 @@ class Device {
|
||||
const std::function<std::string(const std::string&)>& lib_path_func =
|
||||
get_colocated_mtllib_path);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const std::string& lib_name = "mlx");
|
||||
const std::string& source_string,
|
||||
bool cache = true);
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const MTL::StitchedLibraryDescriptor* desc,
|
||||
bool cache = true);
|
||||
|
||||
MTL::Function* get_function(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& specialized_name = "",
|
||||
const MTLFCList& func_consts = {});
|
||||
|
||||
MTL::Function* get_function(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name = "mlx",
|
||||
const std::string& specialized_name = "",
|
||||
const MTLFCList& func_consts = {});
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& hash_name = "",
|
||||
const MTLFCList& func_consts = {},
|
||||
const std::vector<MTL::Function*>& linked_functions = {});
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name = "mlx",
|
||||
const std::string& hash_name = "",
|
||||
const MTLFCList& func_consts = {},
|
||||
const std::vector<MTL::Function*>& linked_functions = {});
|
||||
|
||||
MTL::ArgumentEncoder* argument_encoder(
|
||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
||||
|
||||
private:
|
||||
MTL::Library* get_library_cache_(const std::string& name);
|
||||
|
||||
MTL::Library* get_library_(const std::string& source_string);
|
||||
MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
|
||||
|
||||
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
|
||||
|
||||
MTL::Function* get_function_(
|
||||
const std::string& name,
|
||||
const std::string& specialized_name,
|
||||
const MTLFCList& func_consts,
|
||||
MTL::Library* mtl_lib);
|
||||
|
||||
MTL::LinkedFunctions* get_linked_functions_(
|
||||
const std::vector<MTL::Function*>& funcs);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel_(
|
||||
const std::string& name,
|
||||
const MTL::Function* mtl_function);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel_(
|
||||
const std::string& name,
|
||||
const MTL::Function* mtl_function,
|
||||
const MTL::LinkedFunctions* linked_functions);
|
||||
|
||||
MTL::Device* device_;
|
||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||
|
Loading…
Reference in New Issue
Block a user