mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 03:21:19 +08:00
Fix unintuitive metal kernel caching
This commit is contained in:
parent
c6a20b427a
commit
ac1117b224
@ -8,11 +8,12 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
|||||||
Simple Example
|
Simple Example
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
T tmp = inp[elem];
|
T tmp = inp[elem];
|
||||||
@ -25,6 +26,8 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
b = exp_elementwise(a)
|
b = exp_elementwise(a)
|
||||||
assert mx.allclose(b, mx.exp(a))
|
assert mx.allclose(b, mx.exp(a))
|
||||||
|
|
||||||
|
Every time you make a kernel, a new Metal library is created and possibly
|
||||||
|
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||||
|
:func:`fast.metal_kernel` and then use it many times.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
We are only required to pass the body of the Metal kernel in ``source``.
|
Only pass the body of the Metal kernel in ``source``. The function
|
||||||
|
signature is generated automatically.
|
||||||
|
|
||||||
The full function signature will be generated using:
|
The full function signature will be generated using:
|
||||||
|
|
||||||
@ -78,29 +86,34 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
|||||||
|
|
||||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||||
|
|
||||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||||
|
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||||
|
dimension should be less than or equal to the corresponding grid dimension.
|
||||||
|
|
||||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||||
|
generated code for debugging purposes.
|
||||||
|
|
||||||
Using Shape/Strides
|
Using Shape/Strides
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
is ``True`` by default. This will copy the array inputs if needed
|
||||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
before the kernel is launched to ensure that the memory layout is row
|
||||||
when indexing.
|
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||||
|
have to worry about gaps or the ordering of the dims when indexing.
|
||||||
|
|
||||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||||
input array ``a`` if any are present in ``source``.
|
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||||
|
the right elements for each thread.
|
||||||
|
|
||||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||||
|
relying on a copy from ``ensure_row_contiguous``:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||||
@ -116,6 +129,8 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
|
|||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source
|
source=source
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@ -183,25 +198,13 @@ We'll start with the following MLX implementation using standard ops:
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||||
to write a fast GPU kernel for both the forward and backward passes.
|
to write a fast GPU kernel for both the forward and backward passes.
|
||||||
|
|
||||||
First we'll implement the forward pass as a fused kernel:
|
First we'll implement the forward pass as a fused kernel:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@mx.custom_function
|
|
||||||
def grid_sample(x, grid):
|
|
||||||
|
|
||||||
assert x.ndim == 4, "`x` must be 4D."
|
|
||||||
assert grid.ndim == 4, "`grid` must be 4D."
|
|
||||||
|
|
||||||
B, _, _, C = x.shape
|
|
||||||
_, gN, gM, D = grid.shape
|
|
||||||
out_shape = (B, gN, gM, C)
|
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
|
||||||
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
int H = x_shape[1];
|
int H = x_shape[1];
|
||||||
@ -251,12 +254,26 @@ First we'll implement the forward pass as a fused kernel:
|
|||||||
|
|
||||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="grid_sample",
|
name="grid_sample",
|
||||||
input_names=["x", "grid"],
|
input_names=["x", "grid"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@mx.custom_function
|
||||||
|
def grid_sample(x, grid):
|
||||||
|
|
||||||
|
assert x.ndim == 4, "`x` must be 4D."
|
||||||
|
assert grid.ndim == 4, "`grid` must be 4D."
|
||||||
|
|
||||||
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
out_shape = (B, gN, gM, C)
|
||||||
|
|
||||||
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
|
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[x, grid],
|
inputs=[x, grid],
|
||||||
template=[("T", x.dtype)],
|
template=[("T", x.dtype)],
|
||||||
@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
|
|||||||
Grid Sample VJP
|
Grid Sample VJP
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||||
its custom vjp transform so MLX can differentiate it.
|
define its custom vjp transform so MLX can differentiate it.
|
||||||
|
|
||||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
requires a few extra :func:`fast.metal_kernel` features:
|
||||||
|
|
||||||
* ``init_value=0``
|
* ``init_value=0``
|
||||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||||
@ -299,14 +316,6 @@ We can then implement the backwards pass as follows:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@grid_sample.vjp
|
|
||||||
def grid_sample_vjp(primals, cotangent, _):
|
|
||||||
x, grid = primals
|
|
||||||
B, _, _, C = x.shape
|
|
||||||
_, gN, gM, D = grid.shape
|
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
|
||||||
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
int H = x_shape[1];
|
int H = x_shape[1];
|
||||||
@ -406,6 +415,15 @@ We can then implement the backwards pass as follows:
|
|||||||
source=source,
|
source=source,
|
||||||
atomic_outputs=True,
|
atomic_outputs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@grid_sample.vjp
|
||||||
|
def grid_sample_vjp(primals, cotangent, _):
|
||||||
|
x, grid = primals
|
||||||
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
|
||||||
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
|
|
||||||
# pad the output channels to simd group size
|
# pad the output channels to simd group size
|
||||||
# so that our `simd_sum`s don't overlap.
|
# so that our `simd_sum`s don't overlap.
|
||||||
simdgroup_size = 32
|
simdgroup_size = 32
|
||||||
|
@ -73,6 +73,16 @@ void CustomKernel::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto [tx, ty, tz] = threadgroup_;
|
const auto [tx, ty, tz] = threadgroup_;
|
||||||
|
auto tg_size = tx * ty * tz;
|
||||||
|
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (tg_size > max_tg_size) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Thread group size (" << tg_size << ") is greater than "
|
||||||
|
<< " the maximum allowed threads per threadgroup (" << max_tg_size
|
||||||
|
<< ").";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
const auto [gx, gy, gz] = grid_;
|
const auto [gx, gy, gz] = grid_;
|
||||||
MTL::Size group_dims =
|
MTL::Size group_dims =
|
||||||
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||||
|
14
mlx/fast.cpp
14
mlx/fast.cpp
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <chrono>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
@ -1228,6 +1229,10 @@ MetalKernelFunction metal_kernel(
|
|||||||
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
|
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
auto now = std::chrono::system_clock::now();
|
||||||
|
int64_t timestamp = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||||
|
now.time_since_epoch())
|
||||||
|
.count();
|
||||||
|
|
||||||
return [=,
|
return [=,
|
||||||
shape_infos = std::move(shape_infos),
|
shape_infos = std::move(shape_infos),
|
||||||
@ -1271,14 +1276,15 @@ MetalKernelFunction metal_kernel(
|
|||||||
|
|
||||||
std::ostringstream func_name;
|
std::ostringstream func_name;
|
||||||
std::string template_def = "";
|
std::string template_def = "";
|
||||||
std::string hash_key = "";
|
std::string template_hash = "";
|
||||||
if (!template_args.empty()) {
|
if (!template_args.empty()) {
|
||||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||||
template_def = write_template(template_args);
|
template_def = write_template(template_args);
|
||||||
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
|
template_hash = std::regex_replace(template_def, disallowed_chars, "_");
|
||||||
hash_key.pop_back();
|
template_hash.pop_back();
|
||||||
}
|
}
|
||||||
func_name << "custom_kernel_" << name << hash_key;
|
func_name << "custom_kernel_" << name << "_" << template_hash << "_"
|
||||||
|
<< timestamp;
|
||||||
std::string kernel_name = func_name.str();
|
std::string kernel_name = func_name.str();
|
||||||
|
|
||||||
std::string kernel_source = write_signature(
|
std::string kernel_source = write_signature(
|
||||||
|
@ -735,6 +735,41 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)[0]
|
)[0]
|
||||||
self.assertEqual(out.item(), 2)
|
self.assertEqual(out.item(), 2)
|
||||||
|
|
||||||
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
|
def test_custom_kernel_caching(self):
|
||||||
|
def call_kernel(a: mx.array, source):
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="my_kernel",
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
return kernel(
|
||||||
|
inputs=[a],
|
||||||
|
grid=(a.size, 1, 1),
|
||||||
|
threadgroup=(a.size, 1, 1),
|
||||||
|
output_shapes=[a.shape],
|
||||||
|
output_dtypes=[a.dtype],
|
||||||
|
stream=mx.gpu,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
a = mx.random.normal(shape=(32,))
|
||||||
|
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
out[elem] = 0.0;
|
||||||
|
"""
|
||||||
|
|
||||||
|
out = call_kernel(a, source)
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.zeros_like(out)))
|
||||||
|
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
out[elem] = 1.0;
|
||||||
|
"""
|
||||||
|
out = call_kernel(a, source)
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.ones_like(out)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user