Add grid_sample example to metal_kernel docs (#1352)

* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel`

* add grid sample to docs

* zero_outputs -> init_value

* add missing header for linux
This commit is contained in:
Alex Barron
2024-08-23 18:24:16 -07:00
committed by GitHub
parent 3b4d5484c7
commit b96e105244
6 changed files with 337 additions and 15 deletions

View File

@@ -200,10 +200,11 @@ void init_fast(nb::module_& parent_module) {
A jit-compiled custom Metal kernel defined from a source string.
)pbdoc")
.def(
nb::init<const std::string&, const std::string&, bool>(),
nb::init<const std::string&, const std::string&, bool, bool>(),
"name"_a,
"source"_a,
"ensure_row_contiguous"_a = true,
"atomic_outputs"_a = false,
R"pbdoc(
Initialize a metal_kernel.
@@ -215,6 +216,8 @@ void init_fast(nb::module_& parent_module) {
used when the kernel is called.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
atomic_outputs (bool): Whether to use atomic outputs in the function signature
e.g. ``device atomic<float>``. Default: ``False``.
Returns:
Callable ``metal_kernel``.
@@ -256,6 +259,7 @@ void init_fast(nb::module_& parent_module) {
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, nb::handle>> template_args_,
std::optional<float> init_value,
bool verbose,
StreamOrDevice s) {
std::map<std::string, array> inputs;
@@ -289,6 +293,7 @@ void init_fast(nb::module_& parent_module) {
grid,
threadgroup,
template_args,
init_value,
verbose,
s);
},
@@ -299,10 +304,11 @@ void init_fast(nb::module_& parent_module) {
"grid"_a,
"threadgroup"_a,
"template"_a = nb::none(),
"init_value"_a = nb::none(),
"verbose"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
R"pbdoc(
Run the kernel.
@@ -316,9 +322,11 @@ void init_fast(nb::module_& parent_module) {
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments.
These will be added as template arguments to the kernel definition.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run.
when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Returns: