This commit is contained in:
CircleCI Docs
2025-07-01 22:14:26 +00:00
parent 35c20e6c56
commit cfe36c4c52
533 changed files with 2735 additions and 2574 deletions

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Custom Metal Kernels &#8212; MLX 0.26.1 documentation</title>
<title>Custom Metal Kernels &#8212; MLX 0.26.2 documentation</title>
@@ -36,7 +36,7 @@
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
<script src="../_static/documentation_options.js?v=3724ff34"></script>
<script src="../_static/documentation_options.js?v=20507f52"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -137,8 +137,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.26.1 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.26.1 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.26.2 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.26.2 documentation - Home"/>`);</script>
</a></div>
@@ -926,19 +926,20 @@ document.write(`
<section id="simple-example">
<h2>Simple Example<a class="headerlink" href="#simple-example" title="Link to this heading">#</a></h2>
<p>Lets write a custom kernel that computes <code class="docutils literal notranslate"><span class="pre">exp</span></code> elementwise:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
<span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> T tmp = inp[elem];</span>
<span class="s2"> out[elem] = metal::exp(tmp);</span>
<span class="s2"> &quot;&quot;&quot;</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> T tmp = inp[elem];</span>
<span class="s2"> out[elem] = metal::exp(tmp);</span>
<span class="s2">&quot;&quot;&quot;</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;myexp&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;inp&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;myexp&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;inp&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="p">],</span>
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">&quot;T&quot;</span><span class="p">,</span> <span class="n">mx</span><span class="o">.</span><span class="n">float32</span><span class="p">)],</span>
@@ -954,9 +955,13 @@ document.write(`
<span class="k">assert</span> <span class="n">mx</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">a</span><span class="p">))</span>
</pre></div>
</div>
<p>Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
<a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html#mlx.core.fast.metal_kernel" title="mlx.core.fast.metal_kernel"><code class="xref py py-func docutils literal notranslate"><span class="pre">fast.metal_kernel()</span></code></a> and then use it many times.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>We are only required to pass the body of the Metal kernel in <code class="docutils literal notranslate"><span class="pre">source</span></code>.</p>
<p>Only pass the body of the Metal kernel in <code class="docutils literal notranslate"><span class="pre">source</span></code>. The function
signature is generated automatically.</p>
</div>
<p>The full function signature will be generated using:</p>
<ul class="simple">
@@ -1004,37 +1009,43 @@ All the attributes defined in Table 5.8 of the <a class="reference external" hre
<span class="k">template</span><span class="w"> </span><span class="p">[[</span><span class="n">host_name</span><span class="p">(</span><span class="s">&quot;custom_kernel_myexp_float&quot;</span><span class="p">)]]</span><span class="w"> </span><span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="k">decltype</span><span class="p">(</span><span class="n">custom_kernel_myexp_float</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">)</span><span class="w"> </span><span class="n">custom_kernel_myexp_float</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">;</span>
</pre></div>
</div>
<p>Note: <code class="docutils literal notranslate"><span class="pre">grid</span></code> and <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> are parameters to the Metal <a class="reference external" href="https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads">dispatchThreads</a> function.
This means we will launch <code class="docutils literal notranslate"><span class="pre">mx.prod(grid)</span></code> threads, subdivided into <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.</p>
<p>Passing <code class="docutils literal notranslate"><span class="pre">verbose=True</span></code> to <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel.__call__</span></code> will print the generated code for debugging purposes.</p>
<p>Note: <code class="docutils literal notranslate"><span class="pre">grid</span></code> and <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> are parameters to the Metal <a class="reference external" href="https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads">dispatchThreads</a>
function. This means we will launch <code class="docutils literal notranslate"><span class="pre">mx.prod(grid)</span></code> threads, subdivided into
<code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.</p>
<p>Passing <code class="docutils literal notranslate"><span class="pre">verbose=True</span></code> to <code class="xref py py-func docutils literal notranslate"><span class="pre">ast.metal_kernel.__call__()</span></code> will print the
generated code for debugging purposes.</p>
</section>
<section id="using-shape-strides">
<h2>Using Shape/Strides<a class="headerlink" href="#using-shape-strides" title="Link to this heading">#</a></h2>
<p><code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel</span></code> supports an argument <code class="docutils literal notranslate"><span class="pre">ensure_row_contiguous</span></code> which is <code class="docutils literal notranslate"><span class="pre">True</span></code> by default.
This will copy the <code class="docutils literal notranslate"><span class="pre">mx.array</span></code> inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
Generally this makes writing the kernel easier, since we dont have to worry about gaps or the ordering of the dims
when indexing.</p>
<p>If we want to avoid this copy, <code class="docutils literal notranslate"><span class="pre">metal_kernel</span></code> automatically passes <code class="docutils literal notranslate"><span class="pre">a_shape</span></code>, <code class="docutils literal notranslate"><span class="pre">a_strides</span></code> and <code class="docutils literal notranslate"><span class="pre">a_ndim</span></code> for each
input array <code class="docutils literal notranslate"><span class="pre">a</span></code> if any are present in <code class="docutils literal notranslate"><span class="pre">source</span></code>.
We can then use MLXs built in indexing utils to fetch the right elements for each thread.</p>
<p>Lets convert <code class="docutils literal notranslate"><span class="pre">myexp</span></code> above to support arbitrarily strided arrays without relying on a copy from <code class="docutils literal notranslate"><span class="pre">ensure_row_contiguous</span></code>:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
<span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included</span>
<span class="s2"> uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);</span>
<span class="s2"> T tmp = inp[loc];</span>
<span class="s2"> // Output arrays are always row contiguous</span>
<span class="s2"> out[elem] = metal::exp(tmp);</span>
<span class="s2"> &quot;&quot;&quot;</span>
<p><a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html#mlx.core.fast.metal_kernel" title="mlx.core.fast.metal_kernel"><code class="xref py py-func docutils literal notranslate"><span class="pre">fast.metal_kernel()</span></code></a> supports an argument <code class="docutils literal notranslate"><span class="pre">ensure_row_contiguous</span></code> which
is <code class="docutils literal notranslate"><span class="pre">True</span></code> by default. This will copy the array inputs if needed
before the kernel is launched to ensure that the memory layout is row
contiguous. Generally this makes writing the kernel easier, since we dont
have to worry about gaps or the ordering of the dims when indexing.</p>
<p>If we want to avoid this copy, <a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html#mlx.core.fast.metal_kernel" title="mlx.core.fast.metal_kernel"><code class="xref py py-func docutils literal notranslate"><span class="pre">fast.metal_kernel()</span></code></a> automatically passes
<code class="docutils literal notranslate"><span class="pre">a_shape</span></code>, <code class="docutils literal notranslate"><span class="pre">a_strides</span></code> and <code class="docutils literal notranslate"><span class="pre">a_ndim</span></code> for each input array <code class="docutils literal notranslate"><span class="pre">a</span></code> if any are
present in <code class="docutils literal notranslate"><span class="pre">source</span></code>. We can then use MLXs built in indexing utils to fetch
the right elements for each thread.</p>
<p>Lets convert <code class="docutils literal notranslate"><span class="pre">myexp</span></code> above to support arbitrarily strided arrays without
relying on a copy from <code class="docutils literal notranslate"><span class="pre">ensure_row_contiguous</span></code>:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included</span>
<span class="s2"> uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);</span>
<span class="s2"> T tmp = inp[loc];</span>
<span class="s2"> // Output arrays are always row contiguous</span>
<span class="s2"> out[elem] = metal::exp(tmp);</span>
<span class="s2">&quot;&quot;&quot;</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;myexp_strided&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;inp&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span>
<span class="p">)</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;myexp_strided&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;inp&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span>
<span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="p">],</span>
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">&quot;T&quot;</span><span class="p">,</span> <span class="n">mx</span><span class="o">.</span><span class="n">float32</span><span class="p">)],</span>
@@ -1100,10 +1111,67 @@ We can then use MLXs built in indexing utils to fetch the right elements for
<span class="k">return</span> <span class="n">output</span>
</pre></div>
</div>
<p>Now lets use <code class="docutils literal notranslate"><span class="pre">mx.custom_function</span></code> together with <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel</span></code>
<p>Now lets use <a class="reference internal" href="../python/_autosummary/mlx.core.custom_function.html#mlx.core.custom_function" title="mlx.core.custom_function"><code class="xref py py-func docutils literal notranslate"><span class="pre">custom_function()</span></code></a> together with <a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html#mlx.core.fast.metal_kernel" title="mlx.core.fast.metal_kernel"><code class="xref py py-func docutils literal notranslate"><span class="pre">fast.metal_kernel()</span></code></a>
to write a fast GPU kernel for both the forward and backward passes.</p>
<p>First well implement the forward pass as a fused kernel:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@mx</span><span class="o">.</span><span class="n">custom_function</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> int H = x_shape[1];</span>
<span class="s2"> int W = x_shape[2];</span>
<span class="s2"> int C = x_shape[3];</span>
<span class="s2"> int gH = grid_shape[1];</span>
<span class="s2"> int gW = grid_shape[2];</span>
<span class="s2"> int w_stride = C;</span>
<span class="s2"> int h_stride = W * w_stride;</span>
<span class="s2"> int b_stride = H * h_stride;</span>
<span class="s2"> uint grid_idx = elem / C * 2;</span>
<span class="s2"> float ix = ((grid[grid_idx] + 1) * W - 1) / 2;</span>
<span class="s2"> float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;</span>
<span class="s2"> int ix_nw = floor(ix);</span>
<span class="s2"> int iy_nw = floor(iy);</span>
<span class="s2"> int ix_ne = ix_nw + 1;</span>
<span class="s2"> int iy_ne = iy_nw;</span>
<span class="s2"> int ix_sw = ix_nw;</span>
<span class="s2"> int iy_sw = iy_nw + 1;</span>
<span class="s2"> int ix_se = ix_nw + 1;</span>
<span class="s2"> int iy_se = iy_nw + 1;</span>
<span class="s2"> T nw = (ix_se - ix) * (iy_se - iy);</span>
<span class="s2"> T ne = (ix - ix_sw) * (iy_sw - iy);</span>
<span class="s2"> T sw = (ix_ne - ix) * (iy - iy_ne);</span>
<span class="s2"> T se = (ix - ix_nw) * (iy - iy_nw);</span>
<span class="s2"> int batch_idx = elem / C / gH / gW * b_stride;</span>
<span class="s2"> int channel_idx = elem % C;</span>
<span class="s2"> int base_idx = batch_idx + channel_idx;</span>
<span class="s2"> T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];</span>
<span class="s2"> T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];</span>
<span class="s2"> T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];</span>
<span class="s2"> T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];</span>
<span class="s2"> I_nw = iy_nw &gt;= 0 &amp;&amp; iy_nw &lt;= H - 1 &amp;&amp; ix_nw &gt;= 0 &amp;&amp; ix_nw &lt;= W - 1 ? I_nw : 0;</span>
<span class="s2"> I_ne = iy_ne &gt;= 0 &amp;&amp; iy_ne &lt;= H - 1 &amp;&amp; ix_ne &gt;= 0 &amp;&amp; ix_ne &lt;= W - 1 ? I_ne : 0;</span>
<span class="s2"> I_sw = iy_sw &gt;= 0 &amp;&amp; iy_sw &lt;= H - 1 &amp;&amp; ix_sw &gt;= 0 &amp;&amp; ix_sw &lt;= W - 1 ? I_sw : 0;</span>
<span class="s2"> I_se = iy_se &gt;= 0 &amp;&amp; iy_se &lt;= H - 1 &amp;&amp; ix_se &gt;= 0 &amp;&amp; ix_se &lt;= W - 1 ? I_se : 0;</span>
<span class="s2"> out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;</span>
<span class="s2">&quot;&quot;&quot;</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;grid_sample&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;x&quot;</span><span class="p">,</span> <span class="s2">&quot;grid&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
<span class="p">)</span>
<span class="nd">@mx</span><span class="o">.</span><span class="n">custom_function</span>
<span class="k">def</span><span class="w"> </span><span class="nf">grid_sample</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">4</span><span class="p">,</span> <span class="s2">&quot;`x` must be 4D.&quot;</span>
@@ -1115,61 +1183,6 @@ to write a fast GPU kernel for both the forward and backward passes.</p>
<span class="k">assert</span> <span class="n">D</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">&quot;Last dim of `grid` must be size 2.&quot;</span>
<span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> int H = x_shape[1];</span>
<span class="s2"> int W = x_shape[2];</span>
<span class="s2"> int C = x_shape[3];</span>
<span class="s2"> int gH = grid_shape[1];</span>
<span class="s2"> int gW = grid_shape[2];</span>
<span class="s2"> int w_stride = C;</span>
<span class="s2"> int h_stride = W * w_stride;</span>
<span class="s2"> int b_stride = H * h_stride;</span>
<span class="s2"> uint grid_idx = elem / C * 2;</span>
<span class="s2"> float ix = ((grid[grid_idx] + 1) * W - 1) / 2;</span>
<span class="s2"> float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;</span>
<span class="s2"> int ix_nw = floor(ix);</span>
<span class="s2"> int iy_nw = floor(iy);</span>
<span class="s2"> int ix_ne = ix_nw + 1;</span>
<span class="s2"> int iy_ne = iy_nw;</span>
<span class="s2"> int ix_sw = ix_nw;</span>
<span class="s2"> int iy_sw = iy_nw + 1;</span>
<span class="s2"> int ix_se = ix_nw + 1;</span>
<span class="s2"> int iy_se = iy_nw + 1;</span>
<span class="s2"> T nw = (ix_se - ix) * (iy_se - iy);</span>
<span class="s2"> T ne = (ix - ix_sw) * (iy_sw - iy);</span>
<span class="s2"> T sw = (ix_ne - ix) * (iy - iy_ne);</span>
<span class="s2"> T se = (ix - ix_nw) * (iy - iy_nw);</span>
<span class="s2"> int batch_idx = elem / C / gH / gW * b_stride;</span>
<span class="s2"> int channel_idx = elem % C;</span>
<span class="s2"> int base_idx = batch_idx + channel_idx;</span>
<span class="s2"> T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];</span>
<span class="s2"> T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];</span>
<span class="s2"> T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];</span>
<span class="s2"> T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];</span>
<span class="s2"> I_nw = iy_nw &gt;= 0 &amp;&amp; iy_nw &lt;= H - 1 &amp;&amp; ix_nw &gt;= 0 &amp;&amp; ix_nw &lt;= W - 1 ? I_nw : 0;</span>
<span class="s2"> I_ne = iy_ne &gt;= 0 &amp;&amp; iy_ne &lt;= H - 1 &amp;&amp; ix_ne &gt;= 0 &amp;&amp; ix_ne &lt;= W - 1 ? I_ne : 0;</span>
<span class="s2"> I_sw = iy_sw &gt;= 0 &amp;&amp; iy_sw &lt;= H - 1 &amp;&amp; ix_sw &gt;= 0 &amp;&amp; ix_sw &lt;= W - 1 ? I_sw : 0;</span>
<span class="s2"> I_se = iy_se &gt;= 0 &amp;&amp; iy_se &lt;= H - 1 &amp;&amp; ix_se &gt;= 0 &amp;&amp; ix_se &lt;= W - 1 ? I_se : 0;</span>
<span class="s2"> out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;</span>
<span class="s2"> &quot;&quot;&quot;</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;grid_sample&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;x&quot;</span><span class="p">,</span> <span class="s2">&quot;grid&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">],</span>
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">&quot;T&quot;</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span>
@@ -1191,10 +1204,10 @@ to write a fast GPU kernel for both the forward and backward passes.</p>
</section>
<section id="grid-sample-vjp">
<h2>Grid Sample VJP<a class="headerlink" href="#grid-sample-vjp" title="Link to this heading">#</a></h2>
<p>Since we decorated <code class="docutils literal notranslate"><span class="pre">grid_sample</span></code> with <code class="docutils literal notranslate"><span class="pre">mx.custom_function</span></code>, we can now define
its custom vjp transform so MLX can differentiate it.</p>
<p>Since we decorated <code class="docutils literal notranslate"><span class="pre">grid_sample</span></code> with <a class="reference internal" href="../python/_autosummary/mlx.core.custom_function.html#mlx.core.custom_function" title="mlx.core.custom_function"><code class="xref py py-func docutils literal notranslate"><span class="pre">custom_function()</span></code></a>, we can now
define its custom vjp transform so MLX can differentiate it.</p>
<p>The backwards pass requires atomically updating <code class="docutils literal notranslate"><span class="pre">x_grad</span></code>/<code class="docutils literal notranslate"><span class="pre">grid_grad</span></code> and so
requires a few extra <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel</span></code> features:</p>
requires a few extra <a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html#mlx.core.fast.metal_kernel" title="mlx.core.fast.metal_kernel"><code class="xref py py-func docutils literal notranslate"><span class="pre">fast.metal_kernel()</span></code></a> features:</p>
<ul class="simple">
<li><dl class="simple">
<dt><code class="docutils literal notranslate"><span class="pre">init_value=0</span></code></dt><dd><p>Initialize all of the kernels outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.</p>
@@ -1210,7 +1223,107 @@ See section 6.15 of the <a class="reference external" href="https://developer.ap
</li>
</ul>
<p>We can then implement the backwards pass as follows:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@grid_sample</span><span class="o">.</span><span class="n">vjp</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> int H = x_shape[1];</span>
<span class="s2"> int W = x_shape[2];</span>
<span class="s2"> int C = x_shape[3];</span>
<span class="s2"> // Pad C to the nearest larger simdgroup size multiple</span>
<span class="s2"> int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;</span>
<span class="s2"> int gH = grid_shape[1];</span>
<span class="s2"> int gW = grid_shape[2];</span>
<span class="s2"> int w_stride = C;</span>
<span class="s2"> int h_stride = W * w_stride;</span>
<span class="s2"> int b_stride = H * h_stride;</span>
<span class="s2"> uint grid_idx = elem / C_padded * 2;</span>
<span class="s2"> float ix = ((grid[grid_idx] + 1) * W - 1) / 2;</span>
<span class="s2"> float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;</span>
<span class="s2"> int ix_nw = floor(ix);</span>
<span class="s2"> int iy_nw = floor(iy);</span>
<span class="s2"> int ix_ne = ix_nw + 1;</span>
<span class="s2"> int iy_ne = iy_nw;</span>
<span class="s2"> int ix_sw = ix_nw;</span>
<span class="s2"> int iy_sw = iy_nw + 1;</span>
<span class="s2"> int ix_se = ix_nw + 1;</span>
<span class="s2"> int iy_se = iy_nw + 1;</span>
<span class="s2"> T nw = (ix_se - ix) * (iy_se - iy);</span>
<span class="s2"> T ne = (ix - ix_sw) * (iy_sw - iy);</span>
<span class="s2"> T sw = (ix_ne - ix) * (iy - iy_ne);</span>
<span class="s2"> T se = (ix - ix_nw) * (iy - iy_nw);</span>
<span class="s2"> int batch_idx = elem / C_padded / gH / gW * b_stride;</span>
<span class="s2"> int channel_idx = elem % C_padded;</span>
<span class="s2"> int base_idx = batch_idx + channel_idx;</span>
<span class="s2"> T gix = T(0);</span>
<span class="s2"> T giy = T(0);</span>
<span class="s2"> if (channel_idx &lt; C) {</span>
<span class="s2"> int cot_index = elem / C_padded * C + channel_idx;</span>
<span class="s2"> T cot = cotangent[cot_index];</span>
<span class="s2"> if (iy_nw &gt;= 0 &amp;&amp; iy_nw &lt;= H - 1 &amp;&amp; ix_nw &gt;= 0 &amp;&amp; ix_nw &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], nw * cot, memory_order_relaxed);</span>
<span class="s2"> T I_nw = x[offset];</span>
<span class="s2"> gix -= I_nw * (iy_se - iy) * cot;</span>
<span class="s2"> giy -= I_nw * (ix_se - ix) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> if (iy_ne &gt;= 0 &amp;&amp; iy_ne &lt;= H - 1 &amp;&amp; ix_ne &gt;= 0 &amp;&amp; ix_ne &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], ne * cot, memory_order_relaxed);</span>
<span class="s2"> T I_ne = x[offset];</span>
<span class="s2"> gix += I_ne * (iy_sw - iy) * cot;</span>
<span class="s2"> giy -= I_ne * (ix - ix_sw) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> if (iy_sw &gt;= 0 &amp;&amp; iy_sw &lt;= H - 1 &amp;&amp; ix_sw &gt;= 0 &amp;&amp; ix_sw &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], sw * cot, memory_order_relaxed);</span>
<span class="s2"> T I_sw = x[offset];</span>
<span class="s2"> gix -= I_sw * (iy - iy_ne) * cot;</span>
<span class="s2"> giy += I_sw * (ix_ne - ix) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> if (iy_se &gt;= 0 &amp;&amp; iy_se &lt;= H - 1 &amp;&amp; ix_se &gt;= 0 &amp;&amp; ix_se &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_se * h_stride + ix_se * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], se * cot, memory_order_relaxed);</span>
<span class="s2"> T I_se = x[offset];</span>
<span class="s2"> gix += I_se * (iy - iy_nw) * cot;</span>
<span class="s2"> giy += I_se * (ix - ix_nw) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> }</span>
<span class="s2"> T gix_mult = W / 2;</span>
<span class="s2"> T giy_mult = H / 2;</span>
<span class="s2"> // Reduce across each simdgroup first.</span>
<span class="s2"> // This is much faster than relying purely on atomics.</span>
<span class="s2"> gix = simd_sum(gix);</span>
<span class="s2"> giy = simd_sum(giy);</span>
<span class="s2"> if (thread_index_in_simdgroup == 0) {</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);</span>
<span class="s2"> }</span>
<span class="s2">&quot;&quot;&quot;</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;grid_sample_grad&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;x&quot;</span><span class="p">,</span> <span class="s2">&quot;grid&quot;</span><span class="p">,</span> <span class="s2">&quot;cotangent&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;x_grad&quot;</span><span class="p">,</span> <span class="s2">&quot;grid_grad&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
<span class="n">atomic_outputs</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">)</span>
<span class="nd">@grid_sample</span><span class="o">.</span><span class="n">vjp</span>
<span class="k">def</span><span class="w"> </span><span class="nf">grid_sample_vjp</span><span class="p">(</span><span class="n">primals</span><span class="p">,</span> <span class="n">cotangent</span><span class="p">,</span> <span class="n">_</span><span class="p">):</span>
<span class="n">x</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">primals</span>
<span class="n">B</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
@@ -1218,105 +1331,6 @@ See section 6.15 of the <a class="reference external" href="https://developer.ap
<span class="k">assert</span> <span class="n">D</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">&quot;Last dim of `grid` must be size 2.&quot;</span>
<span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> int H = x_shape[1];</span>
<span class="s2"> int W = x_shape[2];</span>
<span class="s2"> int C = x_shape[3];</span>
<span class="s2"> // Pad C to the nearest larger simdgroup size multiple</span>
<span class="s2"> int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;</span>
<span class="s2"> int gH = grid_shape[1];</span>
<span class="s2"> int gW = grid_shape[2];</span>
<span class="s2"> int w_stride = C;</span>
<span class="s2"> int h_stride = W * w_stride;</span>
<span class="s2"> int b_stride = H * h_stride;</span>
<span class="s2"> uint grid_idx = elem / C_padded * 2;</span>
<span class="s2"> float ix = ((grid[grid_idx] + 1) * W - 1) / 2;</span>
<span class="s2"> float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;</span>
<span class="s2"> int ix_nw = floor(ix);</span>
<span class="s2"> int iy_nw = floor(iy);</span>
<span class="s2"> int ix_ne = ix_nw + 1;</span>
<span class="s2"> int iy_ne = iy_nw;</span>
<span class="s2"> int ix_sw = ix_nw;</span>
<span class="s2"> int iy_sw = iy_nw + 1;</span>
<span class="s2"> int ix_se = ix_nw + 1;</span>
<span class="s2"> int iy_se = iy_nw + 1;</span>
<span class="s2"> T nw = (ix_se - ix) * (iy_se - iy);</span>
<span class="s2"> T ne = (ix - ix_sw) * (iy_sw - iy);</span>
<span class="s2"> T sw = (ix_ne - ix) * (iy - iy_ne);</span>
<span class="s2"> T se = (ix - ix_nw) * (iy - iy_nw);</span>
<span class="s2"> int batch_idx = elem / C_padded / gH / gW * b_stride;</span>
<span class="s2"> int channel_idx = elem % C_padded;</span>
<span class="s2"> int base_idx = batch_idx + channel_idx;</span>
<span class="s2"> T gix = T(0);</span>
<span class="s2"> T giy = T(0);</span>
<span class="s2"> if (channel_idx &lt; C) {</span>
<span class="s2"> int cot_index = elem / C_padded * C + channel_idx;</span>
<span class="s2"> T cot = cotangent[cot_index];</span>
<span class="s2"> if (iy_nw &gt;= 0 &amp;&amp; iy_nw &lt;= H - 1 &amp;&amp; ix_nw &gt;= 0 &amp;&amp; ix_nw &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], nw * cot, memory_order_relaxed);</span>
<span class="s2"> T I_nw = x[offset];</span>
<span class="s2"> gix -= I_nw * (iy_se - iy) * cot;</span>
<span class="s2"> giy -= I_nw * (ix_se - ix) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> if (iy_ne &gt;= 0 &amp;&amp; iy_ne &lt;= H - 1 &amp;&amp; ix_ne &gt;= 0 &amp;&amp; ix_ne &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], ne * cot, memory_order_relaxed);</span>
<span class="s2"> T I_ne = x[offset];</span>
<span class="s2"> gix += I_ne * (iy_sw - iy) * cot;</span>
<span class="s2"> giy -= I_ne * (ix - ix_sw) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> if (iy_sw &gt;= 0 &amp;&amp; iy_sw &lt;= H - 1 &amp;&amp; ix_sw &gt;= 0 &amp;&amp; ix_sw &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], sw * cot, memory_order_relaxed);</span>
<span class="s2"> T I_sw = x[offset];</span>
<span class="s2"> gix -= I_sw * (iy - iy_ne) * cot;</span>
<span class="s2"> giy += I_sw * (ix_ne - ix) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> if (iy_se &gt;= 0 &amp;&amp; iy_se &lt;= H - 1 &amp;&amp; ix_se &gt;= 0 &amp;&amp; ix_se &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_se * h_stride + ix_se * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], se * cot, memory_order_relaxed);</span>
<span class="s2"> T I_se = x[offset];</span>
<span class="s2"> gix += I_se * (iy - iy_nw) * cot;</span>
<span class="s2"> giy += I_se * (ix - ix_nw) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> }</span>
<span class="s2"> T gix_mult = W / 2;</span>
<span class="s2"> T giy_mult = H / 2;</span>
<span class="s2"> // Reduce across each simdgroup first.</span>
<span class="s2"> // This is much faster than relying purely on atomics.</span>
<span class="s2"> gix = simd_sum(gix);</span>
<span class="s2"> giy = simd_sum(giy);</span>
<span class="s2"> if (thread_index_in_simdgroup == 0) {</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);</span>
<span class="s2"> }</span>
<span class="s2"> &quot;&quot;&quot;</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;grid_sample_grad&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;x&quot;</span><span class="p">,</span> <span class="s2">&quot;grid&quot;</span><span class="p">,</span> <span class="s2">&quot;cotangent&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;x_grad&quot;</span><span class="p">,</span> <span class="s2">&quot;grid_grad&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
<span class="n">atomic_outputs</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># pad the output channels to simd group size</span>
<span class="c1"># so that our `simd_sum`s don&#39;t overlap.</span>
<span class="n">simdgroup_size</span> <span class="o">=</span> <span class="mi">32</span>