Add optional headers to `mx.fast.metal_kernel` (#1358)

This commit is contained in:
Alex Barron
2024-08-26 21:45:45 -07:00
committed by GitHub
parent 5f7d19d1f5
commit 1d94ac3f90
4 changed files with 50 additions and 10 deletions

View File

@@ -200,9 +200,15 @@ void init_fast(nb::module_& parent_module) {
A jit-compiled custom Metal kernel defined from a source string.
)pbdoc")
.def(
nb::init<const std::string&, const std::string&, bool, bool>(),
nb::init<
const std::string&,
const std::string&,
const std::string&,
bool,
bool>(),
"name"_a,
"source"_a,
"header"_a = "",
"ensure_row_contiguous"_a = true,
"atomic_outputs"_a = false,
R"pbdoc(
@@ -214,6 +220,8 @@ void init_fast(nb::module_& parent_module) {
the function signature will be generated for you. The names of the inputs/outputs
are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
used when the kernel is called.
header (str): Header source code to include before the main function.
Useful for helper functions or includes that should live outside of the main function body.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
atomic_outputs (bool): Whether to use atomic outputs in the function signature