mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 11:28:12 +08:00
Add optional headers to `mx.fast.metal_kernel` (#1358)
This commit is contained in:
@@ -200,9 +200,15 @@ void init_fast(nb::module_& parent_module) {
|
||||
A jit-compiled custom Metal kernel defined from a source string.
|
||||
)pbdoc")
|
||||
.def(
|
||||
nb::init<const std::string&, const std::string&, bool, bool>(),
|
||||
nb::init<
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool,
|
||||
bool>(),
|
||||
"name"_a,
|
||||
"source"_a,
|
||||
"header"_a = "",
|
||||
"ensure_row_contiguous"_a = true,
|
||||
"atomic_outputs"_a = false,
|
||||
R"pbdoc(
|
||||
@@ -214,6 +220,8 @@ void init_fast(nb::module_& parent_module) {
|
||||
the function signature will be generated for you. The names of the inputs/outputs
|
||||
are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
|
||||
used when the kernel is called.
|
||||
header (str): Header source code to include before the main function.
|
||||
Useful for helper functions or includes that should live outside of the main function body.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``True``.
|
||||
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
||||
|
||||
Reference in New Issue
Block a user