mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Merge 9a5d162ebf
into b3d7b85376
This commit is contained in:
commit
709e3aa875
@ -3,6 +3,7 @@
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <sstream>
|
||||
#include <dlfcn.h>
|
||||
|
||||
#define NS_PRIVATE_IMPLEMENTATION
|
||||
#define CA_PRIVATE_IMPLEMENTATION
|
||||
@ -35,6 +36,16 @@ auto get_metal_version() {
|
||||
return metal_version_;
|
||||
}
|
||||
|
||||
static fs::path get_dylib_directory() {
|
||||
Dl_info info{};
|
||||
|
||||
if (dladdr(reinterpret_cast<void const*>(default_mtllib_path), &info) && info.dli_fname) {
|
||||
fs::path libFile(info.dli_fname);
|
||||
return libFile.parent_path();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
auto load_device() {
|
||||
auto devices = MTL::CopyAllDevices();
|
||||
auto device = static_cast<MTL::Device*>(devices->object(0))
|
||||
@ -115,7 +126,7 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
}
|
||||
|
||||
MTL::Library* load_default_library(MTL::Device* device) {
|
||||
NS::Error* error[4];
|
||||
NS::Error* error[5];
|
||||
MTL::Library* lib;
|
||||
// First try the colocated mlx.metallib
|
||||
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
|
||||
@ -136,10 +147,25 @@ MTL::Library* load_default_library(MTL::Device* device) {
|
||||
|
||||
// Finally try default_mtllib_path
|
||||
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
{
|
||||
auto dir = get_dylib_directory();
|
||||
if (!dir.empty()) {
|
||||
auto dylib_path = (dir / default_mtllib_path).string();
|
||||
std::tie(lib, error[4]) = load_library_from_path(device, dylib_path.c_str());
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the default metallib. ";
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
if (error[i] != nullptr) {
|
||||
msg << error[i]->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user