fix extension metal library finding (#1361)

This commit is contained in:
Awni Hannun 2024-08-26 09:18:50 -07:00 committed by GitHub
parent d1183821a7
commit 860d3a50d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 25 deletions

View File

@ -1,8 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include <cstdlib>
#include <filesystem>
#include <sstream>
#include <sys/sysctl.h>
@ -16,8 +14,6 @@
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/utils.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
namespace {
@ -38,20 +34,6 @@ constexpr auto get_metal_version() {
#endif
}
std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
auto load_device() {
auto devices = MTL::CopyAllDevices();
auto device = static_cast<MTL::Device*>(devices->object(0))
@ -311,12 +293,6 @@ void Device::register_library(
}
}
void Device::register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
// Search for cached metal lib
MTL::Library* mtl_lib;

View File

@ -3,6 +3,8 @@
#pragma once
#include <Metal/Metal.hpp>
#include <dlfcn.h>
#include <filesystem>
#include <functional>
#include <mutex>
#include <string>
@ -12,8 +14,26 @@
#include "mlx/array.h"
#include "mlx/device.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
// Note, this function must be left inline in a header so that it is not
// dynamically linked.
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
@ -86,7 +106,13 @@ class Device {
const std::string& lib_name,
const std::string& lib_path);
void register_library(const std::string& lib_name);
// Note, this should remain in the header so that it is not dynamically
// linked
void register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
MTL::Library* get_library(const std::string& name);