This commit is contained in:
Awni Hannun
2024-09-17 12:06:14 -07:00
committed by CircleCI Docs
parent 9da49a07a4
commit d44f06ae79
739 changed files with 28107 additions and 8524 deletions

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<title>Custom Metal Kernels &#8212; MLX 0.17.0 documentation</title>
<title>Custom Metal Kernels &#8212; MLX 0.17.3 documentation</title>
@@ -36,7 +36,7 @@
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=5b4479735964841361fd" />
<script src="../_static/vendor/fontawesome/6.1.2/js/all.min.js?digest=5b4479735964841361fd"></script>
<script src="../_static/documentation_options.js?v=50b3d22e"></script>
<script src="../_static/documentation_options.js?v=47c75248"></script>
<script src="../_static/doctools.js?v=888ff710"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=efea14e4"></script>
@@ -130,8 +130,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.17.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.17.0 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.17.3 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.17.3 documentation - Home"/>`);</script>
</a></div>
@@ -219,8 +219,9 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.sqrt.html">mlx.core.array.sqrt</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.square.html">mlx.core.array.square</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.squeeze.html">mlx.core.array.squeeze</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.swapaxes.html">mlx.core.array.swapaxes</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.std.html">mlx.core.array.std</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.sum.html">mlx.core.array.sum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.swapaxes.html">mlx.core.array.swapaxes</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.transpose.html">mlx.core.array.transpose</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.T.html">mlx.core.array.T</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.var.html">mlx.core.array.var</a></li>
@@ -282,6 +283,10 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.convolve.html">mlx.core.convolve</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv1d.html">mlx.core.conv1d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv2d.html">mlx.core.conv2d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv3d.html">mlx.core.conv3d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_transpose1d.html">mlx.core.conv_transpose1d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_transpose2d.html">mlx.core.conv_transpose2d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_transpose3d.html">mlx.core.conv_transpose3d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_general.html">mlx.core.conv_general</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cos.html">mlx.core.cos</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cosh.html">mlx.core.cosh</a></li>
@@ -315,6 +320,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.hadamard_transform.html">mlx.core.hadamard_transform</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.identity.html">mlx.core.identity</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.inner.html">mlx.core.inner</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isfinite.html">mlx.core.isfinite</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isclose.html">mlx.core.isclose</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isinf.html">mlx.core.isinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isnan.html">mlx.core.isnan</a></li>
@@ -514,6 +520,9 @@
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv3d.html">mlx.nn.Conv3d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ConvTranspose1d.html">mlx.nn.ConvTranspose1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ConvTranspose2d.html">mlx.nn.ConvTranspose2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ConvTranspose3d.html">mlx.nn.ConvTranspose3d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout.html">mlx.nn.Dropout</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout2d.html">mlx.nn.Dropout2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout3d.html">mlx.nn.Dropout3d</a></li>
@@ -651,6 +660,9 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.init.html">mlx.core.distributed.init</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.all_sum.html">mlx.core.distributed.all_sum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.all_gather.html">mlx.core.distributed.all_gather</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.send.html">mlx.core.distributed.send</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.recv.html">mlx.core.distributed.recv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.recv_like.html">mlx.core.distributed.recv_like</a></li>
</ul>
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a><input class="toctree-checkbox" id="toctree-checkbox-22" name="toctree-checkbox-22" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-22"><i class="fa-solid fa-chevron-down"></i></label><ul>
@@ -836,6 +848,8 @@ document.write(`
<ul class="visible nav section-nav flex-column">
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#simple-example">Simple Example</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#using-shape-strides">Using Shape/Strides</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#complex-example">Complex Example</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#grid-sample-vjp">Grid Sample VJP</a></li>
</ul>
</nav>
</div>
@@ -862,17 +876,19 @@ document.write(`
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;myexp&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;inp&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
<span class="n">inputs</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;inp&quot;</span><span class="p">:</span> <span class="n">a</span><span class="p">},</span>
<span class="n">template</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;T&quot;</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">float32</span><span class="p">},</span>
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="p">],</span>
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">&quot;T&quot;</span><span class="p">,</span> <span class="n">mx</span><span class="o">.</span><span class="n">float32</span><span class="p">)],</span>
<span class="n">grid</span><span class="o">=</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">threadgroup</span><span class="o">=</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">output_shapes</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;out&quot;</span><span class="p">:</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">},</span>
<span class="n">output_dtypes</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;out&quot;</span><span class="p">:</span> <span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">},</span>
<span class="n">output_shapes</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">],</span>
<span class="n">output_dtypes</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">],</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">outputs</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">]</span>
<span class="k">return</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">16</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
@@ -886,20 +902,21 @@ document.write(`
<p>The full function signature will be generated using:</p>
<ul class="simple">
<li><dl class="simple">
<dt>The keys and shapes/dtypes of <code class="docutils literal notranslate"><span class="pre">inputs</span></code></dt><dd><p>In the above, <code class="docutils literal notranslate"><span class="pre">a</span></code> is an <code class="docutils literal notranslate"><span class="pre">mx.array</span></code> of type <code class="docutils literal notranslate"><span class="pre">mx.float16</span></code> and we pass it with the key <code class="docutils literal notranslate"><span class="pre">inp</span></code>
<dt>The shapes/dtypes of <code class="docutils literal notranslate"><span class="pre">inputs</span></code></dt><dd><p>In the above, <code class="docutils literal notranslate"><span class="pre">a</span></code> is an <code class="docutils literal notranslate"><span class="pre">mx.array</span></code> of type <code class="docutils literal notranslate"><span class="pre">mx.float16</span></code> and we pass it with the key <code class="docutils literal notranslate"><span class="pre">inp</span></code>
so we will add <code class="docutils literal notranslate"><span class="pre">const</span> <span class="pre">device</span> <span class="pre">float16_t*</span> <span class="pre">inp</span></code> to the signature.
<code class="docutils literal notranslate"><span class="pre">inp_shape</span></code>, <code class="docutils literal notranslate"><span class="pre">inp_strides</span></code> and <code class="docutils literal notranslate"><span class="pre">inp_ndim</span></code> are also added for convenience.</p>
<code class="docutils literal notranslate"><span class="pre">inp_shape</span></code>, <code class="docutils literal notranslate"><span class="pre">inp_strides</span></code> and <code class="docutils literal notranslate"><span class="pre">inp_ndim</span></code> are also added for convenience if they are present
in <code class="docutils literal notranslate"><span class="pre">source</span></code>.</p>
</dd>
</dl>
</li>
<li><dl class="simple">
<dt>The keys and values of <code class="docutils literal notranslate"><span class="pre">output_shapes</span></code> and <code class="docutils literal notranslate"><span class="pre">output_dtypes</span></code></dt><dd><p>In the above, <code class="docutils literal notranslate"><span class="pre">out</span></code> is an <code class="docutils literal notranslate"><span class="pre">mx.array</span></code> of type <code class="docutils literal notranslate"><span class="pre">mx.float16</span></code>
<dt>The list of <code class="docutils literal notranslate"><span class="pre">output_dtypes</span></code></dt><dd><p>In the above, <code class="docutils literal notranslate"><span class="pre">out</span></code> is an <code class="docutils literal notranslate"><span class="pre">mx.array</span></code> of type <code class="docutils literal notranslate"><span class="pre">mx.float16</span></code>
so we add <code class="docutils literal notranslate"><span class="pre">device</span> <span class="pre">float16_t*</span> <span class="pre">out</span></code>.</p>
</dd>
</dl>
</li>
<li><dl class="simple">
<dt>Template parameters passed using <code class="docutils literal notranslate"><span class="pre">template</span></code></dt><dd><p>In the above, <code class="docutils literal notranslate"><span class="pre">template={&quot;T&quot;:</span> <span class="pre">mx.float32}</span></code> adds a template of <code class="docutils literal notranslate"><span class="pre">template</span> <span class="pre">&lt;typename</span> <span class="pre">T&gt;</span></code> to the function
<dt>Template parameters passed using <code class="docutils literal notranslate"><span class="pre">template</span></code></dt><dd><p>In the above, <code class="docutils literal notranslate"><span class="pre">template=[(&quot;T&quot;,</span> <span class="pre">mx.float32)]</span></code> adds a template of <code class="docutils literal notranslate"><span class="pre">template</span> <span class="pre">&lt;typename</span> <span class="pre">T&gt;</span></code> to the function
and instantiates the template with <code class="docutils literal notranslate"><span class="pre">custom_kernel_myexp_float&lt;float&gt;</span></code>.
Template parameters can be <code class="docutils literal notranslate"><span class="pre">mx.core.Dtype</span></code>, <code class="docutils literal notranslate"><span class="pre">int</span></code> or <code class="docutils literal notranslate"><span class="pre">bool</span></code>.</p>
</dd>
@@ -928,7 +945,7 @@ All the attributes defined in Table 5.8 of the <a class="reference external" hre
<span class="k">template</span><span class="w"> </span><span class="p">[[</span><span class="n">host_name</span><span class="p">(</span><span class="s">&quot;custom_kernel_myexp_float&quot;</span><span class="p">)]]</span><span class="w"> </span><span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="k">decltype</span><span class="p">(</span><span class="n">custom_kernel_myexp_float</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">)</span><span class="w"> </span><span class="n">custom_kernel_myexp_float</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">;</span>
</pre></div>
</div>
<p>You can print the generated code for a <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel</span></code> by passing <code class="docutils literal notranslate"><span class="pre">verbose=True</span></code> when you call it.</p>
<p>Passing <code class="docutils literal notranslate"><span class="pre">verbose=True</span></code> to <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel.__call__</span></code> will print the generated code for debugging purposes.</p>
</section>
<section id="using-shape-strides">
<h2>Using Shape/Strides<a class="headerlink" href="#using-shape-strides" title="Link to this heading">#</a></h2>
@@ -952,18 +969,20 @@ We can then use MLXs built in indexing utils to fetch the right elements for
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;myexp_strided&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;inp&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span>
<span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
<span class="n">inputs</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;inp&quot;</span><span class="p">:</span> <span class="n">a</span><span class="p">},</span>
<span class="n">template</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;T&quot;</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">float32</span><span class="p">},</span>
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="p">],</span>
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">&quot;T&quot;</span><span class="p">,</span> <span class="n">mx</span><span class="o">.</span><span class="n">float32</span><span class="p">)],</span>
<span class="n">grid</span><span class="o">=</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">threadgroup</span><span class="o">=</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">output_shapes</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;out&quot;</span><span class="p">:</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">},</span>
<span class="n">output_dtypes</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;out&quot;</span><span class="p">:</span> <span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">},</span>
<span class="n">output_shapes</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">],</span>
<span class="n">output_dtypes</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">],</span>
<span class="n">ensure_row_contiguous</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">outputs</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">]</span>
<span class="k">return</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">16</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="c1"># make non-contiguous</span>
@@ -973,6 +992,289 @@ We can then use MLXs built in indexing utils to fetch the right elements for
</pre></div>
</div>
</section>
<section id="complex-example">
<h2>Complex Example<a class="headerlink" href="#complex-example" title="Link to this heading">#</a></h2>
<p>Lets implement a more complex example: <code class="docutils literal notranslate"><span class="pre">grid_sample</span></code> in <code class="docutils literal notranslate"><span class="pre">&quot;bilinear&quot;</span></code> mode.</p>
<p>Well start with the following MLX implementation using standard ops:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">grid_sample_ref</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">):</span>
<span class="n">N</span><span class="p">,</span> <span class="n">H_in</span><span class="p">,</span> <span class="n">W_in</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ix</span> <span class="o">=</span> <span class="p">((</span><span class="n">grid</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">W_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
<span class="n">iy</span> <span class="o">=</span> <span class="p">((</span><span class="n">grid</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">H_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
<span class="n">ix_nw</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">ix</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">iy_nw</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">iy</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">ix_ne</span> <span class="o">=</span> <span class="n">ix_nw</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">iy_ne</span> <span class="o">=</span> <span class="n">iy_nw</span>
<span class="n">ix_sw</span> <span class="o">=</span> <span class="n">ix_nw</span>
<span class="n">iy_sw</span> <span class="o">=</span> <span class="n">iy_nw</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">ix_se</span> <span class="o">=</span> <span class="n">ix_nw</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">iy_se</span> <span class="o">=</span> <span class="n">iy_nw</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">nw</span> <span class="o">=</span> <span class="p">(</span><span class="n">ix_se</span> <span class="o">-</span> <span class="n">ix</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">iy_se</span> <span class="o">-</span> <span class="n">iy</span><span class="p">)</span>
<span class="n">ne</span> <span class="o">=</span> <span class="p">(</span><span class="n">ix</span> <span class="o">-</span> <span class="n">ix_sw</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">iy_sw</span> <span class="o">-</span> <span class="n">iy</span><span class="p">)</span>
<span class="n">sw</span> <span class="o">=</span> <span class="p">(</span><span class="n">ix_ne</span> <span class="o">-</span> <span class="n">ix</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">iy</span> <span class="o">-</span> <span class="n">iy_ne</span><span class="p">)</span>
<span class="n">se</span> <span class="o">=</span> <span class="p">(</span><span class="n">ix</span> <span class="o">-</span> <span class="n">ix_nw</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">iy</span> <span class="o">-</span> <span class="n">iy_nw</span><span class="p">)</span>
<span class="n">I_nw</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">N</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> <span class="n">iy_nw</span><span class="p">,</span> <span class="n">ix_nw</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">I_ne</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">N</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> <span class="n">iy_ne</span><span class="p">,</span> <span class="n">ix_ne</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">I_sw</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">N</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> <span class="n">iy_sw</span><span class="p">,</span> <span class="n">ix_sw</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">I_se</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">N</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> <span class="n">iy_se</span><span class="p">,</span> <span class="n">ix_se</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">mask_nw</span> <span class="o">=</span> <span class="p">(</span><span class="n">iy_nw</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">iy_nw</span> <span class="o">&lt;=</span> <span class="n">H_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_nw</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_nw</span> <span class="o">&lt;=</span> <span class="n">W_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">mask_ne</span> <span class="o">=</span> <span class="p">(</span><span class="n">iy_ne</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">iy_ne</span> <span class="o">&lt;=</span> <span class="n">H_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_ne</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_ne</span> <span class="o">&lt;=</span> <span class="n">W_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">mask_sw</span> <span class="o">=</span> <span class="p">(</span><span class="n">iy_sw</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">iy_sw</span> <span class="o">&lt;=</span> <span class="n">H_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_sw</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_sw</span> <span class="o">&lt;=</span> <span class="n">W_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">mask_se</span> <span class="o">=</span> <span class="p">(</span><span class="n">iy_se</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">iy_se</span> <span class="o">&lt;=</span> <span class="n">H_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_se</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_se</span> <span class="o">&lt;=</span> <span class="n">W_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">I_nw</span> <span class="o">*=</span> <span class="n">mask_nw</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">I_ne</span> <span class="o">*=</span> <span class="n">mask_ne</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">I_sw</span> <span class="o">*=</span> <span class="n">mask_sw</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">I_se</span> <span class="o">*=</span> <span class="n">mask_se</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">nw</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">I_nw</span> <span class="o">+</span> <span class="n">ne</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">I_ne</span> <span class="o">+</span> <span class="n">sw</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">I_sw</span> <span class="o">+</span> <span class="n">se</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">I_se</span>
<span class="k">return</span> <span class="n">output</span>
</pre></div>
</div>
<p>Now lets use <code class="docutils literal notranslate"><span class="pre">mx.custom_function</span></code> together with <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel</span></code>
to write a fast GPU kernel for both the forward and backward passes.</p>
<p>First well implement the forward pass as a fused kernel:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@mx</span><span class="o">.</span><span class="n">custom_function</span>
<span class="k">def</span> <span class="nf">grid_sample</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">4</span><span class="p">,</span> <span class="s2">&quot;`x` must be 4D.&quot;</span>
<span class="k">assert</span> <span class="n">grid</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">4</span><span class="p">,</span> <span class="s2">&quot;`grid` must be 4D.&quot;</span>
<span class="n">B</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
<span class="n">_</span><span class="p">,</span> <span class="n">gN</span><span class="p">,</span> <span class="n">gM</span><span class="p">,</span> <span class="n">D</span> <span class="o">=</span> <span class="n">grid</span><span class="o">.</span><span class="n">shape</span>
<span class="n">out_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">gN</span><span class="p">,</span> <span class="n">gM</span><span class="p">,</span> <span class="n">C</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">D</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">&quot;Last dim of `grid` must be size 2.&quot;</span>
<span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> int H = x_shape[1];</span>
<span class="s2"> int W = x_shape[2];</span>
<span class="s2"> int C = x_shape[3];</span>
<span class="s2"> int gH = grid_shape[1];</span>
<span class="s2"> int gW = grid_shape[2];</span>
<span class="s2"> int w_stride = C;</span>
<span class="s2"> int h_stride = W * w_stride;</span>
<span class="s2"> int b_stride = H * h_stride;</span>
<span class="s2"> uint grid_idx = elem / C * 2;</span>
<span class="s2"> float ix = ((grid[grid_idx] + 1) * W - 1) / 2;</span>
<span class="s2"> float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;</span>
<span class="s2"> int ix_nw = floor(ix);</span>
<span class="s2"> int iy_nw = floor(iy);</span>
<span class="s2"> int ix_ne = ix_nw + 1;</span>
<span class="s2"> int iy_ne = iy_nw;</span>
<span class="s2"> int ix_sw = ix_nw;</span>
<span class="s2"> int iy_sw = iy_nw + 1;</span>
<span class="s2"> int ix_se = ix_nw + 1;</span>
<span class="s2"> int iy_se = iy_nw + 1;</span>
<span class="s2"> T nw = (ix_se - ix) * (iy_se - iy);</span>
<span class="s2"> T ne = (ix - ix_sw) * (iy_sw - iy);</span>
<span class="s2"> T sw = (ix_ne - ix) * (iy - iy_ne);</span>
<span class="s2"> T se = (ix - ix_nw) * (iy - iy_nw);</span>
<span class="s2"> int batch_idx = elem / C / gH / gW * b_stride;</span>
<span class="s2"> int channel_idx = elem % C;</span>
<span class="s2"> int base_idx = batch_idx + channel_idx;</span>
<span class="s2"> T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];</span>
<span class="s2"> T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];</span>
<span class="s2"> T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];</span>
<span class="s2"> T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];</span>
<span class="s2"> I_nw = iy_nw &gt;= 0 &amp;&amp; iy_nw &lt;= H - 1 &amp;&amp; ix_nw &gt;= 0 &amp;&amp; ix_nw &lt;= W - 1 ? I_nw : 0;</span>
<span class="s2"> I_ne = iy_ne &gt;= 0 &amp;&amp; iy_ne &lt;= H - 1 &amp;&amp; ix_ne &gt;= 0 &amp;&amp; ix_ne &lt;= W - 1 ? I_ne : 0;</span>
<span class="s2"> I_sw = iy_sw &gt;= 0 &amp;&amp; iy_sw &lt;= H - 1 &amp;&amp; ix_sw &gt;= 0 &amp;&amp; ix_sw &lt;= W - 1 ? I_sw : 0;</span>
<span class="s2"> I_se = iy_se &gt;= 0 &amp;&amp; iy_se &lt;= H - 1 &amp;&amp; ix_se &gt;= 0 &amp;&amp; ix_se &lt;= W - 1 ? I_se : 0;</span>
<span class="s2"> out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;</span>
<span class="s2"> &quot;&quot;&quot;</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;grid_sample&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;x&quot;</span><span class="p">,</span> <span class="s2">&quot;grid&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">],</span>
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">&quot;T&quot;</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span>
<span class="n">output_shapes</span><span class="o">=</span><span class="p">[</span><span class="n">out_shape</span><span class="p">],</span>
<span class="n">output_dtypes</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">],</span>
<span class="n">grid</span><span class="o">=</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">out_shape</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">threadgroup</span><span class="o">=</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</pre></div>
</div>
<p>For a reasonably sized input such as:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">x</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="mi">64</span><span class="p">)</span>
<span class="n">grid</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
</pre></div>
</div>
<p>On an M1 Max, we see a big performance improvement:</p>
<p><code class="docutils literal notranslate"><span class="pre">55.7ms</span> <span class="pre">-&gt;</span> <span class="pre">6.7ms</span> <span class="pre">=&gt;</span> <span class="pre">8x</span> <span class="pre">speed</span> <span class="pre">up</span></code></p>
</section>
<section id="grid-sample-vjp">
<h2>Grid Sample VJP<a class="headerlink" href="#grid-sample-vjp" title="Link to this heading">#</a></h2>
<p>Since we decorated <code class="docutils literal notranslate"><span class="pre">grid_sample</span></code> with <code class="docutils literal notranslate"><span class="pre">mx.custom_function</span></code>, we can now define
its custom vjp transform so MLX can differentiate it.</p>
<p>The backwards pass requires atomically updating <code class="docutils literal notranslate"><span class="pre">x_grad</span></code>/<code class="docutils literal notranslate"><span class="pre">grid_grad</span></code> and so
requires a few extra <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel</span></code> features:</p>
<ul class="simple">
<li><dl class="simple">
<dt><code class="docutils literal notranslate"><span class="pre">init_value=0</span></code></dt><dd><p>Initialize all of the kernels outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.</p>
</dd>
</dl>
</li>
<li><dl class="simple">
<dt><code class="docutils literal notranslate"><span class="pre">atomic_outputs=True</span></code></dt><dd><p>Designate all of the kernel outputs as <code class="docutils literal notranslate"><span class="pre">atomic</span></code> in the function signature.
This means we can use Metals <code class="docutils literal notranslate"><span class="pre">atomic</span></code> features to simultaneously update the <code class="docutils literal notranslate"><span class="pre">x_grad</span></code> and <code class="docutils literal notranslate"><span class="pre">grid_grad</span></code> arrays from multiple threadgroups.
See section 6.15 of the <a class="reference external" href="https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf">Metal Shading Language Specification</a> for more details.</p>
</dd>
</dl>
</li>
</ul>
<p>We can then implement the backwards pass as follows:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@grid_sample</span><span class="o">.</span><span class="n">vjp</span>
<span class="k">def</span> <span class="nf">grid_sample_vjp</span><span class="p">(</span><span class="n">primals</span><span class="p">,</span> <span class="n">cotangent</span><span class="p">,</span> <span class="n">_</span><span class="p">):</span>
<span class="n">x</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">primals</span>
<span class="n">B</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
<span class="n">_</span><span class="p">,</span> <span class="n">gN</span><span class="p">,</span> <span class="n">gM</span><span class="p">,</span> <span class="n">D</span> <span class="o">=</span> <span class="n">grid</span><span class="o">.</span><span class="n">shape</span>
<span class="k">assert</span> <span class="n">D</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">&quot;Last dim of `grid` must be size 2.&quot;</span>
<span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> int H = x_shape[1];</span>
<span class="s2"> int W = x_shape[2];</span>
<span class="s2"> int C = x_shape[3];</span>
<span class="s2"> // Pad C to the nearest larger simdgroup size multiple</span>
<span class="s2"> int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;</span>
<span class="s2"> int gH = grid_shape[1];</span>
<span class="s2"> int gW = grid_shape[2];</span>
<span class="s2"> int w_stride = C;</span>
<span class="s2"> int h_stride = W * w_stride;</span>
<span class="s2"> int b_stride = H * h_stride;</span>
<span class="s2"> uint grid_idx = elem / C_padded * 2;</span>
<span class="s2"> float ix = ((grid[grid_idx] + 1) * W - 1) / 2;</span>
<span class="s2"> float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;</span>
<span class="s2"> int ix_nw = floor(ix);</span>
<span class="s2"> int iy_nw = floor(iy);</span>
<span class="s2"> int ix_ne = ix_nw + 1;</span>
<span class="s2"> int iy_ne = iy_nw;</span>
<span class="s2"> int ix_sw = ix_nw;</span>
<span class="s2"> int iy_sw = iy_nw + 1;</span>
<span class="s2"> int ix_se = ix_nw + 1;</span>
<span class="s2"> int iy_se = iy_nw + 1;</span>
<span class="s2"> T nw = (ix_se - ix) * (iy_se - iy);</span>
<span class="s2"> T ne = (ix - ix_sw) * (iy_sw - iy);</span>
<span class="s2"> T sw = (ix_ne - ix) * (iy - iy_ne);</span>
<span class="s2"> T se = (ix - ix_nw) * (iy - iy_nw);</span>
<span class="s2"> int batch_idx = elem / C_padded / gH / gW * b_stride;</span>
<span class="s2"> int channel_idx = elem % C_padded;</span>
<span class="s2"> int base_idx = batch_idx + channel_idx;</span>
<span class="s2"> T gix = T(0);</span>
<span class="s2"> T giy = T(0);</span>
<span class="s2"> if (channel_idx &lt; C) {</span>
<span class="s2"> int cot_index = elem / C_padded * C + channel_idx;</span>
<span class="s2"> T cot = cotangent[cot_index];</span>
<span class="s2"> if (iy_nw &gt;= 0 &amp;&amp; iy_nw &lt;= H - 1 &amp;&amp; ix_nw &gt;= 0 &amp;&amp; ix_nw &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], nw * cot, memory_order_relaxed);</span>
<span class="s2"> T I_nw = x[offset];</span>
<span class="s2"> gix -= I_nw * (iy_se - iy) * cot;</span>
<span class="s2"> giy -= I_nw * (ix_se - ix) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> if (iy_ne &gt;= 0 &amp;&amp; iy_ne &lt;= H - 1 &amp;&amp; ix_ne &gt;= 0 &amp;&amp; ix_ne &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], ne * cot, memory_order_relaxed);</span>
<span class="s2"> T I_ne = x[offset];</span>
<span class="s2"> gix += I_ne * (iy_sw - iy) * cot;</span>
<span class="s2"> giy -= I_ne * (ix - ix_sw) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> if (iy_sw &gt;= 0 &amp;&amp; iy_sw &lt;= H - 1 &amp;&amp; ix_sw &gt;= 0 &amp;&amp; ix_sw &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], sw * cot, memory_order_relaxed);</span>
<span class="s2"> T I_sw = x[offset];</span>
<span class="s2"> gix -= I_sw * (iy - iy_ne) * cot;</span>
<span class="s2"> giy += I_sw * (ix_ne - ix) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> if (iy_se &gt;= 0 &amp;&amp; iy_se &lt;= H - 1 &amp;&amp; ix_se &gt;= 0 &amp;&amp; ix_se &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_se * h_stride + ix_se * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], se * cot, memory_order_relaxed);</span>
<span class="s2"> T I_se = x[offset];</span>
<span class="s2"> gix += I_se * (iy - iy_nw) * cot;</span>
<span class="s2"> giy += I_se * (ix - ix_nw) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> }</span>
<span class="s2"> T gix_mult = W / 2;</span>
<span class="s2"> T giy_mult = H / 2;</span>
<span class="s2"> // Reduce across each simdgroup first.</span>
<span class="s2"> // This is much faster than relying purely on atomics.</span>
<span class="s2"> gix = simd_sum(gix);</span>
<span class="s2"> giy = simd_sum(giy);</span>
<span class="s2"> if (thread_index_in_simdgroup == 0) {</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);</span>
<span class="s2"> }</span>
<span class="s2"> &quot;&quot;&quot;</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;grid_sample_grad&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;x&quot;</span><span class="p">,</span> <span class="s2">&quot;grid&quot;</span><span class="p">,</span> <span class="s2">&quot;cotangent&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;x_grad&quot;</span><span class="p">,</span> <span class="s2">&quot;grid_grad&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
<span class="n">atomic_outputs</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># pad the output channels to simd group size</span>
<span class="c1"># so that our `simd_sum`s don&#39;t overlap.</span>
<span class="n">simdgroup_size</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">C_padded</span> <span class="o">=</span> <span class="p">(</span><span class="n">C</span> <span class="o">+</span> <span class="n">simdgroup_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">simdgroup_size</span> <span class="o">*</span> <span class="n">simdgroup_size</span>
<span class="n">grid_size</span> <span class="o">=</span> <span class="n">B</span> <span class="o">*</span> <span class="n">gN</span> <span class="o">*</span> <span class="n">gM</span> <span class="o">*</span> <span class="n">C_padded</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">,</span> <span class="n">cotangent</span><span class="p">],</span>
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">&quot;T&quot;</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span>
<span class="n">output_shapes</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">grid</span><span class="o">.</span><span class="n">shape</span><span class="p">],</span>
<span class="n">output_dtypes</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">],</span>
<span class="n">grid</span><span class="o">=</span><span class="p">(</span><span class="n">grid_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">threadgroup</span><span class="o">=</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">init_value</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
</pre></div>
</div>
<p>Theres an even larger speed up for the vjp:</p>
<p><code class="docutils literal notranslate"><span class="pre">676.4ms</span> <span class="pre">-&gt;</span> <span class="pre">16.7ms</span> <span class="pre">=&gt;</span> <span class="pre">40x</span> <span class="pre">speed</span> <span class="pre">up</span></code></p>
</section>
</section>
@@ -1012,6 +1314,8 @@ We can then use MLXs built in indexing utils to fetch the right elements for
<ul class="visible nav section-nav flex-column">
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#simple-example">Simple Example</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#using-shape-strides">Using Shape/Strides</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#complex-example">Complex Example</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#grid-sample-vjp">Grid Sample VJP</a></li>
</ul>
</nav></div>