From da5912e4f2cd6b6e1a75d1100e383b21ace5eac4 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 31 Jul 2025 06:25:36 -0700 Subject: [PATCH] fix custom metal extension (#2446) --- .circleci/config.yml | 4 +++- docs/src/dev/extensions.rst | 8 ++++---- examples/extensions/axpby/axpby.cpp | 25 +++++++++++++++++++------ examples/extensions/requirements.txt | 2 +- examples/extensions/test.py | 10 ++++++---- 5 files changed, 33 insertions(+), 16 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 1de308c587..7472c58f17 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -164,9 +164,11 @@ jobs: - run: name: Build example extension command: | + source .venv/bin/activate cd examples/extensions uv pip install -r requirements.txt - uv run --no-project setup.py build_ext -j8 + uv run --no-project setup.py build_ext --inplace + uv run --no-project python test.py - store_test_results: path: test-results - run: diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 5a4de81238..8b5bd41e5f 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -394,14 +394,14 @@ below. out.set_data(allocator::malloc(out.nbytes())); // Resolve name of kernel - std::ostringstream kname; - kname << "axpby_" << "general_" << type_to_name(out); + std::stream kname; + kname = "axpby_general_" + type_to_name(out); // Load the metal library - auto lib = d.get_library("mlx_ext"); + auto lib = d.get_library("mlx_ext", current_binary_dir()); // Make a kernel from this metal library - auto kernel = d.get_kernel(kname.str(), lib); + auto kernel = d.get_kernel(kname, lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index 9ba9334837..31badbbda9 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -1,5 +1,6 @@ // Copyright © 2023-2025 Apple Inc. +#include #include #include @@ -16,6 +17,19 @@ namespace my_ext { +// A helper function to find the location of the current binary on disk. +// The Metal library ("mlx_ext.mtllib"), should be in the same directory. +std::string current_binary_dir() { + static std::string binary_dir = []() { + Dl_info info; + if (!dladdr(reinterpret_cast(¤t_binary_dir), &info)) { + throw std::runtime_error("Unable to get current binary dir."); + } + return std::filesystem::path(info.dli_fname).parent_path().string(); + }(); + return binary_dir; +} + /////////////////////////////////////////////////////////////////////////////// // Operation Implementation /////////////////////////////////////////////////////////////////////////////// @@ -167,16 +181,15 @@ void Axpby::eval_gpu( } // Resolve name of kernel (corresponds to axpby.metal) - std::ostringstream kname; - kname << "axpby_"; - kname << (contiguous_kernel ? "contiguous_" : "general_"); - kname << type_to_name(out); + std::string kname = "axpby_"; + kname += (contiguous_kernel ? "contiguous_" : "general_"); + kname += type_to_name(out); // Load the metal library - auto lib = d.get_library("mlx_ext"); + auto lib = d.get_library("mlx_ext", current_binary_dir()); // Make a kernel from this metal library - auto kernel = d.get_kernel(kname.str(), lib); + auto kernel = d.get_kernel(kname, lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/examples/extensions/requirements.txt b/examples/extensions/requirements.txt index 084b049ce4..a4591df8cc 100644 --- a/examples/extensions/requirements.txt +++ b/examples/extensions/requirements.txt @@ -1,4 +1,4 @@ setuptools>=42 cmake>=3.25 mlx>=0.21.0 -nanobind==2.2.0 +nanobind==2.4.0 diff --git a/examples/extensions/test.py b/examples/extensions/test.py index f00a72e857..4190fe9962 100644 --- a/examples/extensions/test.py +++ b/examples/extensions/test.py @@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby a = mx.ones((3, 4)) b = mx.ones((3, 4)) -c = axpby(a, b, 4.0, 2.0, stream=mx.cpu) +c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu) +c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu) -print(f"c shape: {c.shape}") -print(f"c dtype: {c.dtype}") -print(f"c correct: {mx.all(c == 6.0).item()}") +print(f"c shape: {c_cpu.shape}") +print(f"c dtype: {c_cpu.dtype}") +print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}") +print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")