diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index c4c1b0aff..3e92f2814 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -1,3 +1,5 @@ +.. _custom_metal_kernels: + Custom Metal Kernels ==================== @@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float) custom_kernel_myexp_float; +Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads `_ function. +This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups. +For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension. + Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes. Using Shape/Strides diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 758a27530..829301ab5 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -278,7 +278,9 @@ void init_fast(nb::module_& parent_module) { output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``. output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``. grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. + This will be passed to ``MTLComputeCommandEncoder::dispatchThreads``. threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. + This will be passed to ``MTLComputeCommandEncoder::dispatchThreads``. template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments. These will be added as template arguments to the kernel definition. Default: ``None``. init_value (float, optional): Optional value to use to initialize all of the output arrays. @@ -300,6 +302,8 @@ void init_fast(nb::module_& parent_module) { R"pbdoc( A jit-compiled custom Metal kernel defined from a source string. + Full documentation: :ref:`custom_metal_kernels`. + Args: name (str): Name for the kernel. input_names (List[str]): The parameter names of the inputs in the