mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-07 09:51:13 +08:00
fix extension metal library finding (#1361)
This commit is contained in:
parent
d1183821a7
commit
860d3a50d7
@ -1,8 +1,6 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <dlfcn.h>
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <filesystem>
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include <sys/sysctl.h>
|
#include <sys/sysctl.h>
|
||||||
@ -16,8 +14,6 @@
|
|||||||
#include "mlx/backend/metal/metal_impl.h"
|
#include "mlx/backend/metal/metal_impl.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
|
||||||
namespace fs = std::filesystem;
|
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -38,20 +34,6 @@ constexpr auto get_metal_version() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
|
||||||
Dl_info info;
|
|
||||||
std::string mtllib_path;
|
|
||||||
std::string lib_ext = lib_name + ".metallib";
|
|
||||||
|
|
||||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
|
||||||
if (success) {
|
|
||||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
|
||||||
mtllib_path = mtllib.c_str();
|
|
||||||
}
|
|
||||||
|
|
||||||
return mtllib_path;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto load_device() {
|
auto load_device() {
|
||||||
auto devices = MTL::CopyAllDevices();
|
auto devices = MTL::CopyAllDevices();
|
||||||
auto device = static_cast<MTL::Device*>(devices->object(0))
|
auto device = static_cast<MTL::Device*>(devices->object(0))
|
||||||
@ -311,12 +293,6 @@ void Device::register_library(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::register_library(const std::string& lib_name) {
|
|
||||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
|
||||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
||||||
// Search for cached metal lib
|
// Search for cached metal lib
|
||||||
MTL::Library* mtl_lib;
|
MTL::Library* mtl_lib;
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <Metal/Metal.hpp>
|
#include <Metal/Metal.hpp>
|
||||||
|
#include <dlfcn.h>
|
||||||
|
#include <filesystem>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -12,8 +14,26 @@
|
|||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
|
|
||||||
|
namespace fs = std::filesystem;
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
|
// Note, this function must be left inline in a header so that it is not
|
||||||
|
// dynamically linked.
|
||||||
|
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||||
|
Dl_info info;
|
||||||
|
std::string mtllib_path;
|
||||||
|
std::string lib_ext = lib_name + ".metallib";
|
||||||
|
|
||||||
|
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||||
|
if (success) {
|
||||||
|
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||||
|
mtllib_path = mtllib.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
return mtllib_path;
|
||||||
|
}
|
||||||
|
|
||||||
using MTLFCList =
|
using MTLFCList =
|
||||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||||
|
|
||||||
@ -86,7 +106,13 @@ class Device {
|
|||||||
const std::string& lib_name,
|
const std::string& lib_name,
|
||||||
const std::string& lib_path);
|
const std::string& lib_path);
|
||||||
|
|
||||||
void register_library(const std::string& lib_name);
|
// Note, this should remain in the header so that it is not dynamically
|
||||||
|
// linked
|
||||||
|
void register_library(const std::string& lib_name) {
|
||||||
|
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||||
|
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
MTL::Library* get_library(const std::string& name);
|
MTL::Library* get_library(const std::string& name);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user