mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
rebase
This commit is contained in:
170
docs/build/html/usage/compile.html
vendored
170
docs/build/html/usage/compile.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>Compilation — MLX 0.21.1 documentation</title>
|
||||
<title>Compilation — MLX 0.22.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
||||
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
||||
</script>
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
-->
|
||||
<noscript>
|
||||
<style>
|
||||
@@ -27,19 +27,19 @@
|
||||
</noscript>
|
||||
|
||||
<!-- Loaded before other Sphinx assets -->
|
||||
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
|
||||
|
||||
<!-- So that users can add custom icons -->
|
||||
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
|
||||
<!-- Pre-loaded scripts that we'll load fully later -->
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
|
||||
|
||||
<script src="../_static/documentation_options.js?v=acb17c73"></script>
|
||||
<script src="../_static/documentation_options.js?v=c952a61b"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -51,7 +51,7 @@
|
||||
<link rel="prev" title="Function Transforms" href="function_transforms.html" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<meta name="docsearch:language" content="en"/>
|
||||
<meta name="docsearch:version" content="0.21.1" />
|
||||
<meta name="docsearch:version" content="0.22.0" />
|
||||
</head>
|
||||
|
||||
|
||||
@@ -130,8 +130,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -160,6 +160,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="numpy.html">Conversion to NumPy and Other Frameworks</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="distributed.html">Distributed Communication</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="using_streams.html">Using Streams</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="export.html">Exporting Functions</a></li>
|
||||
</ul>
|
||||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
||||
<ul class="nav bd-sidenav">
|
||||
@@ -228,6 +229,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
@@ -242,6 +244,13 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
|
||||
@@ -324,6 +333,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
|
||||
@@ -379,6 +389,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
|
||||
@@ -403,6 +415,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
|
||||
@@ -695,6 +708,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
@@ -703,9 +717,14 @@
|
||||
|
||||
|
||||
<div class="sidebar-primary-items__end sidebar-primary__section">
|
||||
<div class="sidebar-primary-item">
|
||||
<div id="ethical-ad-placement"
|
||||
class="flat"
|
||||
data-ea-publisher="readthedocs"
|
||||
data-ea-type="readthedocs-sidebar"
|
||||
data-ea-manual="true">
|
||||
</div></div>
|
||||
</div>
|
||||
|
||||
<div id="rtd-footer-container"></div>
|
||||
|
||||
|
||||
</div>
|
||||
@@ -856,6 +875,7 @@
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#pure-functions">Pure Functions</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#compiling-training-graphs">Compiling Training Graphs</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#transformations-with-compile">Transformations with Compile</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#shapeless-compilation">Shapeless Compilation</a></li>
|
||||
</ul>
|
||||
</nav>
|
||||
</div>
|
||||
@@ -878,7 +898,7 @@ that are good to be aware of for more complex graphs and advanced usage.</p>
|
||||
<section id="basics-of-compile">
|
||||
<h2>Basics of Compile<a class="headerlink" href="#basics-of-compile" title="Link to this heading">#</a></h2>
|
||||
<p>Let’s start with a simple example:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">y</span>
|
||||
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>
|
||||
@@ -902,7 +922,7 @@ graph, optimize it, and generate and compile code. This can be relatively
|
||||
slow. However, MLX will cache compiled functions, so calling a compiled
|
||||
function multiple times will not initiate a new compilation. This means you
|
||||
should typically compile functions that you plan to use more than once.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">y</span>
|
||||
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>
|
||||
@@ -946,7 +966,7 @@ function in a loop:</p>
|
||||
<p>The <a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu.html#mlx.nn.gelu" title="mlx.nn.gelu"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.nn.gelu()</span></code></a> is a nonlinear activation function commonly used with
|
||||
Transformer-based models. The implementation involves several unary and binary
|
||||
element-wise operations:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">gelu</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">gelu</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">mx</span><span class="o">.</span><span class="n">erf</span><span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="p">)))</span> <span class="o">/</span> <span class="mi">2</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
@@ -957,9 +977,9 @@ the operations in the <code class="docutils literal notranslate"><span class="pr
|
||||
<p>Let’s compare the runtime of the regular function versus the compiled
|
||||
function. We’ll use the following timing helper which does a warm up and
|
||||
handles synchronization:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">time</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">time</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">timeit</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">timeit</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="c1"># warm up</span>
|
||||
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span>
|
||||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">fun</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
|
||||
@@ -987,7 +1007,7 @@ five times faster.</p>
|
||||
inputs. This means you can’t evaluate arrays (for example to print their
|
||||
contents) inside compiled functions.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@mx</span><span class="o">.</span><span class="n">compile</span>
|
||||
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="o">-</span><span class="n">x</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="c1"># Crash</span>
|
||||
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||||
@@ -1000,7 +1020,7 @@ globally disable compilation using the <a class="reference internal" href="../py
|
||||
<code class="docutils literal notranslate"><span class="pre">MLX_DISABLE_COMPILE</span></code> flag. For example the following is okay even though
|
||||
<code class="docutils literal notranslate"><span class="pre">fun</span></code> is compiled:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@mx</span><span class="o">.</span><span class="n">compile</span>
|
||||
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="o">-</span><span class="n">x</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="c1"># Okay</span>
|
||||
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||||
@@ -1017,7 +1037,7 @@ effects. For example:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">state</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
|
||||
<span class="nd">@mx</span><span class="o">.</span><span class="n">compile</span>
|
||||
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
<span class="n">state</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||||
@@ -1035,7 +1055,7 @@ computation graph. Printing such an array results in a crash.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">state</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
|
||||
<span class="nd">@mx</span><span class="o">.</span><span class="n">compile</span>
|
||||
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
<span class="n">state</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">),</span> <span class="n">state</span>
|
||||
@@ -1047,13 +1067,13 @@ computation graph. Printing such an array results in a crash.</p>
|
||||
</div>
|
||||
<p>In some cases returning updated state can be pretty inconvenient. Hence,
|
||||
<a class="reference internal" href="../python/_autosummary/mlx.core.compile.html#mlx.core.compile" title="mlx.core.compile"><code class="xref py py-func docutils literal notranslate"><span class="pre">compile()</span></code></a> has a parameter to capture implicit outputs:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">functools</span><span class="w"> </span><span class="kn">import</span> <span class="n">partial</span>
|
||||
|
||||
<span class="n">state</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
|
||||
<span class="c1"># Tell compile to capture state as an output</span>
|
||||
<span class="nd">@partial</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">compile</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">state</span><span class="p">)</span>
|
||||
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
<span class="n">state</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">),</span> <span class="n">state</span>
|
||||
@@ -1071,7 +1091,7 @@ constants. For example:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">state</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)]</span>
|
||||
|
||||
<span class="nd">@mx</span><span class="o">.</span><span class="n">compile</span>
|
||||
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="n">state</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># Prints array(2, dtype=float32)</span>
|
||||
@@ -1088,12 +1108,12 @@ constants. For example:</p>
|
||||
again have two options. The first option is to simply pass <code class="docutils literal notranslate"><span class="pre">state</span></code> as input
|
||||
to the function. In some cases this can be pretty inconvenient. Hence,
|
||||
<a class="reference internal" href="../python/_autosummary/mlx.core.compile.html#mlx.core.compile" title="mlx.core.compile"><code class="xref py py-func docutils literal notranslate"><span class="pre">compile()</span></code></a> also has a parameter to capture implicit inputs:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">functools</span><span class="w"> </span><span class="kn">import</span> <span class="n">partial</span>
|
||||
<span class="n">state</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)]</span>
|
||||
|
||||
<span class="c1"># Tell compile to capture state as an input</span>
|
||||
<span class="nd">@partial</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">compile</span><span class="p">,</span> <span class="n">inputs</span><span class="o">=</span><span class="n">state</span><span class="p">)</span>
|
||||
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="n">state</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># Prints array(2, dtype=float32)</span>
|
||||
@@ -1114,9 +1134,9 @@ of a common setup: training a model with <a class="reference internal" href="../
|
||||
<a class="reference internal" href="../python/optimizers/optimizer.html#mlx.optimizers.Optimizer" title="mlx.optimizers.Optimizer"><code class="xref py py-obj docutils literal notranslate"><span class="pre">mlx.optimizers.Optimizer</span></code></a> with state. We will show how to compile the
|
||||
full forward, backward, and update with <a class="reference internal" href="../python/_autosummary/mlx.core.compile.html#mlx.core.compile" title="mlx.core.compile"><code class="xref py py-func docutils literal notranslate"><span class="pre">compile()</span></code></a>.</p>
|
||||
<p>To start, here is the simple example without any compilation:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||||
<span class="kn">import</span> <span class="nn">mlx.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||||
<span class="kn">import</span> <span class="nn">mlx.optimizers</span> <span class="k">as</span> <span class="nn">optim</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||||
<span class="kn">import</span><span class="w"> </span><span class="nn">mlx.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
|
||||
<span class="kn">import</span><span class="w"> </span><span class="nn">mlx.optimizers</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">optim</span>
|
||||
|
||||
<span class="c1"># 4 examples with 10 features each</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
|
||||
@@ -1130,7 +1150,7 @@ full forward, backward, and update with <a class="reference internal" href="../p
|
||||
<span class="c1"># SGD with momentum</span>
|
||||
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">loss_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
<span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
|
||||
@@ -1145,10 +1165,10 @@ full forward, backward, and update with <a class="reference internal" href="../p
|
||||
</div>
|
||||
<p>To compile the update we can put it all in a function and compile it with the
|
||||
appropriate input and output captures. Here’s the same example but compiled:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||||
<span class="kn">import</span> <span class="nn">mlx.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||||
<span class="kn">import</span> <span class="nn">mlx.optimizers</span> <span class="k">as</span> <span class="nn">optim</span>
|
||||
<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||||
<span class="kn">import</span><span class="w"> </span><span class="nn">mlx.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
|
||||
<span class="kn">import</span><span class="w"> </span><span class="nn">mlx.optimizers</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">optim</span>
|
||||
<span class="kn">from</span><span class="w"> </span><span class="nn">functools</span><span class="w"> </span><span class="kn">import</span> <span class="n">partial</span>
|
||||
|
||||
<span class="c1"># 4 examples with 10 features each</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
|
||||
@@ -1162,7 +1182,7 @@ appropriate input and output captures. Here’s the same example but compiled:</
|
||||
<span class="c1"># SGD with momentum</span>
|
||||
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">loss_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
<span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
|
||||
@@ -1170,7 +1190,7 @@ appropriate input and output captures. Here’s the same example but compiled:</
|
||||
<span class="n">state</span> <span class="o">=</span> <span class="p">[</span><span class="n">model</span><span class="o">.</span><span class="n">state</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">state</span><span class="p">]</span>
|
||||
|
||||
<span class="nd">@partial</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">compile</span><span class="p">,</span> <span class="n">inputs</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">state</span><span class="p">)</span>
|
||||
<span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">step</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="n">loss_and_grad_fn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">value_and_grad</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">)</span>
|
||||
<span class="n">loss</span><span class="p">,</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">loss_and_grad_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="n">optimizer</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">grads</span><span class="p">)</span>
|
||||
@@ -1224,10 +1244,10 @@ function simply pass it through <a class="reference internal" href="../python/_a
|
||||
good practice is to compile the outer most function to give <a class="reference internal" href="../python/_autosummary/mlx.core.compile.html#mlx.core.compile" title="mlx.core.compile"><code class="xref py py-func docutils literal notranslate"><span class="pre">compile()</span></code></a>
|
||||
the most opportunity to optimize the computation graph:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@mx</span><span class="o">.</span><span class="n">compile</span>
|
||||
<span class="k">def</span> <span class="nf">inner</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">inner</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">mx</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">outer</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">outer</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="n">inner</span><span class="p">(</span><span class="n">inner</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># Compiling the outer function is good to do as it will likely</span>
|
||||
@@ -1236,6 +1256,69 @@ the most opportunity to optimize the computation graph:</p>
|
||||
</pre></div>
|
||||
</div>
|
||||
</section>
|
||||
<section id="shapeless-compilation">
|
||||
<span id="shapeless-compile"></span><h2>Shapeless Compilation<a class="headerlink" href="#shapeless-compilation" title="Link to this heading">#</a></h2>
|
||||
<p>When the shape of an input to a compiled function changes, the function is
|
||||
recompiled. You can compile a function once and run it on inputs with
|
||||
variable shapes by specifying <code class="docutils literal notranslate"><span class="pre">shapeless=True</span></code> to <a class="reference internal" href="../python/_autosummary/mlx.core.compile.html#mlx.core.compile" title="mlx.core.compile"><code class="xref py py-func docutils literal notranslate"><span class="pre">compile()</span></code></a>. In this
|
||||
case changes to the shapes of the inputs do not cause the function to be
|
||||
recompiled.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)</span>
|
||||
|
||||
<span class="n">compiled_fun</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="n">shapeless</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="o">-</span><span class="mf">2.0</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Firt call compiles the function</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">compiled_fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># Second call with different shapes</span>
|
||||
<span class="c1"># does not recompile the function</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">1.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">6.0</span><span class="p">])</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="o">-</span><span class="mf">2.0</span><span class="p">,</span> <span class="mf">3.0</span><span class="p">])</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">compiled_fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Use shapeless compilations carefully. Since compilation is not triggered when
|
||||
shapes change, any graphs which are conditional on the input shapes will not
|
||||
work as expected. Shape-dependent computations are common and sometimes subtle
|
||||
to detect. For example:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="n">compiled_fun</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="n">shapeless</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
|
||||
|
||||
<span class="n">out</span> <span class="o">=</span> <span class="n">compiled_fun</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># Error, can't reshape (5, 5, 3) to (6, -1)</span>
|
||||
<span class="n">out</span> <span class="o">=</span> <span class="n">compiled_fun</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The second call to the <code class="docutils literal notranslate"><span class="pre">compiled_fun</span></code> fails because of the call to
|
||||
<a class="reference internal" href="../python/_autosummary/mlx.core.reshape.html#mlx.core.reshape" title="mlx.core.reshape"><code class="xref py py-func docutils literal notranslate"><span class="pre">reshape()</span></code></a> which uses the static shape of <code class="docutils literal notranslate"><span class="pre">x</span></code> in the first call. We can
|
||||
fix this by using <a class="reference internal" href="../python/_autosummary/mlx.core.flatten.html#mlx.core.flatten" title="mlx.core.flatten"><code class="xref py py-func docutils literal notranslate"><span class="pre">flatten()</span></code></a> to avoid hardcoding the shape of <code class="docutils literal notranslate"><span class="pre">x</span></code>:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">x</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="n">compiled_fun</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="n">shapeless</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
|
||||
|
||||
<span class="n">out</span> <span class="o">=</span> <span class="n">compiled_fun</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># Ok</span>
|
||||
<span class="n">out</span> <span class="o">=</span> <span class="n">compiled_fun</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</section>
|
||||
</section>
|
||||
|
||||
|
||||
@@ -1290,6 +1373,7 @@ the most opportunity to optimize the computation graph:</p>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#pure-functions">Pure Functions</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#compiling-training-graphs">Compiling Training Graphs</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#transformations-with-compile">Transformations with Compile</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#shapeless-compilation">Shapeless Compilation</a></li>
|
||||
</ul>
|
||||
</nav></div>
|
||||
|
||||
@@ -1338,8 +1422,8 @@ By MLX Contributors
|
||||
</div>
|
||||
|
||||
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
|
||||
|
||||
<footer class="bd-footer">
|
||||
</footer>
|
||||
|
Reference in New Issue
Block a user