mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
Update Compute Pipeline Creation API (#581)
* Add option to specialize metal functions on function constants * Update Compute Pipeline Creation API * Add options to make libraries from source and stitching * Update function specialization name options
This commit is contained in:
parent
1895d34c20
commit
375446453e
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-24 Apple Inc.
|
||||||
|
|
||||||
#include <dlfcn.h>
|
#include <dlfcn.h>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
@ -242,37 +242,127 @@ void Device::register_library(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* Device::get_kernel(
|
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
||||||
const std::string& name,
|
|
||||||
const std::string& lib_name /* = "mlx" */) {
|
|
||||||
auto pool = new_scoped_memory_pool();
|
|
||||||
// Look for cached kernel
|
|
||||||
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare new kernel
|
|
||||||
|
|
||||||
// Search for cached metal lib
|
// Search for cached metal lib
|
||||||
MTL::Library* mtl_lib;
|
MTL::Library* mtl_lib;
|
||||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
|
||||||
mtl_lib = it->second;
|
mtl_lib = it->second;
|
||||||
} else { // Look for metallib alongside library
|
} else { // Look for metallib alongside library
|
||||||
register_library(lib_name);
|
register_library(lib_name);
|
||||||
mtl_lib = library_map_[lib_name];
|
mtl_lib = library_map_[lib_name];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return mtl_lib;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Library* Device::get_library_(const std::string& source_string) {
|
||||||
|
auto pool = new_scoped_memory_pool();
|
||||||
|
|
||||||
|
auto ns_code =
|
||||||
|
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
|
||||||
|
|
||||||
|
NS::Error* error = nullptr;
|
||||||
|
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
|
||||||
|
|
||||||
|
// Throw error if unable to compile library
|
||||||
|
if (!mtl_lib) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal::Device] Unable to load build metal library from source"
|
||||||
|
<< "\n";
|
||||||
|
if (error) {
|
||||||
|
msg << error->localizedDescription()->utf8String() << "\n";
|
||||||
|
}
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return mtl_lib;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
|
||||||
|
auto pool = new_scoped_memory_pool();
|
||||||
|
|
||||||
|
NS::Error* error = nullptr;
|
||||||
|
auto mtl_lib = device_->newLibrary(desc, &error);
|
||||||
|
|
||||||
|
// Throw error if unable to compile library
|
||||||
|
if (!mtl_lib) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal::Device] Unable to load build stitched metal library"
|
||||||
|
<< "\n";
|
||||||
|
if (error) {
|
||||||
|
msg << error->localizedDescription()->utf8String() << "\n";
|
||||||
|
}
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return mtl_lib;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Function* Device::get_function_(
|
||||||
|
const std::string& name,
|
||||||
|
MTL::Library* mtl_lib) {
|
||||||
// Pull kernel from library
|
// Pull kernel from library
|
||||||
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
|
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
|
||||||
auto mtl_function = mtl_lib->newFunction(ns_name);
|
auto mtl_function = mtl_lib->newFunction(ns_name);
|
||||||
|
|
||||||
|
return mtl_function;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Function* Device::get_function_(
|
||||||
|
const std::string& name,
|
||||||
|
const std::string& specialized_name,
|
||||||
|
const MTLFCList& func_consts,
|
||||||
|
MTL::Library* mtl_lib) {
|
||||||
|
if (func_consts.empty() && (specialized_name == name)) {
|
||||||
|
return get_function_(name, mtl_lib);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare function constants
|
||||||
|
auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init();
|
||||||
|
|
||||||
|
for (auto [value, type, index] : func_consts) {
|
||||||
|
mtl_func_consts->setConstantValue(value, type, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare function desc
|
||||||
|
auto desc = MTL::FunctionDescriptor::functionDescriptor();
|
||||||
|
desc->setName(NS::String::string(name.c_str(), NS::ASCIIStringEncoding));
|
||||||
|
desc->setSpecializedName(
|
||||||
|
NS::String::string(specialized_name.c_str(), NS::ASCIIStringEncoding));
|
||||||
|
desc->setConstantValues(mtl_func_consts);
|
||||||
|
|
||||||
|
// Pull kernel from library
|
||||||
|
NS::Error* error = nullptr;
|
||||||
|
auto mtl_function = mtl_lib->newFunction(desc, &error);
|
||||||
|
|
||||||
|
// Throw error if unable to build metal function
|
||||||
|
if (!mtl_function) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal::Device] Unable to load function " << name << "\n";
|
||||||
|
if (error) {
|
||||||
|
msg << error->localizedDescription()->utf8String() << "\n";
|
||||||
|
}
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
mtl_func_consts->release();
|
||||||
|
desc->release();
|
||||||
|
|
||||||
|
return mtl_function;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* Device::get_kernel_(
|
||||||
|
const std::string& name,
|
||||||
|
const MTL::Function* mtl_function) {
|
||||||
// Compile kernel to compute pipeline
|
// Compile kernel to compute pipeline
|
||||||
NS::Error* error = nullptr;
|
NS::Error* error = nullptr;
|
||||||
MTL::ComputePipelineState* kernel;
|
MTL::ComputePipelineState* kernel;
|
||||||
|
|
||||||
if (mtl_function) {
|
if (mtl_function) {
|
||||||
kernel = device_->newComputePipelineState(mtl_function, &error);
|
kernel = device_->newComputePipelineState(mtl_function, &error);
|
||||||
mtl_function->release();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Throw error if unable to compile metal function
|
||||||
if (!mtl_function || !kernel) {
|
if (!mtl_function || !kernel) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
||||||
@ -282,11 +372,170 @@ MTL::ComputePipelineState* Device::get_kernel(
|
|||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add kernel to cache
|
|
||||||
kernel_map_.insert({name, kernel});
|
|
||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* Device::get_kernel_(
|
||||||
|
const std::string& name,
|
||||||
|
const MTL::Function* mtl_function,
|
||||||
|
const MTL::LinkedFunctions* linked_functions) {
|
||||||
|
// Check inputs
|
||||||
|
if (!linked_functions) {
|
||||||
|
return get_kernel_(name, mtl_function);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!mtl_function) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare compute pipeline state descriptor
|
||||||
|
auto desc = MTL::ComputePipelineDescriptor::alloc()->init();
|
||||||
|
desc->setComputeFunction(mtl_function);
|
||||||
|
desc->setLinkedFunctions(linked_functions);
|
||||||
|
|
||||||
|
// Compile kernel to compute pipeline
|
||||||
|
NS::Error* error = nullptr;
|
||||||
|
auto kernel = device_->newComputePipelineState(
|
||||||
|
desc, MTL::PipelineOptionNone, nullptr, &error);
|
||||||
|
|
||||||
|
// Throw error if unable to compile metal function
|
||||||
|
if (!kernel) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
||||||
|
if (error) {
|
||||||
|
msg << error->localizedDescription()->utf8String() << "\n";
|
||||||
|
}
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return kernel;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Library* Device::get_library(
|
||||||
|
const std::string& name,
|
||||||
|
const std::string& source,
|
||||||
|
bool cache /* = true */) {
|
||||||
|
if (cache) {
|
||||||
|
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto mtl_lib = get_library_(source);
|
||||||
|
|
||||||
|
if (cache) {
|
||||||
|
library_map_.insert({name, mtl_lib});
|
||||||
|
}
|
||||||
|
|
||||||
|
return mtl_lib;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Library* Device::get_library(
|
||||||
|
const std::string& name,
|
||||||
|
const MTL::StitchedLibraryDescriptor* desc,
|
||||||
|
bool cache /* = true */) {
|
||||||
|
if (cache) {
|
||||||
|
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto mtl_lib = get_library_(desc);
|
||||||
|
|
||||||
|
if (cache) {
|
||||||
|
library_map_.insert({name, mtl_lib});
|
||||||
|
}
|
||||||
|
|
||||||
|
return mtl_lib;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Function* Device::get_function(
|
||||||
|
const std::string& base_name,
|
||||||
|
MTL::Library* mtl_lib,
|
||||||
|
const std::string& specialized_name /* = "" */,
|
||||||
|
const MTLFCList& func_consts /* = {} */) {
|
||||||
|
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Function* Device::get_function(
|
||||||
|
const std::string& base_name,
|
||||||
|
const std::string& lib_name /* = "mlx" */,
|
||||||
|
const std::string& specialized_name /* = "" */,
|
||||||
|
const MTLFCList& func_consts /* = {} */) {
|
||||||
|
// Search for cached metal lib
|
||||||
|
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
||||||
|
|
||||||
|
return get_function(base_name, mtl_lib, specialized_name, func_consts);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||||
|
const std::vector<MTL::Function*>& funcs) {
|
||||||
|
if (funcs.empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
|
||||||
|
|
||||||
|
std::vector<NS::Object*> objs(funcs.size());
|
||||||
|
for (int i = 0; i < funcs.size(); i++) {
|
||||||
|
objs[i] = funcs[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
NS::Array* funcs_arr = NS::Array::array(objs.data(), funcs.size());
|
||||||
|
|
||||||
|
lfuncs->setPrivateFunctions(funcs_arr);
|
||||||
|
|
||||||
|
return lfuncs;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* Device::get_kernel(
|
||||||
|
const std::string& base_name,
|
||||||
|
MTL::Library* mtl_lib,
|
||||||
|
const std::string& hash_name /* = "" */,
|
||||||
|
const MTLFCList& func_consts /* = {} */,
|
||||||
|
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||||
|
auto pool = new_scoped_memory_pool();
|
||||||
|
|
||||||
|
// Look for cached kernel
|
||||||
|
const auto& kname = hash_name.empty() ? base_name : hash_name;
|
||||||
|
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pull kernel from library
|
||||||
|
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
|
||||||
|
|
||||||
|
// Compile kernel to compute pipeline
|
||||||
|
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
|
||||||
|
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
|
||||||
|
mtl_function->release();
|
||||||
|
mtl_linked_funcs->release();
|
||||||
|
|
||||||
|
// Add kernel to cache
|
||||||
|
kernel_map_.insert({kname, kernel});
|
||||||
|
return kernel;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* Device::get_kernel(
|
||||||
|
const std::string& base_name,
|
||||||
|
const std::string& lib_name /* = "mlx" */,
|
||||||
|
const std::string& hash_name /* = "" */,
|
||||||
|
const MTLFCList& func_consts /* = {} */,
|
||||||
|
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||||
|
// Look for cached kernel
|
||||||
|
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
|
||||||
|
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search for cached metal lib
|
||||||
|
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
||||||
|
|
||||||
|
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||||
|
}
|
||||||
|
|
||||||
Device& device(mlx::core::Device) {
|
Device& device(mlx::core::Device) {
|
||||||
static Device metal_device;
|
static Device metal_device;
|
||||||
return metal_device;
|
return metal_device;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-24 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -31,6 +31,9 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
|||||||
return mtllib_path;
|
return mtllib_path;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
using MTLFCList =
|
||||||
|
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||||
|
|
||||||
class Device {
|
class Device {
|
||||||
public:
|
public:
|
||||||
Device();
|
Device();
|
||||||
@ -59,14 +62,71 @@ class Device {
|
|||||||
const std::function<std::string(const std::string&)>& lib_path_func =
|
const std::function<std::string(const std::string&)>& lib_path_func =
|
||||||
get_colocated_mtllib_path);
|
get_colocated_mtllib_path);
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_kernel(
|
MTL::Library* get_library(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const std::string& lib_name = "mlx");
|
const std::string& source_string,
|
||||||
|
bool cache = true);
|
||||||
|
|
||||||
|
MTL::Library* get_library(
|
||||||
|
const std::string& name,
|
||||||
|
const MTL::StitchedLibraryDescriptor* desc,
|
||||||
|
bool cache = true);
|
||||||
|
|
||||||
|
MTL::Function* get_function(
|
||||||
|
const std::string& base_name,
|
||||||
|
MTL::Library* mtl_lib,
|
||||||
|
const std::string& specialized_name = "",
|
||||||
|
const MTLFCList& func_consts = {});
|
||||||
|
|
||||||
|
MTL::Function* get_function(
|
||||||
|
const std::string& base_name,
|
||||||
|
const std::string& lib_name = "mlx",
|
||||||
|
const std::string& specialized_name = "",
|
||||||
|
const MTLFCList& func_consts = {});
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_kernel(
|
||||||
|
const std::string& base_name,
|
||||||
|
MTL::Library* mtl_lib,
|
||||||
|
const std::string& hash_name = "",
|
||||||
|
const MTLFCList& func_consts = {},
|
||||||
|
const std::vector<MTL::Function*>& linked_functions = {});
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_kernel(
|
||||||
|
const std::string& base_name,
|
||||||
|
const std::string& lib_name = "mlx",
|
||||||
|
const std::string& hash_name = "",
|
||||||
|
const MTLFCList& func_consts = {},
|
||||||
|
const std::vector<MTL::Function*>& linked_functions = {});
|
||||||
|
|
||||||
MTL::ArgumentEncoder* argument_encoder(
|
MTL::ArgumentEncoder* argument_encoder(
|
||||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
MTL::Library* get_library_cache_(const std::string& name);
|
||||||
|
|
||||||
|
MTL::Library* get_library_(const std::string& source_string);
|
||||||
|
MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
|
||||||
|
|
||||||
|
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
|
||||||
|
|
||||||
|
MTL::Function* get_function_(
|
||||||
|
const std::string& name,
|
||||||
|
const std::string& specialized_name,
|
||||||
|
const MTLFCList& func_consts,
|
||||||
|
MTL::Library* mtl_lib);
|
||||||
|
|
||||||
|
MTL::LinkedFunctions* get_linked_functions_(
|
||||||
|
const std::vector<MTL::Function*>& funcs);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_kernel_(
|
||||||
|
const std::string& name,
|
||||||
|
const MTL::Function* mtl_function);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_kernel_(
|
||||||
|
const std::string& name,
|
||||||
|
const MTL::Function* mtl_function,
|
||||||
|
const MTL::LinkedFunctions* linked_functions);
|
||||||
|
|
||||||
MTL::Device* device_;
|
MTL::Device* device_;
|
||||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||||
|
Loading…
Reference in New Issue
Block a user