From 1d94ac3f90d39ed8b10074518e18c4483b9d5651 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Mon, 26 Aug 2024 21:45:45 -0700 Subject: [PATCH] Add optional headers to ``mx.fast.metal_kernel`` (#1358) --- mlx/fast.cpp | 4 +++- mlx/fast.h | 11 +++++++---- python/src/fast.cpp | 10 +++++++++- python/tests/test_fast.py | 35 +++++++++++++++++++++++++++++++---- 4 files changed, 50 insertions(+), 10 deletions(-) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 186d2081f..c8c12af69 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1112,7 +1112,6 @@ std::map MetalKernel::operator()( "[metal_kernel] MetalKernel only works on GPU."); } - std::ostringstream kernel_source; std::ostringstream func_name; std::string template_def = ""; @@ -1128,6 +1127,9 @@ std::map MetalKernel::operator()( func_name << "custom_kernel_" << name_ << hash_key; std::string kernel_name = func_name.str(); + std::ostringstream kernel_source; + kernel_source << header_ << std::endl; + std::vector shape_infos; write_signature( func_name.str(), diff --git a/mlx/fast.h b/mlx/fast.h index 75ac8759a..874aa529a 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -71,10 +71,12 @@ class MetalKernel { MetalKernel( const std::string& name, const std::string& source, - bool ensure_row_contiguous, - bool atomic_outputs) + const std::string& header = "", + bool ensure_row_contiguous = true, + bool atomic_outputs = false) : name_(name), source_(source), + header_(header), ensure_row_contiguous_(ensure_row_contiguous), atomic_outputs_(atomic_outputs) {} @@ -93,7 +95,8 @@ class MetalKernel { private: std::string name_; std::string source_; - bool ensure_row_contiguous_ = true; - bool atomic_outputs_ = false; + std::string header_; + bool ensure_row_contiguous_; + bool atomic_outputs_; }; } // namespace mlx::core::fast diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 863a65ec1..f389cc2c5 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -200,9 +200,15 @@ void init_fast(nb::module_& parent_module) { A jit-compiled custom Metal kernel defined from a source string. )pbdoc") .def( - nb::init(), + nb::init< + const std::string&, + const std::string&, + const std::string&, + bool, + bool>(), "name"_a, "source"_a, + "header"_a = "", "ensure_row_contiguous"_a = true, "atomic_outputs"_a = false, R"pbdoc( @@ -214,6 +220,8 @@ void init_fast(nb::module_& parent_module) { the function signature will be generated for you. The names of the inputs/outputs are determined by the ``inputs`` and ``output_shapes``/``output_dtypes`` used when the kernel is called. + header (str): Header source code to include before the main function. + Useful for helper functions or includes that should live outside of the main function body. ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous before the kernel runs. Default: ``True``. atomic_outputs (bool): Whether to use atomic outputs in the function signature diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 369ac2086..93dc2f261 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -551,7 +551,7 @@ class TestFast(mlx_tests.MLXTestCase): @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_custom_kernel_basic(self): mx.random.seed(7) - a = mx.random.normal(shape=(3, 6)) + a = mx.random.normal(shape=(2, 2)) kernel = mx.fast.metal_kernel( name="basic", source=""" @@ -567,7 +567,7 @@ class TestFast(mlx_tests.MLXTestCase): output_dtypes={"out1": mx.float32}, stream=mx.gpu, ) - mx.allclose(out["out1"], a[:2, :2]) + self.assertTrue(mx.allclose(out["out1"], a)) @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_custom_kernel_args(self): @@ -618,12 +618,12 @@ class TestFast(mlx_tests.MLXTestCase): uint elem = thread_position_in_grid.x; uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); T tmp = inp[loc]; - out[elem] = metal::exp(tmp) * threads_per_simdgroup; + out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; """ source_contig = """ uint elem = thread_position_in_grid.x; T tmp = inp[elem]; - out[elem] = metal::exp(tmp) * threads_per_simdgroup; + out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; """ # non contiguous @@ -646,6 +646,33 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"])) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_custom_kernel_helper(self): + mx.random.seed(7) + a = mx.random.normal(shape=(2, 2)) + kernel = mx.fast.metal_kernel( + name="helper", + header=""" + template + T do_exp(T x) { + return metal::precise::exp(x); + } + """, + source=""" + uint elem = thread_position_in_grid.x; + out1[elem] = do_exp(a[elem]); + """, + ) + out = kernel( + inputs={"a": a}, + grid=(4, 1, 1), + threadgroup=(2, 1, 1), + output_shapes={"out1": (2, 2)}, + output_dtypes={"out1": mx.float32}, + stream=mx.gpu, + ) + self.assertTrue(mx.allclose(out["out1"], mx.exp(a))) + if __name__ == "__main__": unittest.main()