2024-04-11 12:45:31 +08:00
|
|
|
// Copyright © 2023-2024 Apple Inc.
|
2023-12-01 03:12:53 +08:00
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
#include <dlfcn.h>
|
|
|
|
#include <cstdlib>
|
|
|
|
#include <filesystem>
|
|
|
|
#include <sstream>
|
|
|
|
|
2024-05-03 21:50:15 +08:00
|
|
|
#include <sys/sysctl.h>
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
#define NS_PRIVATE_IMPLEMENTATION
|
|
|
|
#define CA_PRIVATE_IMPLEMENTATION
|
|
|
|
#define MTL_PRIVATE_IMPLEMENTATION
|
|
|
|
|
|
|
|
#include "mlx/backend/metal/device.h"
|
|
|
|
#include "mlx/backend/metal/metal.h"
|
2024-04-08 12:47:43 +08:00
|
|
|
#include "mlx/backend/metal/metal_impl.h"
|
2023-11-30 02:52:08 +08:00
|
|
|
#include "mlx/backend/metal/mps/gemm.h"
|
2024-03-29 00:40:31 +08:00
|
|
|
#include "mlx/backend/metal/utils.h"
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
namespace fs = std::filesystem;
|
|
|
|
|
|
|
|
namespace mlx::core::metal {
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
// TODO nicer way to set this or possibly expose as an environment variable
|
2024-03-19 08:04:10 +08:00
|
|
|
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
2024-05-10 07:21:02 +08:00
|
|
|
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
2023-11-30 02:52:08 +08:00
|
|
|
|
2024-03-19 08:04:10 +08:00
|
|
|
constexpr const char* default_mtllib_path = METAL_PATH;
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
auto load_device() {
|
2024-01-05 08:12:00 +08:00
|
|
|
auto devices = MTL::CopyAllDevices();
|
2024-02-05 04:29:17 +08:00
|
|
|
auto device = static_cast<MTL::Device*>(devices->object(0))
|
|
|
|
?: MTL::CreateSystemDefaultDevice();
|
2023-11-30 02:52:08 +08:00
|
|
|
if (!device) {
|
|
|
|
throw std::runtime_error("Failed to load device");
|
|
|
|
}
|
|
|
|
return device;
|
|
|
|
}
|
|
|
|
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
|
|
|
|
MTL::Device* device,
|
|
|
|
const char* path) {
|
|
|
|
auto library = NS::String::string(path, NS::UTF8StringEncoding);
|
|
|
|
NS::Error* error;
|
|
|
|
auto lib = device->newLibrary(library, &error);
|
|
|
|
|
|
|
|
return std::make_pair(lib, error);
|
|
|
|
}
|
|
|
|
|
2023-12-20 08:22:10 +08:00
|
|
|
#ifdef SWIFTPM_BUNDLE
|
|
|
|
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
|
|
|
|
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
|
|
|
|
SWIFTPM_BUNDLE + ".bundle";
|
|
|
|
auto bundle = NS::Bundle::alloc()->init(
|
|
|
|
NS::String::string(bundle_path.c_str(), NS::UTF8StringEncoding));
|
|
|
|
if (bundle != nullptr) {
|
|
|
|
std::string resource_path =
|
|
|
|
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
|
|
|
|
"default.metallib";
|
|
|
|
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
|
|
|
|
if (lib) {
|
|
|
|
return lib;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
MTL::Library* load_library(
|
|
|
|
MTL::Device* device,
|
|
|
|
const std::string& lib_name = "mlx",
|
|
|
|
const char* lib_path = default_mtllib_path) {
|
|
|
|
// Firstly, search for the metallib in the same path as this binary
|
|
|
|
std::string first_path = get_colocated_mtllib_path(lib_name);
|
|
|
|
if (first_path.size() != 0) {
|
|
|
|
auto [lib, error] = load_library_from_path(device, first_path.c_str());
|
|
|
|
if (lib) {
|
|
|
|
return lib;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-20 08:22:10 +08:00
|
|
|
#ifdef SWIFTPM_BUNDLE
|
|
|
|
// try to load from a swiftpm resource bundle -- scan the available bundles to
|
|
|
|
// find one that contains the named bundle
|
|
|
|
{
|
|
|
|
MTL::Library* library =
|
|
|
|
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL());
|
|
|
|
if (library != nullptr) {
|
|
|
|
return library;
|
|
|
|
}
|
|
|
|
auto bundles = NS::Bundle::allBundles();
|
|
|
|
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
|
|
|
|
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
|
|
|
|
library = try_load_bundle(device, bundle->resourceURL());
|
|
|
|
if (library != nullptr) {
|
|
|
|
return library;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
// Couldn't find it so let's load it from default_mtllib_path
|
|
|
|
{
|
|
|
|
auto [lib, error] = load_library_from_path(device, lib_path);
|
|
|
|
if (!lib) {
|
|
|
|
std::ostringstream msg;
|
|
|
|
msg << error->localizedDescription()->utf8String() << "\n"
|
|
|
|
<< "Failed to load device library from <" << lib_path << ">"
|
|
|
|
<< " or <" << first_path << ">.";
|
|
|
|
throw std::runtime_error(msg.str());
|
|
|
|
}
|
|
|
|
return lib;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
2024-05-10 07:21:02 +08:00
|
|
|
void CommandEncoder::dispatchThreadgroups(
|
|
|
|
MTL::Size grid_dims,
|
|
|
|
MTL::Size group_dims) {
|
|
|
|
num_dispatches++;
|
|
|
|
enc->dispatchThreadgroups(grid_dims, group_dims);
|
|
|
|
maybe_split();
|
|
|
|
}
|
|
|
|
|
|
|
|
void CommandEncoder::dispatchThreads(
|
|
|
|
MTL::Size grid_dims,
|
|
|
|
MTL::Size group_dims) {
|
|
|
|
num_dispatches++;
|
|
|
|
enc->dispatchThreads(grid_dims, group_dims);
|
|
|
|
maybe_split();
|
|
|
|
}
|
|
|
|
|
|
|
|
void CommandEncoder::maybe_split() {
|
|
|
|
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
|
|
|
|
enc->endEncoding();
|
|
|
|
enc->release();
|
|
|
|
num_dispatches = 0;
|
|
|
|
outputs.clear();
|
|
|
|
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
|
|
|
enc->retain();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-23 03:01:26 +08:00
|
|
|
Device::Device() {
|
|
|
|
auto pool = new_scoped_memory_pool();
|
|
|
|
device_ = load_device();
|
|
|
|
library_map_ = {{"mlx", load_library(device_)}};
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
Device::~Device() {
|
2024-01-05 08:12:00 +08:00
|
|
|
auto pool = new_scoped_memory_pool();
|
2023-11-30 02:52:08 +08:00
|
|
|
for (auto& q : queue_map_) {
|
|
|
|
q.second->release();
|
|
|
|
}
|
2023-12-22 09:59:15 +08:00
|
|
|
for (auto& b : buffer_map_) {
|
|
|
|
b.second.second->release();
|
|
|
|
}
|
2023-12-23 03:01:26 +08:00
|
|
|
for (auto& k : kernel_map_) {
|
|
|
|
k.second->release();
|
|
|
|
}
|
|
|
|
for (auto& l : library_map_) {
|
|
|
|
l.second->release();
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
device_->release();
|
|
|
|
}
|
|
|
|
|
|
|
|
void Device::new_queue(int index) {
|
2024-01-05 08:12:00 +08:00
|
|
|
auto thread_pool = metal::new_scoped_memory_pool();
|
2024-01-05 08:28:52 +08:00
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
// Multiple threads can ask the device for queues
|
|
|
|
// We lock this as a critical section for safety
|
|
|
|
const std::lock_guard<std::mutex> lock(mtx_);
|
|
|
|
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
2024-03-29 00:40:31 +08:00
|
|
|
debug_set_stream_queue_label(q, index);
|
2023-11-30 02:52:08 +08:00
|
|
|
if (!q) {
|
|
|
|
throw std::runtime_error(
|
|
|
|
"[metal::Device] Failed to make new command queue.");
|
|
|
|
}
|
|
|
|
queue_map_.insert({index, q});
|
|
|
|
}
|
|
|
|
|
|
|
|
int Device::get_command_buffer_ops(int index) {
|
|
|
|
auto bit = buffer_map_.find(index);
|
|
|
|
return bit->second.first;
|
|
|
|
}
|
|
|
|
|
|
|
|
void Device::increment_command_buffer_ops(int index) {
|
|
|
|
auto bit = buffer_map_.find(index);
|
|
|
|
bit->second.first++;
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
|
|
|
auto bit = buffer_map_.find(index);
|
2024-05-10 07:21:02 +08:00
|
|
|
if (bit == buffer_map_.end()) {
|
|
|
|
auto qit = queue_map_.find(index);
|
|
|
|
if (qit == queue_map_.end()) {
|
|
|
|
throw std::runtime_error(
|
|
|
|
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
2024-05-10 07:21:02 +08:00
|
|
|
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
2023-11-30 02:52:08 +08:00
|
|
|
|
2024-05-10 07:21:02 +08:00
|
|
|
if (!cb) {
|
|
|
|
throw std::runtime_error(
|
|
|
|
"[metal::Device] Unable to create new command buffer");
|
|
|
|
}
|
2023-11-30 02:52:08 +08:00
|
|
|
|
2024-05-10 07:21:02 +08:00
|
|
|
// Increment ref count so the buffer is not garbage collected
|
|
|
|
cb->retain();
|
2023-11-30 02:52:08 +08:00
|
|
|
|
2024-05-10 07:21:02 +08:00
|
|
|
bit = buffer_map_.insert({index, {0, cb}}).first;
|
|
|
|
}
|
|
|
|
return bit->second.second;
|
2023-11-30 02:52:08 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void Device::commit_command_buffer(int index) {
|
|
|
|
auto bit = buffer_map_.find(index);
|
|
|
|
bit->second.second->commit();
|
|
|
|
bit->second.second->release();
|
|
|
|
buffer_map_.erase(bit);
|
|
|
|
}
|
|
|
|
|
|
|
|
void Device::end_encoding(int index) {
|
2024-05-10 07:21:02 +08:00
|
|
|
encoder_map_.erase(index);
|
2023-11-30 02:52:08 +08:00
|
|
|
}
|
|
|
|
|
2024-04-11 12:45:31 +08:00
|
|
|
CommandEncoder& Device::get_command_encoder(int index) {
|
2023-11-30 02:52:08 +08:00
|
|
|
auto eit = encoder_map_.find(index);
|
|
|
|
if (eit == encoder_map_.end()) {
|
|
|
|
auto cb = get_command_buffer(index);
|
2024-05-10 07:21:02 +08:00
|
|
|
eit =
|
|
|
|
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
|
2023-11-30 02:52:08 +08:00
|
|
|
}
|
2024-04-12 12:15:36 +08:00
|
|
|
return *(eit->second);
|
2023-11-30 02:52:08 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void Device::register_library(
|
|
|
|
const std::string& lib_name,
|
|
|
|
const std::string& lib_path) {
|
|
|
|
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
|
|
|
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
|
|
|
|
library_map_.insert({lib_name, new_lib});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void Device::register_library(
|
|
|
|
const std::string& lib_name,
|
|
|
|
const std::function<std::string(const std::string&)>& lib_path_func) {
|
|
|
|
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
|
|
|
std::string new_lib_path = lib_path_func(lib_name);
|
|
|
|
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
|
|
|
|
library_map_.insert({lib_name, new_lib});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-01-31 07:42:36 +08:00
|
|
|
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
2023-11-30 02:52:08 +08:00
|
|
|
// Search for cached metal lib
|
|
|
|
MTL::Library* mtl_lib;
|
2024-01-31 07:42:36 +08:00
|
|
|
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
|
2023-11-30 02:52:08 +08:00
|
|
|
mtl_lib = it->second;
|
|
|
|
} else { // Look for metallib alongside library
|
|
|
|
register_library(lib_name);
|
|
|
|
mtl_lib = library_map_[lib_name];
|
|
|
|
}
|
|
|
|
|
2024-01-31 07:42:36 +08:00
|
|
|
return mtl_lib;
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::Library* Device::get_library_(const std::string& source_string) {
|
|
|
|
auto pool = new_scoped_memory_pool();
|
|
|
|
|
|
|
|
auto ns_code =
|
|
|
|
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
|
|
|
|
|
|
|
|
NS::Error* error = nullptr;
|
|
|
|
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
|
|
|
|
|
|
|
|
// Throw error if unable to compile library
|
|
|
|
if (!mtl_lib) {
|
|
|
|
std::ostringstream msg;
|
2024-04-30 03:17:40 +08:00
|
|
|
msg << "[metal::Device] Unable to build metal library from source" << "\n";
|
2024-01-31 07:42:36 +08:00
|
|
|
if (error) {
|
|
|
|
msg << error->localizedDescription()->utf8String() << "\n";
|
|
|
|
}
|
|
|
|
throw std::runtime_error(msg.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
return mtl_lib;
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
|
|
|
|
auto pool = new_scoped_memory_pool();
|
|
|
|
|
|
|
|
NS::Error* error = nullptr;
|
|
|
|
auto mtl_lib = device_->newLibrary(desc, &error);
|
|
|
|
|
|
|
|
// Throw error if unable to compile library
|
|
|
|
if (!mtl_lib) {
|
|
|
|
std::ostringstream msg;
|
2024-04-30 03:17:40 +08:00
|
|
|
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
|
2024-01-31 07:42:36 +08:00
|
|
|
if (error) {
|
|
|
|
msg << error->localizedDescription()->utf8String() << "\n";
|
|
|
|
}
|
|
|
|
throw std::runtime_error(msg.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
return mtl_lib;
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::Function* Device::get_function_(
|
|
|
|
const std::string& name,
|
|
|
|
MTL::Library* mtl_lib) {
|
2023-11-30 02:52:08 +08:00
|
|
|
// Pull kernel from library
|
|
|
|
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
|
|
|
|
auto mtl_function = mtl_lib->newFunction(ns_name);
|
|
|
|
|
2024-01-31 07:42:36 +08:00
|
|
|
return mtl_function;
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::Function* Device::get_function_(
|
|
|
|
const std::string& name,
|
|
|
|
const std::string& specialized_name,
|
|
|
|
const MTLFCList& func_consts,
|
|
|
|
MTL::Library* mtl_lib) {
|
|
|
|
if (func_consts.empty() && (specialized_name == name)) {
|
|
|
|
return get_function_(name, mtl_lib);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Prepare function constants
|
|
|
|
auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init();
|
|
|
|
|
|
|
|
for (auto [value, type, index] : func_consts) {
|
|
|
|
mtl_func_consts->setConstantValue(value, type, index);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Prepare function desc
|
|
|
|
auto desc = MTL::FunctionDescriptor::functionDescriptor();
|
|
|
|
desc->setName(NS::String::string(name.c_str(), NS::ASCIIStringEncoding));
|
|
|
|
desc->setSpecializedName(
|
|
|
|
NS::String::string(specialized_name.c_str(), NS::ASCIIStringEncoding));
|
|
|
|
desc->setConstantValues(mtl_func_consts);
|
|
|
|
|
|
|
|
// Pull kernel from library
|
|
|
|
NS::Error* error = nullptr;
|
|
|
|
auto mtl_function = mtl_lib->newFunction(desc, &error);
|
|
|
|
|
|
|
|
// Throw error if unable to build metal function
|
|
|
|
if (!mtl_function) {
|
|
|
|
std::ostringstream msg;
|
|
|
|
msg << "[metal::Device] Unable to load function " << name << "\n";
|
|
|
|
if (error) {
|
|
|
|
msg << error->localizedDescription()->utf8String() << "\n";
|
|
|
|
}
|
|
|
|
throw std::runtime_error(msg.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
mtl_func_consts->release();
|
|
|
|
|
|
|
|
return mtl_function;
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* Device::get_kernel_(
|
|
|
|
const std::string& name,
|
|
|
|
const MTL::Function* mtl_function) {
|
2023-11-30 02:52:08 +08:00
|
|
|
// Compile kernel to compute pipeline
|
|
|
|
NS::Error* error = nullptr;
|
|
|
|
MTL::ComputePipelineState* kernel;
|
2024-01-31 07:42:36 +08:00
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
if (mtl_function) {
|
|
|
|
kernel = device_->newComputePipelineState(mtl_function, &error);
|
|
|
|
}
|
2024-01-31 07:42:36 +08:00
|
|
|
|
|
|
|
// Throw error if unable to compile metal function
|
2023-11-30 02:52:08 +08:00
|
|
|
if (!mtl_function || !kernel) {
|
|
|
|
std::ostringstream msg;
|
|
|
|
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
|
|
|
if (error) {
|
|
|
|
msg << error->localizedDescription()->utf8String() << "\n";
|
|
|
|
}
|
|
|
|
throw std::runtime_error(msg.str());
|
|
|
|
}
|
|
|
|
|
2024-01-31 07:42:36 +08:00
|
|
|
return kernel;
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* Device::get_kernel_(
|
|
|
|
const std::string& name,
|
|
|
|
const MTL::Function* mtl_function,
|
|
|
|
const MTL::LinkedFunctions* linked_functions) {
|
|
|
|
// Check inputs
|
|
|
|
if (!linked_functions) {
|
|
|
|
return get_kernel_(name, mtl_function);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!mtl_function) {
|
|
|
|
std::ostringstream msg;
|
|
|
|
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
|
|
|
throw std::runtime_error(msg.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
// Prepare compute pipeline state descriptor
|
|
|
|
auto desc = MTL::ComputePipelineDescriptor::alloc()->init();
|
|
|
|
desc->setComputeFunction(mtl_function);
|
|
|
|
desc->setLinkedFunctions(linked_functions);
|
|
|
|
|
|
|
|
// Compile kernel to compute pipeline
|
|
|
|
NS::Error* error = nullptr;
|
|
|
|
auto kernel = device_->newComputePipelineState(
|
|
|
|
desc, MTL::PipelineOptionNone, nullptr, &error);
|
|
|
|
|
|
|
|
// Throw error if unable to compile metal function
|
|
|
|
if (!kernel) {
|
|
|
|
std::ostringstream msg;
|
|
|
|
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
|
|
|
if (error) {
|
|
|
|
msg << error->localizedDescription()->utf8String() << "\n";
|
|
|
|
}
|
|
|
|
throw std::runtime_error(msg.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
return kernel;
|
|
|
|
}
|
|
|
|
|
2024-02-08 05:15:59 +08:00
|
|
|
MTL::Library* Device::get_library(const std::string& name) {
|
|
|
|
auto it = library_map_.find(name);
|
|
|
|
return (it != library_map_.end()) ? it->second : nullptr;
|
|
|
|
}
|
|
|
|
|
2024-01-31 07:42:36 +08:00
|
|
|
MTL::Library* Device::get_library(
|
|
|
|
const std::string& name,
|
|
|
|
const std::string& source,
|
|
|
|
bool cache /* = true */) {
|
|
|
|
if (cache) {
|
|
|
|
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto mtl_lib = get_library_(source);
|
|
|
|
|
|
|
|
if (cache) {
|
|
|
|
library_map_.insert({name, mtl_lib});
|
|
|
|
}
|
|
|
|
|
|
|
|
return mtl_lib;
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::Library* Device::get_library(
|
|
|
|
const std::string& name,
|
|
|
|
const MTL::StitchedLibraryDescriptor* desc,
|
|
|
|
bool cache /* = true */) {
|
|
|
|
if (cache) {
|
|
|
|
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto mtl_lib = get_library_(desc);
|
|
|
|
|
|
|
|
if (cache) {
|
|
|
|
library_map_.insert({name, mtl_lib});
|
|
|
|
}
|
|
|
|
|
|
|
|
return mtl_lib;
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::Function* Device::get_function(
|
|
|
|
const std::string& base_name,
|
|
|
|
MTL::Library* mtl_lib,
|
|
|
|
const std::string& specialized_name /* = "" */,
|
|
|
|
const MTLFCList& func_consts /* = {} */) {
|
|
|
|
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::Function* Device::get_function(
|
|
|
|
const std::string& base_name,
|
|
|
|
const std::string& lib_name /* = "mlx" */,
|
|
|
|
const std::string& specialized_name /* = "" */,
|
|
|
|
const MTLFCList& func_consts /* = {} */) {
|
|
|
|
// Search for cached metal lib
|
|
|
|
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
|
|
|
|
|
|
|
return get_function(base_name, mtl_lib, specialized_name, func_consts);
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::LinkedFunctions* Device::get_linked_functions_(
|
|
|
|
const std::vector<MTL::Function*>& funcs) {
|
|
|
|
if (funcs.empty()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
|
|
|
|
|
|
|
|
std::vector<NS::Object*> objs(funcs.size());
|
|
|
|
for (int i = 0; i < funcs.size(); i++) {
|
|
|
|
objs[i] = funcs[i];
|
|
|
|
}
|
|
|
|
|
|
|
|
NS::Array* funcs_arr = NS::Array::array(objs.data(), funcs.size());
|
|
|
|
|
|
|
|
lfuncs->setPrivateFunctions(funcs_arr);
|
|
|
|
|
|
|
|
return lfuncs;
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* Device::get_kernel(
|
|
|
|
const std::string& base_name,
|
|
|
|
MTL::Library* mtl_lib,
|
|
|
|
const std::string& hash_name /* = "" */,
|
|
|
|
const MTLFCList& func_consts /* = {} */,
|
|
|
|
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
|
|
|
auto pool = new_scoped_memory_pool();
|
|
|
|
|
|
|
|
// Look for cached kernel
|
|
|
|
const auto& kname = hash_name.empty() ? base_name : hash_name;
|
|
|
|
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Pull kernel from library
|
|
|
|
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
|
|
|
|
|
|
|
|
// Compile kernel to compute pipeline
|
|
|
|
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
|
|
|
|
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
|
2024-05-16 01:30:41 +08:00
|
|
|
|
2024-01-31 07:42:36 +08:00
|
|
|
mtl_function->release();
|
|
|
|
mtl_linked_funcs->release();
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
// Add kernel to cache
|
2024-01-31 07:42:36 +08:00
|
|
|
kernel_map_.insert({kname, kernel});
|
2024-05-16 01:30:41 +08:00
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
return kernel;
|
|
|
|
}
|
|
|
|
|
2024-01-31 07:42:36 +08:00
|
|
|
MTL::ComputePipelineState* Device::get_kernel(
|
|
|
|
const std::string& base_name,
|
|
|
|
const std::string& lib_name /* = "mlx" */,
|
|
|
|
const std::string& hash_name /* = "" */,
|
|
|
|
const MTLFCList& func_consts /* = {} */,
|
|
|
|
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
|
|
|
// Look for cached kernel
|
|
|
|
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
|
|
|
|
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Search for cached metal lib
|
|
|
|
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
|
|
|
|
|
|
|
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
|
|
|
|
}
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
Device& device(mlx::core::Device) {
|
2023-12-23 03:01:26 +08:00
|
|
|
static Device metal_device;
|
|
|
|
return metal_device;
|
2023-11-30 02:52:08 +08:00
|
|
|
}
|
|
|
|
|
2024-04-17 21:16:02 +08:00
|
|
|
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
|
2023-12-23 03:01:26 +08:00
|
|
|
auto dtor = [](void* ptr) {
|
|
|
|
static_cast<NS::AutoreleasePool*>(ptr)->release();
|
|
|
|
};
|
2024-04-17 21:16:02 +08:00
|
|
|
return std::unique_ptr<void, std::function<void(void*)>>(
|
|
|
|
NS::AutoreleasePool::alloc()->init(), dtor);
|
2023-11-30 02:52:08 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void new_stream(Stream stream) {
|
|
|
|
if (stream.device == mlx::core::Device::gpu) {
|
|
|
|
device(stream.device).new_queue(stream.index);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-05-01 06:47:27 +08:00
|
|
|
std::unordered_map<std::string, std::variant<std::string, size_t>>
|
|
|
|
device_info() {
|
|
|
|
auto raw_device = device(default_device()).mtl_device();
|
|
|
|
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
2024-05-03 21:50:15 +08:00
|
|
|
|
|
|
|
int mib[] = {CTL_HW, HW_MEMSIZE};
|
|
|
|
size_t memsize = 0;
|
|
|
|
size_t length = sizeof(memsize);
|
|
|
|
|
|
|
|
sysctl(mib, 2, &memsize, &length, NULL, 0);
|
|
|
|
|
2024-05-01 06:47:27 +08:00
|
|
|
return {
|
|
|
|
{"architecture", arch},
|
|
|
|
{"max_buffer_length", raw_device->maxBufferLength()},
|
|
|
|
{"max_recommended_working_set_size",
|
2024-05-03 21:50:15 +08:00
|
|
|
raw_device->recommendedMaxWorkingSetSize()},
|
|
|
|
{"memory_size", memsize}};
|
2024-05-01 06:47:27 +08:00
|
|
|
}
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
} // namespace mlx::core::metal
|