diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 3cefbbcb2..8da905c5c 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -1,6 +1,6 @@ // Copyright © 2023 Apple Inc. -#include +#include #include "mlx/array.h" #include "mlx/backend/common/copy.h" diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index eba03c4c6..68662763a 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -42,6 +42,25 @@ std::pair load_library_from_path( return std::make_pair(lib, error); } +#ifdef SWIFTPM_BUNDLE +MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) { + std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" + + SWIFTPM_BUNDLE + ".bundle"; + auto bundle = NS::Bundle::alloc()->init( + NS::String::string(bundle_path.c_str(), NS::UTF8StringEncoding)); + if (bundle != nullptr) { + std::string resource_path = + std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" + + "default.metallib"; + auto [lib, error] = load_library_from_path(device, resource_path.c_str()); + if (lib) { + return lib; + } + } + return nullptr; +} +#endif + MTL::Library* load_library( MTL::Device* device, const std::string& lib_name = "mlx", @@ -55,6 +74,26 @@ MTL::Library* load_library( } } +#ifdef SWIFTPM_BUNDLE + // try to load from a swiftpm resource bundle -- scan the available bundles to + // find one that contains the named bundle + { + MTL::Library* library = + try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL()); + if (library != nullptr) { + return library; + } + auto bundles = NS::Bundle::allBundles(); + for (int i = 0, c = (int)bundles->count(); i < c; i++) { + auto bundle = reinterpret_cast(bundles->object(i)); + library = try_load_bundle(device, bundle->resourceURL()); + if (library != nullptr) { + return library; + } + } + } +#endif + // Couldn't find it so let's load it from default_mtllib_path { auto [lib, error] = load_library_from_path(device, lib_path);