mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
rebase
This commit is contained in:
26
docs/build/html/dev/extensions.html
vendored
26
docs/build/html/dev/extensions.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>Custom Extensions in MLX — MLX 0.24.1 documentation</title>
|
||||
<title>Custom Extensions in MLX — MLX 0.24.2 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
|
||||
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||||
|
||||
<script src="../_static/documentation_options.js?v=15e192ad"></script>
|
||||
<script src="../_static/documentation_options.js?v=029512cb"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -137,8 +137,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.24.1 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.24.1 documentation - Home"/>`);</script>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.24.2 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.24.2 documentation - Home"/>`);</script>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -291,10 +291,12 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_or.html">mlx.core.bitwise_or</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_xor.html">mlx.core.bitwise_xor</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.block_masked_mm.html">mlx.core.block_masked_mm</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.broadcast_arrays.html">mlx.core.broadcast_arrays</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.broadcast_to.html">mlx.core.broadcast_to</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.ceil.html">mlx.core.ceil</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.clip.html">mlx.core.clip</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.concatenate.html">mlx.core.concatenate</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.contiguous.html">mlx.core.contiguous</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conj.html">mlx.core.conj</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conjugate.html">mlx.core.conjugate</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.convolve.html">mlx.core.convolve</a></li>
|
||||
@@ -453,6 +455,7 @@
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/transforms.html">Transforms</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eval.html">mlx.core.eval</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.async_eval.html">mlx.core.async_eval</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.compile.html">mlx.core.compile</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.custom_function.html">mlx.core.custom_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.disable_compile.html">mlx.core.disable_compile</a></li>
|
||||
@@ -500,6 +503,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eigh.html">mlx.core.linalg.eigh</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.lu.html">mlx.core.linalg.lu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.lu_factor.html">mlx.core.linalg.lu_factor</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.pinv.html">mlx.core.linalg.pinv</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.solve.html">mlx.core.linalg.solve</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.solve_triangular.html">mlx.core.linalg.solve_triangular</a></li>
|
||||
</ul>
|
||||
@@ -1008,9 +1012,9 @@ easy to use interface that use <code class="xref py py-class docutils literal no
|
||||
<section id="primitives">
|
||||
<h3>Primitives<a class="headerlink" href="#primitives" title="Link to this heading">#</a></h3>
|
||||
<p>A <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> is part of the computation graph of an <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code>. It
|
||||
defines how to create outputs arrays given a input arrays. Further, a
|
||||
defines how to create output arrays given input arrays. Further, a
|
||||
<code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> has methods to run on the CPU or GPU and for function
|
||||
transformations such as <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code>. Lets go back to our example to be
|
||||
transformations such as <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code>. Let’s go back to our example to be
|
||||
more concrete:</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">Axpby</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="k">public</span><span class="w"> </span><span class="n">Primitive</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">public</span><span class="o">:</span>
|
||||
@@ -1040,7 +1044,7 @@ more concrete:</p>
|
||||
<span class="w"> </span><span class="cm">/** The vector-Jacobian product. */</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">vjp</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">cotangents</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
|
||||
@@ -1360,7 +1364,7 @@ one we just defined:</p>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="c1">// Forward mode diff that pushes along the tangents</span>
|
||||
<span class="w"> </span><span class="c1">// The jvp transform on the primitive can built with ops</span>
|
||||
<span class="w"> </span><span class="c1">// The jvp transform on the primitive can be built with ops</span>
|
||||
<span class="w"> </span><span class="c1">// that are scheduled on the same stream as the primitive</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// If argnums = {0}, we only push along x in which case the</span>
|
||||
@@ -1372,7 +1376,7 @@ one we just defined:</p>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dtype</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="p">{</span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">stream</span><span class="p">())};</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="w"> </span><span class="c1">// If, argnums = {0, 1}, we take contributions from both</span>
|
||||
<span class="w"> </span><span class="c1">// If argnums = {0, 1}, we take contributions from both</span>
|
||||
<span class="w"> </span><span class="c1">// which gives us jvp = tangent_x * alpha + tangent_y * beta</span>
|
||||
<span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="p">{</span><span class="n">axpby</span><span class="p">(</span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">())};</span>
|
||||
@@ -1608,13 +1612,13 @@ import the Python package and play with it as you would any other MLX operation.
|
||||
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"c shape: </span><span class="si">{</span><span class="n">c</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"c dtype: </span><span class="si">{</span><span class="n">c</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"c correct: </span><span class="si">{</span><span class="n">mx</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">c</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mf">6.0</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"c is correct: </span><span class="si">{</span><span class="n">mx</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">c</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mf">6.0</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Output:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">c</span> <span class="n">shape</span><span class="p">:</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span>
|
||||
<span class="n">c</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">float32</span>
|
||||
<span class="n">c</span> <span class="n">correctness</span><span class="p">:</span> <span class="kc">True</span>
|
||||
<span class="n">c</span> <span class="ow">is</span> <span class="n">correct</span><span class="p">:</span> <span class="kc">True</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<section id="results">
|
||||
|
Reference in New Issue
Block a user