This commit is contained in:
CircleCI Docs
2024-11-05 19:54:16 +00:00
parent 0e14896450
commit 1b73eb0da4
51 changed files with 2277 additions and 1802 deletions

View File

@@ -864,7 +864,7 @@
<article class="bd-article">
<section id="custom-metal-kernels">
<h1>Custom Metal Kernels<a class="headerlink" href="#custom-metal-kernels" title="Link to this heading">#</a></h1>
<span id="id1"></span><h1>Custom Metal Kernels<a class="headerlink" href="#custom-metal-kernels" title="Link to this heading">#</a></h1>
<p>MLX supports writing custom Metal kernels through the Python and C++ APIs.</p>
<section id="simple-example">
<h2>Simple Example<a class="headerlink" href="#simple-example" title="Link to this heading">#</a></h2>
@@ -947,6 +947,9 @@ All the attributes defined in Table 5.8 of the <a class="reference external" hre
<span class="k">template</span><span class="w"> </span><span class="p">[[</span><span class="n">host_name</span><span class="p">(</span><span class="s">&quot;custom_kernel_myexp_float&quot;</span><span class="p">)]]</span><span class="w"> </span><span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="k">decltype</span><span class="p">(</span><span class="n">custom_kernel_myexp_float</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">)</span><span class="w"> </span><span class="n">custom_kernel_myexp_float</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">;</span>
</pre></div>
</div>
<p>Note: <code class="docutils literal notranslate"><span class="pre">grid</span></code> and <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> are parameters to the Metal <a class="reference external" href="https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads">dispatchThreads</a> function.
This means we will launch <code class="docutils literal notranslate"><span class="pre">mx.prod(grid)</span></code> threads, subdivided into <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.</p>
<p>Passing <code class="docutils literal notranslate"><span class="pre">verbose=True</span></code> to <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel.__call__</span></code> will print the generated code for debugging purposes.</p>
</section>
<section id="using-shape-strides">