mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
rebase
This commit is contained in:
@@ -864,7 +864,7 @@
|
||||
<article class="bd-article">
|
||||
|
||||
<section id="custom-metal-kernels">
|
||||
<h1>Custom Metal Kernels<a class="headerlink" href="#custom-metal-kernels" title="Link to this heading">#</a></h1>
|
||||
<span id="id1"></span><h1>Custom Metal Kernels<a class="headerlink" href="#custom-metal-kernels" title="Link to this heading">#</a></h1>
|
||||
<p>MLX supports writing custom Metal kernels through the Python and C++ APIs.</p>
|
||||
<section id="simple-example">
|
||||
<h2>Simple Example<a class="headerlink" href="#simple-example" title="Link to this heading">#</a></h2>
|
||||
@@ -947,6 +947,9 @@ All the attributes defined in Table 5.8 of the <a class="reference external" hre
|
||||
<span class="k">template</span><span class="w"> </span><span class="p">[[</span><span class="n">host_name</span><span class="p">(</span><span class="s">"custom_kernel_myexp_float"</span><span class="p">)]]</span><span class="w"> </span><span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="k">decltype</span><span class="p">(</span><span class="n">custom_kernel_myexp_float</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">)</span><span class="w"> </span><span class="n">custom_kernel_myexp_float</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">;</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Note: <code class="docutils literal notranslate"><span class="pre">grid</span></code> and <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> are parameters to the Metal <a class="reference external" href="https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads">dispatchThreads</a> function.
|
||||
This means we will launch <code class="docutils literal notranslate"><span class="pre">mx.prod(grid)</span></code> threads, subdivided into <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.</p>
|
||||
<p>Passing <code class="docutils literal notranslate"><span class="pre">verbose=True</span></code> to <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel.__call__</span></code> will print the generated code for debugging purposes.</p>
|
||||
</section>
|
||||
<section id="using-shape-strides">
|
||||
|
Reference in New Issue
Block a user