fix custom metal extension (#2446)

This commit is contained in:
Awni Hannun 2025-07-31 06:25:36 -07:00 committed by GitHub
parent daafee676f
commit da5912e4f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 33 additions and 16 deletions

View File

@ -164,9 +164,11 @@ jobs:
- run: - run:
name: Build example extension name: Build example extension
command: | command: |
source .venv/bin/activate
cd examples/extensions cd examples/extensions
uv pip install -r requirements.txt uv pip install -r requirements.txt
uv run --no-project setup.py build_ext -j8 uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
- store_test_results: - store_test_results:
path: test-results path: test-results
- run: - run:

View File

@ -394,14 +394,14 @@ below.
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel // Resolve name of kernel
std::ostringstream kname; std::stream kname;
kname << "axpby_" << "general_" << type_to_name(out); kname = "axpby_general_" + type_to_name(out);
// Load the metal library // Load the metal library
auto lib = d.get_library("mlx_ext"); auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), lib); auto kernel = d.get_kernel(kname, lib);
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);

View File

@ -1,5 +1,6 @@
// Copyright © 2023-2025 Apple Inc. // Copyright © 2023-2025 Apple Inc.
#include <dlfcn.h>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
@ -16,6 +17,19 @@
namespace my_ext { namespace my_ext {
// A helper function to find the location of the current binary on disk.
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
std::string current_binary_dir() {
static std::string binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path().string();
}();
return binary_dir;
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Operation Implementation // Operation Implementation
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -167,16 +181,15 @@ void Axpby::eval_gpu(
} }
// Resolve name of kernel (corresponds to axpby.metal) // Resolve name of kernel (corresponds to axpby.metal)
std::ostringstream kname; std::string kname = "axpby_";
kname << "axpby_"; kname += (contiguous_kernel ? "contiguous_" : "general_");
kname << (contiguous_kernel ? "contiguous_" : "general_"); kname += type_to_name(out);
kname << type_to_name(out);
// Load the metal library // Load the metal library
auto lib = d.get_library("mlx_ext"); auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), lib); auto kernel = d.get_kernel(kname, lib);
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);

View File

@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.25 cmake>=3.25
mlx>=0.21.0 mlx>=0.21.0
nanobind==2.2.0 nanobind==2.4.0

View File

@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
a = mx.ones((3, 4)) a = mx.ones((3, 4))
b = mx.ones((3, 4)) b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu) c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
print(f"c shape: {c.shape}") print(f"c shape: {c_cpu.shape}")
print(f"c dtype: {c.dtype}") print(f"c dtype: {c_cpu.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}") print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")