mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 21:37:50 +08:00
Add grid_sample
example to metal_kernel
docs (#1352)
* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel` * add grid sample to docs * zero_outputs -> init_value * add missing header for linux
This commit is contained in:
@@ -200,10 +200,11 @@ void init_fast(nb::module_& parent_module) {
|
||||
A jit-compiled custom Metal kernel defined from a source string.
|
||||
)pbdoc")
|
||||
.def(
|
||||
nb::init<const std::string&, const std::string&, bool>(),
|
||||
nb::init<const std::string&, const std::string&, bool, bool>(),
|
||||
"name"_a,
|
||||
"source"_a,
|
||||
"ensure_row_contiguous"_a = true,
|
||||
"atomic_outputs"_a = false,
|
||||
R"pbdoc(
|
||||
Initialize a metal_kernel.
|
||||
|
||||
@@ -215,6 +216,8 @@ void init_fast(nb::module_& parent_module) {
|
||||
used when the kernel is called.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``True``.
|
||||
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
||||
e.g. ``device atomic<float>``. Default: ``False``.
|
||||
Returns:
|
||||
Callable ``metal_kernel``.
|
||||
|
||||
@@ -256,6 +259,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, nb::handle>> template_args_,
|
||||
std::optional<float> init_value,
|
||||
bool verbose,
|
||||
StreamOrDevice s) {
|
||||
std::map<std::string, array> inputs;
|
||||
@@ -289,6 +293,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
grid,
|
||||
threadgroup,
|
||||
template_args,
|
||||
init_value,
|
||||
verbose,
|
||||
s);
|
||||
},
|
||||
@@ -299,10 +304,11 @@ void init_fast(nb::module_& parent_module) {
|
||||
"grid"_a,
|
||||
"threadgroup"_a,
|
||||
"template"_a = nb::none(),
|
||||
"init_value"_a = nb::none(),
|
||||
"verbose"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||
R"pbdoc(
|
||||
Run the kernel.
|
||||
|
||||
@@ -316,9 +322,11 @@ void init_fast(nb::module_& parent_module) {
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||
when it is run.
|
||||
when it is run. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
|
Reference in New Issue
Block a user