mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
Add dispatchThreads to custom kernel doc (#1551)
* add dispatchThreads info * update * add link
This commit is contained in:
parent
eac961ddb1
commit
9e516b71ea
@ -1,3 +1,5 @@
|
|||||||
|
.. _custom_metal_kernels:
|
||||||
|
|
||||||
Custom Metal Kernels
|
Custom Metal Kernels
|
||||||
====================
|
====================
|
||||||
|
|
||||||
@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
|||||||
|
|
||||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||||
|
|
||||||
|
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||||
|
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||||
|
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||||
|
|
||||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||||
|
|
||||||
Using Shape/Strides
|
Using Shape/Strides
|
||||||
|
@ -278,7 +278,9 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
||||||
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
||||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||||
|
This will be passed to ``MTLComputeCommandEncoder::dispatchThreads``.
|
||||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||||
|
This will be passed to ``MTLComputeCommandEncoder::dispatchThreads``.
|
||||||
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
||||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||||
@ -300,6 +302,8 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
A jit-compiled custom Metal kernel defined from a source string.
|
A jit-compiled custom Metal kernel defined from a source string.
|
||||||
|
|
||||||
|
Full documentation: :ref:`custom_metal_kernels`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): Name for the kernel.
|
name (str): Name for the kernel.
|
||||||
input_names (List[str]): The parameter names of the inputs in the
|
input_names (List[str]): The parameter names of the inputs in the
|
||||||
|
Loading…
Reference in New Issue
Block a user