throw for certain cases of non captured inputs in compile (#1401)

This commit is contained in:
Awni Hannun
2024-09-09 14:54:31 -07:00
committed by GitHub
parent dc627dcb5e
commit 3ae6aabe9f
5 changed files with 70 additions and 18 deletions

View File

@@ -302,20 +302,20 @@ void init_fast(nb::module_& parent_module) {
A jit-compiled custom Metal kernel defined from a source string.
Args:
name (str): Name for the kernel.
input_names (List[str]): The parameter names of the inputs in the
function signature.
output_names (List[str]): The parameter names of the outputs in the
name (str): Name for the kernel.
input_names (List[str]): The parameter names of the inputs in the
function signature.
source (str): Source code. This is the body of a function in Metal,
the function signature will be automatically generated.
header (str): Header source code to include before the main function.
Useful for helper functions or includes that should live outside of
the main function body.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
atomic_outputs (bool): Whether to use atomic outputs in the function signature
e.g. ``device atomic<float>``. Default: ``False``.
output_names (List[str]): The parameter names of the outputs in the
function signature.
source (str): Source code. This is the body of a function in Metal,
the function signature will be automatically generated.
header (str): Header source code to include before the main function.
Useful for helper functions or includes that should live outside of
the main function body.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
atomic_outputs (bool): Whether to use atomic outputs in the function signature
e.g. ``device atomic<float>``. Default: ``False``.
Returns:
Callable ``metal_kernel``.