This commit is contained in:
CircleCI Docs
2025-04-03 20:25:24 +00:00
parent e87bbad179
commit a4d0492a1d
528 changed files with 8356 additions and 2106 deletions

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Custom Extensions in MLX &#8212; MLX 0.24.1 documentation</title>
<title>Custom Extensions in MLX &#8212; MLX 0.24.2 documentation</title>
@@ -36,7 +36,7 @@
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
<script src="../_static/documentation_options.js?v=15e192ad"></script>
<script src="../_static/documentation_options.js?v=029512cb"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -137,8 +137,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.24.1 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.24.1 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.24.2 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.24.2 documentation - Home"/>`);</script>
</a></div>
@@ -291,10 +291,12 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_or.html">mlx.core.bitwise_or</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_xor.html">mlx.core.bitwise_xor</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.block_masked_mm.html">mlx.core.block_masked_mm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.broadcast_arrays.html">mlx.core.broadcast_arrays</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.broadcast_to.html">mlx.core.broadcast_to</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.ceil.html">mlx.core.ceil</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.clip.html">mlx.core.clip</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.concatenate.html">mlx.core.concatenate</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.contiguous.html">mlx.core.contiguous</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conj.html">mlx.core.conj</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conjugate.html">mlx.core.conjugate</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.convolve.html">mlx.core.convolve</a></li>
@@ -453,6 +455,7 @@
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/transforms.html">Transforms</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eval.html">mlx.core.eval</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.async_eval.html">mlx.core.async_eval</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.compile.html">mlx.core.compile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.custom_function.html">mlx.core.custom_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.disable_compile.html">mlx.core.disable_compile</a></li>
@@ -500,6 +503,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eigh.html">mlx.core.linalg.eigh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.lu.html">mlx.core.linalg.lu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.lu_factor.html">mlx.core.linalg.lu_factor</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.pinv.html">mlx.core.linalg.pinv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.solve.html">mlx.core.linalg.solve</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.solve_triangular.html">mlx.core.linalg.solve_triangular</a></li>
</ul>
@@ -1008,9 +1012,9 @@ easy to use interface that use <code class="xref py py-class docutils literal no
<section id="primitives">
<h3>Primitives<a class="headerlink" href="#primitives" title="Link to this heading">#</a></h3>
<p>A <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> is part of the computation graph of an <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code>. It
defines how to create outputs arrays given a input arrays. Further, a
defines how to create output arrays given input arrays. Further, a
<code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> has methods to run on the CPU or GPU and for function
transformations such as <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code>. Lets go back to our example to be
transformations such as <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code>. Lets go back to our example to be
more concrete:</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">Axpby</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="k">public</span><span class="w"> </span><span class="n">Primitive</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">public</span><span class="o">:</span>
@@ -1040,7 +1044,7 @@ more concrete:</p>
<span class="w"> </span><span class="cm">/** The vector-Jacobian product. */</span>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;</span><span class="w"> </span><span class="n">vjp</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">cotangents</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">argnums</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
@@ -1360,7 +1364,7 @@ one we just defined:</p>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Forward mode diff that pushes along the tangents</span>
<span class="w"> </span><span class="c1">// The jvp transform on the primitive can built with ops</span>
<span class="w"> </span><span class="c1">// The jvp transform on the primitive can be built with ops</span>
<span class="w"> </span><span class="c1">// that are scheduled on the same stream as the primitive</span>
<span class="w"> </span><span class="c1">// If argnums = {0}, we only push along x in which case the</span>
@@ -1372,7 +1376,7 @@ one we just defined:</p>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dtype</span><span class="p">());</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="p">{</span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">stream</span><span class="p">())};</span>
<span class="w"> </span><span class="p">}</span>
<span class="w"> </span><span class="c1">// If, argnums = {0, 1}, we take contributions from both</span>
<span class="w"> </span><span class="c1">// If argnums = {0, 1}, we take contributions from both</span>
<span class="w"> </span><span class="c1">// which gives us jvp = tangent_x * alpha + tangent_y * beta</span>
<span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="p">{</span><span class="n">axpby</span><span class="p">(</span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">())};</span>
@@ -1608,13 +1612,13 @@ import the Python package and play with it as you would any other MLX operation.
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;c shape: </span><span class="si">{</span><span class="n">c</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;c dtype: </span><span class="si">{</span><span class="n">c</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;c correct: </span><span class="si">{</span><span class="n">mx</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">c</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mf">6.0</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;c is correct: </span><span class="si">{</span><span class="n">mx</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">c</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mf">6.0</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</pre></div>
</div>
<p>Output:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">c</span> <span class="n">shape</span><span class="p">:</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span>
<span class="n">c</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">float32</span>
<span class="n">c</span> <span class="n">correctness</span><span class="p">:</span> <span class="kc">True</span>
<span class="n">c</span> <span class="ow">is</span> <span class="n">correct</span><span class="p">:</span> <span class="kc">True</span>
</pre></div>
</div>
<section id="results">