This commit is contained in:
CircleCI Docs
2024-11-22 20:24:16 +00:00
parent a84697024f
commit 379b7b4027
905 changed files with 30035 additions and 16934 deletions

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Custom Extensions in MLX &#8212; MLX 0.20.0 documentation</title>
<title>Custom Extensions in MLX &#8212; MLX 0.21.0 documentation</title>
@@ -39,7 +39,7 @@
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
<script src="../_static/documentation_options.js?v=eb97cb82"></script>
<script src="../_static/documentation_options.js?v=174dfe6e"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -51,7 +51,7 @@
<link rel="prev" title="Operations" href="../cpp/ops.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="0.20.0" />
<meta name="docsearch:version" content="0.21.0" />
</head>
@@ -130,8 +130,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.20.0 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.20.0 documentation - Home"/>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.0 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.0 documentation - Home"/>
</a></div>
@@ -444,7 +444,6 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.layer_norm.html">mlx.core.fast.layer_norm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.rope.html">mlx.core.fast.rope</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.scaled_dot_product_attention.html">mlx.core.fast.scaled_dot_product_attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.affine_quantize.html">mlx.core.fast.affine_quantize</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html">mlx.core.fast.metal_kernel</a></li>
</ul>
</details></li>
@@ -521,6 +520,7 @@
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ALiBi.html">mlx.nn.ALiBi</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.AvgPool1d.html">mlx.nn.AvgPool1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.AvgPool2d.html">mlx.nn.AvgPool2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.AvgPool3d.html">mlx.nn.AvgPool3d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.BatchNorm.html">mlx.nn.BatchNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.CELU.html">mlx.nn.CELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
@@ -550,6 +550,7 @@
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LSTM.html">mlx.nn.LSTM</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MaxPool1d.html">mlx.nn.MaxPool1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MaxPool2d.html">mlx.nn.MaxPool2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MaxPool3d.html">mlx.nn.MaxPool3d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Mish.html">mlx.nn.Mish</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.PReLU.html">mlx.nn.PReLU</a></li>
@@ -1342,7 +1343,7 @@ below.</p>
<span class="w"> </span><span class="c1">// Prepare to encode kernel</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">compute_encoder</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">d</span><span class="p">.</span><span class="n">get_command_encoder</span><span class="p">(</span><span class="n">s</span><span class="p">.</span><span class="n">index</span><span class="p">);</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-&gt;</span><span class="n">setComputePipelineState</span><span class="p">(</span><span class="n">kernel</span><span class="p">);</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_compute_pipeline_state</span><span class="p">(</span><span class="n">kernel</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Kernel parameters are registered with buffer indices corresponding to</span>
<span class="w"> </span><span class="c1">// those in the kernel declaration at axpby.metal</span>
@@ -1357,14 +1358,14 @@ below.</p>
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_output_array</span><span class="p">(</span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Encode alpha and beta</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-&gt;</span><span class="n">setBytes</span><span class="p">(</span><span class="o">&amp;</span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">),</span><span class="w"> </span><span class="mi">3</span><span class="p">);</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-&gt;</span><span class="n">setBytes</span><span class="p">(</span><span class="o">&amp;</span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">),</span><span class="w"> </span><span class="mi">4</span><span class="p">);</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_bytes</span><span class="p">(</span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="mi">3</span><span class="p">);</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_bytes</span><span class="p">(</span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="mi">4</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Encode shape, strides and ndim</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-&gt;</span><span class="n">setBytes</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">().</span><span class="n">data</span><span class="p">(),</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">int</span><span class="p">),</span><span class="w"> </span><span class="mi">5</span><span class="p">);</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-&gt;</span><span class="n">setBytes</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">().</span><span class="n">data</span><span class="p">(),</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">size_t</span><span class="p">),</span><span class="w"> </span><span class="mi">6</span><span class="p">);</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-&gt;</span><span class="n">setBytes</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">().</span><span class="n">data</span><span class="p">(),</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">size_t</span><span class="p">),</span><span class="w"> </span><span class="mi">7</span><span class="p">);</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-&gt;</span><span class="n">setBytes</span><span class="p">(</span><span class="o">&amp;</span><span class="n">ndim</span><span class="p">,</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">int</span><span class="p">),</span><span class="w"> </span><span class="mi">8</span><span class="p">);</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_vector_bytes</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="mi">5</span><span class="p">);</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_vector_bytes</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span><span class="w"> </span><span class="mi">6</span><span class="p">);</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_bytes</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span><span class="w"> </span><span class="mi">7</span><span class="p">);</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_bytes</span><span class="p">(</span><span class="n">ndim</span><span class="p">,</span><span class="w"> </span><span class="mi">8</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// We launch 1 thread for each input and make sure that the number of</span>
<span class="w"> </span><span class="c1">// threads in any given threadgroup is not higher than the max allowed</span>
@@ -1378,7 +1379,7 @@ below.</p>
<span class="w"> </span><span class="c1">// Launch the grid with the given number of threads divided among</span>
<span class="w"> </span><span class="c1">// the given threadgroups</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">dispatchThreads</span><span class="p">(</span><span class="n">grid_dims</span><span class="p">,</span><span class="w"> </span><span class="n">group_dims</span><span class="p">);</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">dispatch_threads</span><span class="p">(</span><span class="n">grid_dims</span><span class="p">,</span><span class="w"> </span><span class="n">group_dims</span><span class="p">);</span>
<span class="p">}</span>
</pre></div>
</div>