mlx/mlx/backend/metal/device.cpp
davidkoski 37024d899c
fixes for building with swiftpm (#225)
- clbas is part of veclib (compile failure)
- add SWIFTPM_BUNDLE #define to allow loading the metallib from a swiftpm resource bundle
2023-12-19 16:22:10 -08:00

298 lines
8.5 KiB
C++

// Copyright © 2023 Apple Inc.
#include <dlfcn.h>
#include <cstdlib>
#include <filesystem>
#include <sstream>
#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/mps/gemm.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
namespace {
// TODO nicer way to set this or possibly expose as an environment variable
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
static constexpr const char* default_mtllib_path = METAL_PATH;
auto load_device() {
MTL::Device* device = MTL::CreateSystemDefaultDevice();
if (!device) {
throw std::runtime_error("Failed to load device");
}
return device;
}
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
MTL::Device* device,
const char* path) {
auto library = NS::String::string(path, NS::UTF8StringEncoding);
NS::Error* error;
auto lib = device->newLibrary(library, &error);
return std::make_pair(lib, error);
}
#ifdef SWIFTPM_BUNDLE
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
SWIFTPM_BUNDLE + ".bundle";
auto bundle = NS::Bundle::alloc()->init(
NS::String::string(bundle_path.c_str(), NS::UTF8StringEncoding));
if (bundle != nullptr) {
std::string resource_path =
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
"default.metallib";
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
if (lib) {
return lib;
}
}
return nullptr;
}
#endif
MTL::Library* load_library(
MTL::Device* device,
const std::string& lib_name = "mlx",
const char* lib_path = default_mtllib_path) {
// Firstly, search for the metallib in the same path as this binary
std::string first_path = get_colocated_mtllib_path(lib_name);
if (first_path.size() != 0) {
auto [lib, error] = load_library_from_path(device, first_path.c_str());
if (lib) {
return lib;
}
}
#ifdef SWIFTPM_BUNDLE
// try to load from a swiftpm resource bundle -- scan the available bundles to
// find one that contains the named bundle
{
MTL::Library* library =
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL());
if (library != nullptr) {
return library;
}
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) {
return library;
}
}
}
#endif
// Couldn't find it so let's load it from default_mtllib_path
{
auto [lib, error] = load_library_from_path(device, lib_path);
if (!lib) {
std::ostringstream msg;
msg << error->localizedDescription()->utf8String() << "\n"
<< "Failed to load device library from <" << lib_path << ">"
<< " or <" << first_path << ">.";
throw std::runtime_error(msg.str());
}
return lib;
}
}
} // namespace
Device::Device()
: pool_(NS::AutoreleasePool::alloc()->init()),
device_(load_device()),
library_map_({{"mlx", load_library(device_)}}) {}
Device::~Device() {
for (auto& q : queue_map_) {
q.second->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
for (auto& l : library_map_) {
l.second->release();
}
device_->release();
pool_->release();
}
void Device::new_queue(int index) {
// Multiple threads can ask the device for queues
// We lock this as a critical section for safety
const std::lock_guard<std::mutex> lock(mtx_);
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
if (!q) {
throw std::runtime_error(
"[metal::Device] Failed to make new command queue.");
}
queue_map_.insert({index, q});
}
int Device::get_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index);
return bit->second.first;
}
void Device::increment_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index);
bit->second.first++;
}
MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index);
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
}
MTL::CommandBuffer* Device::new_command_buffer(int index) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
return buffer_map_.insert({index, {0, cb}}).first->second.second;
}
void Device::commit_command_buffer(int index) {
auto bit = buffer_map_.find(index);
bit->second.second->commit();
bit->second.second->release();
buffer_map_.erase(bit);
}
void Device::end_encoding(int index) {
auto eit = encoder_map_.find(index);
if (eit != encoder_map_.end()) {
eit->second->endEncoding();
eit->second->release();
encoder_map_.erase(eit);
}
}
MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index);
if (eit == encoder_map_.end()) {
auto cb = get_command_buffer(index);
auto compute_encoder = cb->computeCommandEncoder();
// Increment ref count so the buffer is not garbage collected
compute_encoder->retain();
eit = encoder_map_.insert({index, compute_encoder}).first;
}
return eit->second;
}
MTL::ArgumentEncoder* Device::argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const {
// NB array here is already autoreleased but the returned argument
// encoder is owned by the caller and must be released/autoreleased
NS::Array* arg_desc_arr = NS::Array::array(
reinterpret_cast<NS::Object* const*>(arg_descs.data()), arg_descs.size());
return device_->newArgumentEncoder(arg_desc_arr);
}
void Device::register_library(
const std::string& lib_name,
const std::string& lib_path) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
library_map_.insert({lib_name, new_lib});
}
}
void Device::register_library(
const std::string& lib_name,
const std::function<std::string(const std::string&)>& lib_path_func) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
std::string new_lib_path = lib_path_func(lib_name);
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
library_map_.insert({lib_name, new_lib});
}
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& name,
const std::string& lib_name /* = "mlx" */) {
// Look for cached kernel
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
return it->second;
}
// Prepare new kernel
// Search for cached metal lib
MTL::Library* mtl_lib;
if (auto it = library_map_.find(name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name);
mtl_lib = library_map_[lib_name];
}
// Pull kernel from library
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
auto mtl_function = mtl_lib->newFunction(ns_name);
// Compile kernel to compute pipeline
NS::Error* error = nullptr;
MTL::ComputePipelineState* kernel;
if (mtl_function) {
kernel = device_->newComputePipelineState(mtl_function, &error);
mtl_function->release();
}
if (!mtl_function || !kernel) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
// Add kernel to cache
kernel_map_.insert({name, kernel});
return kernel;
}
Device& device(mlx::core::Device) {
static Device metal_device_;
return metal_device_;
}
NS::AutoreleasePool*& thread_autorelease_pool() {
static thread_local NS::AutoreleasePool* p =
NS::AutoreleasePool::alloc()->init();
return p;
}
void new_stream(Stream stream) {
thread_autorelease_pool();
if (stream.device == mlx::core::Device::gpu) {
device(stream.device).new_queue(stream.index);
}
}
} // namespace mlx::core::metal