This commit is contained in:
CircleCI Docs
2025-01-09 21:56:20 +00:00
parent 04b749a588
commit d8d647015b
2642 changed files with 137687 additions and 70861 deletions

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Compilation &#8212; MLX 0.21.1 documentation</title>
<title>Compilation &#8212; MLX 0.22.0 documentation</title>
@@ -16,8 +16,8 @@
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
@@ -27,19 +27,19 @@
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
<!-- So that users can add custom icons -->
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../_static/documentation_options.js?v=acb17c73"></script>
<script src="../_static/documentation_options.js?v=c952a61b"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -51,7 +51,7 @@
<link rel="prev" title="Function Transforms" href="function_transforms.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="0.21.1" />
<meta name="docsearch:version" content="0.22.0" />
</head>
@@ -130,8 +130,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
</a></div>
@@ -160,6 +160,7 @@
<li class="toctree-l1"><a class="reference internal" href="numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="distributed.html">Distributed Communication</a></li>
<li class="toctree-l1"><a class="reference internal" href="using_streams.html">Using Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="export.html">Exporting Functions</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="nav bd-sidenav">
@@ -228,6 +229,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
@@ -242,6 +244,13 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
@@ -324,6 +333,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
@@ -379,6 +389,8 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
@@ -403,6 +415,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
@@ -695,6 +708,7 @@
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
</ul>
</div>
@@ -703,9 +717,14 @@
<div class="sidebar-primary-items__end sidebar-primary__section">
<div class="sidebar-primary-item">
<div id="ethical-ad-placement"
class="flat"
data-ea-publisher="readthedocs"
data-ea-type="readthedocs-sidebar"
data-ea-manual="true">
</div></div>
</div>
<div id="rtd-footer-container"></div>
</div>
@@ -856,6 +875,7 @@
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#pure-functions">Pure Functions</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#compiling-training-graphs">Compiling Training Graphs</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#transformations-with-compile">Transformations with Compile</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#shapeless-compilation">Shapeless Compilation</a></li>
</ul>
</nav>
</div>
@@ -878,7 +898,7 @@ that are good to be aware of for more complex graphs and advanced usage.</p>
<section id="basics-of-compile">
<h2>Basics of Compile<a class="headerlink" href="#basics-of-compile" title="Link to this heading">#</a></h2>
<p>Lets start with a simple example:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>
@@ -902,7 +922,7 @@ graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
function multiple times will not initiate a new compilation. This means you
should typically compile functions that you plan to use more than once.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>
@@ -946,7 +966,7 @@ function in a loop:</p>
<p>The <a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu.html#mlx.nn.gelu" title="mlx.nn.gelu"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.nn.gelu()</span></code></a> is a nonlinear activation function commonly used with
Transformer-based models. The implementation involves several unary and binary
element-wise operations:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">gelu</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">gelu</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">mx</span><span class="o">.</span><span class="n">erf</span><span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="p">)))</span> <span class="o">/</span> <span class="mi">2</span>
</pre></div>
</div>
@@ -957,9 +977,9 @@ the operations in the <code class="docutils literal notranslate"><span class="pr
<p>Lets compare the runtime of the regular function versus the compiled
function. Well use the following timing helper which does a warm up and
handles synchronization:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">time</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">time</span>
<span class="k">def</span> <span class="nf">timeit</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">timeit</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="c1"># warm up</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">fun</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
@@ -987,7 +1007,7 @@ five times faster.</p>
inputs. This means you cant evaluate arrays (for example to print their
contents) inside compiled functions.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@mx</span><span class="o">.</span><span class="n">compile</span>
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="n">z</span> <span class="o">=</span> <span class="o">-</span><span class="n">x</span>
<span class="nb">print</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="c1"># Crash</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
@@ -1000,7 +1020,7 @@ globally disable compilation using the <a class="reference internal" href="../py
<code class="docutils literal notranslate"><span class="pre">MLX_DISABLE_COMPILE</span></code> flag. For example the following is okay even though
<code class="docutils literal notranslate"><span class="pre">fun</span></code> is compiled:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@mx</span><span class="o">.</span><span class="n">compile</span>
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="n">z</span> <span class="o">=</span> <span class="o">-</span><span class="n">x</span>
<span class="nb">print</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="c1"># Okay</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
@@ -1017,7 +1037,7 @@ effects. For example:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">state</span> <span class="o">=</span> <span class="p">[]</span>
<span class="nd">@mx</span><span class="o">.</span><span class="n">compile</span>
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">state</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
@@ -1035,7 +1055,7 @@ computation graph. Printing such an array results in a crash.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">state</span> <span class="o">=</span> <span class="p">[]</span>
<span class="nd">@mx</span><span class="o">.</span><span class="n">compile</span>
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">state</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">),</span> <span class="n">state</span>
@@ -1047,13 +1067,13 @@ computation graph. Printing such an array results in a crash.</p>
</div>
<p>In some cases returning updated state can be pretty inconvenient. Hence,
<a class="reference internal" href="../python/_autosummary/mlx.core.compile.html#mlx.core.compile" title="mlx.core.compile"><code class="xref py py-func docutils literal notranslate"><span class="pre">compile()</span></code></a> has a parameter to capture implicit outputs:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">functools</span><span class="w"> </span><span class="kn">import</span> <span class="n">partial</span>
<span class="n">state</span> <span class="o">=</span> <span class="p">[]</span>
<span class="c1"># Tell compile to capture state as an output</span>
<span class="nd">@partial</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">compile</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">state</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">state</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">),</span> <span class="n">state</span>
@@ -1071,7 +1091,7 @@ constants. For example:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">state</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)]</span>
<span class="nd">@mx</span><span class="o">.</span><span class="n">compile</span>
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="n">state</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># Prints array(2, dtype=float32)</span>
@@ -1088,12 +1108,12 @@ constants. For example:</p>
again have two options. The first option is to simply pass <code class="docutils literal notranslate"><span class="pre">state</span></code> as input
to the function. In some cases this can be pretty inconvenient. Hence,
<a class="reference internal" href="../python/_autosummary/mlx.core.compile.html#mlx.core.compile" title="mlx.core.compile"><code class="xref py py-func docutils literal notranslate"><span class="pre">compile()</span></code></a> also has a parameter to capture implicit inputs:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">functools</span><span class="w"> </span><span class="kn">import</span> <span class="n">partial</span>
<span class="n">state</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)]</span>
<span class="c1"># Tell compile to capture state as an input</span>
<span class="nd">@partial</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">compile</span><span class="p">,</span> <span class="n">inputs</span><span class="o">=</span><span class="n">state</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="n">state</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># Prints array(2, dtype=float32)</span>
@@ -1114,9 +1134,9 @@ of a common setup: training a model with <a class="reference internal" href="../
<a class="reference internal" href="../python/optimizers/optimizer.html#mlx.optimizers.Optimizer" title="mlx.optimizers.Optimizer"><code class="xref py py-obj docutils literal notranslate"><span class="pre">mlx.optimizers.Optimizer</span></code></a> with state. We will show how to compile the
full forward, backward, and update with <a class="reference internal" href="../python/_autosummary/mlx.core.compile.html#mlx.core.compile" title="mlx.core.compile"><code class="xref py py-func docutils literal notranslate"><span class="pre">compile()</span></code></a>.</p>
<p>To start, here is the simple example without any compilation:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">import</span> <span class="nn">mlx.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="kn">import</span> <span class="nn">mlx.optimizers</span> <span class="k">as</span> <span class="nn">optim</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">mlx.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">mlx.optimizers</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">optim</span>
<span class="c1"># 4 examples with 10 features each</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
@@ -1130,7 +1150,7 @@ full forward, backward, and update with <a class="reference internal" href="../p
<span class="c1"># SGD with momentum</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">loss_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
@@ -1145,10 +1165,10 @@ full forward, backward, and update with <a class="reference internal" href="../p
</div>
<p>To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Heres the same example but compiled:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">import</span> <span class="nn">mlx.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="kn">import</span> <span class="nn">mlx.optimizers</span> <span class="k">as</span> <span class="nn">optim</span>
<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">mlx.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">mlx.optimizers</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">optim</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">functools</span><span class="w"> </span><span class="kn">import</span> <span class="n">partial</span>
<span class="c1"># 4 examples with 10 features each</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
@@ -1162,7 +1182,7 @@ appropriate input and output captures. Heres the same example but compiled:</
<span class="c1"># SGD with momentum</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">loss_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
@@ -1170,7 +1190,7 @@ appropriate input and output captures. Heres the same example but compiled:</
<span class="n">state</span> <span class="o">=</span> <span class="p">[</span><span class="n">model</span><span class="o">.</span><span class="n">state</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">state</span><span class="p">]</span>
<span class="nd">@partial</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">compile</span><span class="p">,</span> <span class="n">inputs</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">state</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">step</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="n">loss_and_grad_fn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">value_and_grad</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">)</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">loss_and_grad_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">grads</span><span class="p">)</span>
@@ -1224,10 +1244,10 @@ function simply pass it through <a class="reference internal" href="../python/_a
good practice is to compile the outer most function to give <a class="reference internal" href="../python/_autosummary/mlx.core.compile.html#mlx.core.compile" title="mlx.core.compile"><code class="xref py py-func docutils literal notranslate"><span class="pre">compile()</span></code></a>
the most opportunity to optimize the computation graph:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@mx</span><span class="o">.</span><span class="n">compile</span>
<span class="k">def</span> <span class="nf">inner</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">inner</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">mx</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">outer</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">outer</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="n">inner</span><span class="p">(</span><span class="n">inner</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="c1"># Compiling the outer function is good to do as it will likely</span>
@@ -1236,6 +1256,69 @@ the most opportunity to optimize the computation graph:</p>
</pre></div>
</div>
</section>
<section id="shapeless-compilation">
<span id="shapeless-compile"></span><h2>Shapeless Compilation<a class="headerlink" href="#shapeless-compilation" title="Link to this heading">#</a></h2>
<p>When the shape of an input to a compiled function changes, the function is
recompiled. You can compile a function once and run it on inputs with
variable shapes by specifying <code class="docutils literal notranslate"><span class="pre">shapeless=True</span></code> to <a class="reference internal" href="../python/_autosummary/mlx.core.compile.html#mlx.core.compile" title="mlx.core.compile"><code class="xref py py-func docutils literal notranslate"><span class="pre">compile()</span></code></a>. In this
case changes to the shapes of the inputs do not cause the function to be
recompiled.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)</span>
<span class="n">compiled_fun</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="n">shapeless</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="o">-</span><span class="mf">2.0</span><span class="p">)</span>
<span class="c1"># Firt call compiles the function</span>
<span class="nb">print</span><span class="p">(</span><span class="n">compiled_fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
<span class="c1"># Second call with different shapes</span>
<span class="c1"># does not recompile the function</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">1.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">6.0</span><span class="p">])</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="o">-</span><span class="mf">2.0</span><span class="p">,</span> <span class="mf">3.0</span><span class="p">])</span>
<span class="nb">print</span><span class="p">(</span><span class="n">compiled_fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
</pre></div>
</div>
<p>Use shapeless compilations carefully. Since compilation is not triggered when
shapes change, any graphs which are conditional on the input shapes will not
work as expected. Shape-dependent computations are common and sometimes subtle
to detect. For example:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">compiled_fun</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="n">shapeless</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">compiled_fun</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="c1"># Error, can&#39;t reshape (5, 5, 3) to (6, -1)</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">compiled_fun</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</pre></div>
</div>
<p>The second call to the <code class="docutils literal notranslate"><span class="pre">compiled_fun</span></code> fails because of the call to
<a class="reference internal" href="../python/_autosummary/mlx.core.reshape.html#mlx.core.reshape" title="mlx.core.reshape"><code class="xref py py-func docutils literal notranslate"><span class="pre">reshape()</span></code></a> which uses the static shape of <code class="docutils literal notranslate"><span class="pre">x</span></code> in the first call. We can
fix this by using <a class="reference internal" href="../python/_autosummary/mlx.core.flatten.html#mlx.core.flatten" title="mlx.core.flatten"><code class="xref py py-func docutils literal notranslate"><span class="pre">flatten()</span></code></a> to avoid hardcoding the shape of <code class="docutils literal notranslate"><span class="pre">x</span></code>:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">x</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">compiled_fun</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="n">shapeless</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">compiled_fun</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="c1"># Ok</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">compiled_fun</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</pre></div>
</div>
</section>
</section>
@@ -1290,6 +1373,7 @@ the most opportunity to optimize the computation graph:</p>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#pure-functions">Pure Functions</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#compiling-training-graphs">Compiling Training Graphs</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#transformations-with-compile">Transformations with Compile</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#shapeless-compilation">Shapeless Compilation</a></li>
</ul>
</nav></div>
@@ -1338,8 +1422,8 @@ By MLX Contributors
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
</footer>

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Distributed Communication &#8212; MLX 0.21.1 documentation</title>
<title>Distributed Communication &#8212; MLX 0.22.0 documentation</title>
@@ -16,8 +16,8 @@
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
@@ -27,19 +27,19 @@
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
<!-- So that users can add custom icons -->
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../_static/documentation_options.js?v=acb17c73"></script>
<script src="../_static/documentation_options.js?v=c952a61b"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -51,7 +51,7 @@
<link rel="prev" title="Conversion to NumPy and Other Frameworks" href="numpy.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="0.21.1" />
<meta name="docsearch:version" content="0.22.0" />
</head>
@@ -130,8 +130,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
</a></div>
@@ -160,6 +160,7 @@
<li class="toctree-l1"><a class="reference internal" href="numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1 current active"><a class="current reference internal" href="#">Distributed Communication</a></li>
<li class="toctree-l1"><a class="reference internal" href="using_streams.html">Using Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="export.html">Exporting Functions</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="nav bd-sidenav">
@@ -228,6 +229,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
@@ -242,6 +244,13 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
@@ -324,6 +333,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
@@ -379,6 +389,8 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
@@ -403,6 +415,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
@@ -695,6 +708,7 @@
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
</ul>
</div>
@@ -703,9 +717,14 @@
<div class="sidebar-primary-items__end sidebar-primary__section">
<div class="sidebar-primary-item">
<div id="ethical-ad-placement"
class="flat"
data-ea-publisher="readthedocs"
data-ea-type="readthedocs-sidebar"
data-ea-manual="true">
</div></div>
</div>
<div id="rtd-footer-container"></div>
</div>
@@ -882,7 +901,7 @@ best way to do distributed computing on Macs using MLX.</p>
<h2>Getting Started<a class="headerlink" href="#getting-started" title="Link to this heading">#</a></h2>
<p>MLX already comes with the ability to “talk” to MPI if it is installed on the
machine. The minimal distributed program in MLX is as simple as:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
<span class="n">world</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init</span><span class="p">()</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_sum</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">10</span><span class="p">))</span>
@@ -961,7 +980,7 @@ dataset and optimizer initialization.</p>
<span class="n">optimizer</span> <span class="o">=</span> <span class="o">...</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="o">...</span>
<span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">step</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">loss_grad_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">grads</span><span class="p">)</span>
<span class="k">return</span> <span class="n">loss</span>
@@ -974,23 +993,24 @@ dataset and optimizer initialization.</p>
<p>All we have to do to average the gradients across machines is perform an
<a class="reference internal" href="../python/_autosummary/mlx.core.distributed.all_sum.html#mlx.core.distributed.all_sum" title="mlx.core.distributed.all_sum"><code class="xref py py-func docutils literal notranslate"><span class="pre">all_sum()</span></code></a> and divide by the size of the <a class="reference internal" href="../python/_autosummary/mlx.core.distributed.Group.html#mlx.core.distributed.Group" title="mlx.core.distributed.Group"><code class="xref py py-class docutils literal notranslate"><span class="pre">Group</span></code></a>. Namely we
have to <a class="reference internal" href="../python/_autosummary/mlx.utils.tree_map.html#mlx.utils.tree_map" title="mlx.utils.tree_map"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.utils.tree_map()</span></code></a> the gradients with following function.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">all_avg</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">all_avg</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_sum</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init</span><span class="p">()</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
</pre></div>
</div>
<p>Putting everything together our training loop step looks as follows with
everything else remaining the same.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">mlx.utils</span> <span class="kn">import</span> <span class="n">tree_map</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">mlx.utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">tree_map</span>
<span class="k">def</span> <span class="nf">all_reduce_grads</span><span class="p">(</span><span class="n">grads</span><span class="p">):</span>
<span class="n">N</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init</span><span class="p">()</span>
<span class="k">def</span><span class="w"> </span><span class="nf">all_reduce_grads</span><span class="p">(</span><span class="n">grads</span><span class="p">):</span>
<span class="n">N</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init</span><span class="p">()</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">return</span> <span class="n">grads</span>
<span class="k">return</span> <span class="n">tree_map</span><span class="p">(</span>
<span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_sum</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span><span class="p">,</span>
<span class="n">grads</span><span class="p">)</span>
<span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_sum</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span><span class="p">,</span>
<span class="n">grads</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">step</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">loss_grad_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="n">grads</span> <span class="o">=</span> <span class="n">all_reduce_grads</span><span class="p">(</span><span class="n">grads</span><span class="p">)</span> <span class="c1"># &lt;--- This line was added</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">grads</span><span class="p">)</span>
@@ -1110,8 +1130,8 @@ By MLX Contributors
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
</footer>

1243
docs/build/html/usage/export.html vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Function Transforms &#8212; MLX 0.21.1 documentation</title>
<title>Function Transforms &#8212; MLX 0.22.0 documentation</title>
@@ -16,8 +16,8 @@
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
@@ -27,19 +27,19 @@
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
<!-- So that users can add custom icons -->
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../_static/documentation_options.js?v=acb17c73"></script>
<script src="../_static/documentation_options.js?v=c952a61b"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -51,7 +51,7 @@
<link rel="prev" title="Saving and Loading Arrays" href="saving_and_loading.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="0.21.1" />
<meta name="docsearch:version" content="0.22.0" />
</head>
@@ -130,8 +130,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
</a></div>
@@ -160,6 +160,7 @@
<li class="toctree-l1"><a class="reference internal" href="numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="distributed.html">Distributed Communication</a></li>
<li class="toctree-l1"><a class="reference internal" href="using_streams.html">Using Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="export.html">Exporting Functions</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="nav bd-sidenav">
@@ -228,6 +229,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
@@ -242,6 +244,13 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
@@ -324,6 +333,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
@@ -379,6 +389,8 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
@@ -403,6 +415,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
@@ -695,6 +708,7 @@
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
</ul>
</div>
@@ -703,9 +717,14 @@
<div class="sidebar-primary-items__end sidebar-primary__section">
<div class="sidebar-primary-item">
<div id="ethical-ad-placement"
class="flat"
data-ea-publisher="readthedocs"
data-ea-type="readthedocs-sidebar"
data-ea-manual="true">
</div></div>
</div>
<div id="rtd-footer-container"></div>
</div>
@@ -908,7 +927,7 @@ graphs.</p>
saw above. You can use the <a class="reference internal" href="../python/_autosummary/mlx.core.grad.html#mlx.core.grad" title="mlx.core.grad"><code class="xref py py-func docutils literal notranslate"><span class="pre">grad()</span></code></a> and <a class="reference internal" href="../python/_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-func docutils literal notranslate"><span class="pre">value_and_grad()</span></code></a> function to
compute gradients of more complex functions. By default these functions compute
the gradient with respect to the first argument:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">loss_fn</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">w</span> <span class="o">*</span> <span class="n">x</span> <span class="o">-</span> <span class="n">y</span><span class="p">))</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>
@@ -947,7 +966,7 @@ containers of arrays (specifically any of <a class="reference external" href="ht
<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#dict" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">dict</span></code></a>).</p>
<p>Suppose we wanted a weight and a bias parameter in the above example. A nice
way to do that is the following:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">loss_fn</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="n">w</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">params</span><span class="p">[</span><span class="s2">&quot;weight&quot;</span><span class="p">],</span> <span class="n">params</span><span class="p">[</span><span class="s2">&quot;bias&quot;</span><span class="p">]</span>
<span class="n">h</span> <span class="o">=</span> <span class="n">w</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">b</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">h</span> <span class="o">-</span> <span class="n">y</span><span class="p">))</span>
@@ -986,7 +1005,7 @@ We will prioritize including it.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">xs</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">4096</span><span class="p">,</span> <span class="mi">100</span><span class="p">))</span>
<span class="n">ys</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">4096</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">naive_add</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">naive_add</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">):</span>
<span class="k">return</span> <span class="p">[</span><span class="n">xs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">ys</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">xs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
</pre></div>
</div>
@@ -1000,7 +1019,7 @@ We will prioritize including it.</p>
corresponding input to vectorize over. Similarly, use <code class="docutils literal notranslate"><span class="pre">out_axes</span></code> to specify
where the vectorized axes should be in the outputs.</p>
<p>Lets time these two different versions:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">timeit</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">timeit</span>
<span class="nb">print</span><span class="p">(</span><span class="n">timeit</span><span class="o">.</span><span class="n">timeit</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">naive_add</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">)),</span> <span class="n">number</span><span class="o">=</span><span class="mi">100</span><span class="p">))</span>
<span class="nb">print</span><span class="p">(</span><span class="n">timeit</span><span class="o">.</span><span class="n">timeit</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">vmap_add</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">)),</span> <span class="n">number</span><span class="o">=</span><span class="mi">100</span><span class="p">))</span>
@@ -1109,8 +1128,8 @@ By MLX Contributors
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
</footer>

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Indexing Arrays &#8212; MLX 0.21.1 documentation</title>
<title>Indexing Arrays &#8212; MLX 0.22.0 documentation</title>
@@ -16,8 +16,8 @@
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
@@ -27,19 +27,19 @@
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
<!-- So that users can add custom icons -->
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../_static/documentation_options.js?v=acb17c73"></script>
<script src="../_static/documentation_options.js?v=c952a61b"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -51,7 +51,7 @@
<link rel="prev" title="Unified Memory" href="unified_memory.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="0.21.1" />
<meta name="docsearch:version" content="0.22.0" />
</head>
@@ -130,8 +130,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
</a></div>
@@ -160,6 +160,7 @@
<li class="toctree-l1"><a class="reference internal" href="numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="distributed.html">Distributed Communication</a></li>
<li class="toctree-l1"><a class="reference internal" href="using_streams.html">Using Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="export.html">Exporting Functions</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="nav bd-sidenav">
@@ -228,6 +229,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
@@ -242,6 +244,13 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
@@ -324,6 +333,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
@@ -379,6 +389,8 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
@@ -403,6 +415,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
@@ -695,6 +708,7 @@
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
</ul>
</div>
@@ -703,9 +717,14 @@
<div class="sidebar-primary-items__end sidebar-primary__section">
<div class="sidebar-primary-item">
<div id="ethical-ad-placement"
class="flat"
data-ea-publisher="readthedocs"
data-ea-type="readthedocs-sidebar"
data-ea-manual="true">
</div></div>
</div>
<div id="rtd-footer-container"></div>
</div>
@@ -866,9 +885,9 @@
<section id="indexing-arrays">
<span id="indexing"></span><h1>Indexing Arrays<a class="headerlink" href="#indexing-arrays" title="Link to this heading">#</a></h1>
<p>For the most part, indexing an MLX <a class="reference internal" href="../python/_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-obj docutils literal notranslate"><span class="pre">array</span></code></a> works the same as indexing a
NumPy <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v2.1)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">numpy.ndarray</span></code></a>. See the <a class="reference external" href="https://numpy.org/doc/stable/user/basics.indexing.html">NumPy documentation</a> for more details on
NumPy <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v2.2)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">numpy.ndarray</span></code></a>. See the <a class="reference external" href="https://numpy.org/doc/stable/user/basics.indexing.html">NumPy documentation</a> for more details on
how that works.</p>
<p>For example, you can use regular integers and slices (<a class="reference external" href="https://docs.python.org/3/library/functions.html#slice" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">slice</span></code></a>) to index arrays:</p>
<p>For example, you can use regular integers and slices (<a class="reference internal" href="../python/_autosummary/mlx.core.slice.html#mlx.core.slice" title="mlx.core.slice"><code class="xref py py-obj docutils literal notranslate"><span class="pre">slice</span></code></a>) to index arrays:</p>
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>&gt;&gt;&gt;<span class="w"> </span><span class="nv">arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span>mx.arange<span class="o">(</span><span class="m">10</span><span class="o">)</span>
&gt;&gt;&gt;<span class="w"> </span>arr<span class="o">[</span><span class="m">3</span><span class="o">]</span>
array<span class="o">(</span><span class="m">3</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>int32<span class="o">)</span>
@@ -904,7 +923,7 @@ array<span class="o">([[</span><span class="m">0</span>,<span class="w"> </span>
array<span class="o">([</span><span class="m">5</span>,<span class="w"> </span><span class="m">7</span><span class="o">]</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>int32<span class="o">)</span>
</pre></div>
</div>
<p>Mixing and matching integers, <a class="reference external" href="https://docs.python.org/3/library/functions.html#slice" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">slice</span></code></a>, <code class="docutils literal notranslate"><span class="pre">...</span></code>, and <a class="reference internal" href="../python/_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-obj docutils literal notranslate"><span class="pre">array</span></code></a> indices
<p>Mixing and matching integers, <a class="reference internal" href="../python/_autosummary/mlx.core.slice.html#mlx.core.slice" title="mlx.core.slice"><code class="xref py py-obj docutils literal notranslate"><span class="pre">slice</span></code></a>, <code class="docutils literal notranslate"><span class="pre">...</span></code>, and <a class="reference internal" href="../python/_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-obj docutils literal notranslate"><span class="pre">array</span></code></a> indices
works just as in NumPy.</p>
<p>Other functions which may be useful for indexing arrays are <a class="reference internal" href="../python/_autosummary/mlx.core.take.html#mlx.core.take" title="mlx.core.take"><code class="xref py py-func docutils literal notranslate"><span class="pre">take()</span></code></a> and
<a class="reference internal" href="../python/_autosummary/mlx.core.take_along_axis.html#mlx.core.take_along_axis" title="mlx.core.take_along_axis"><code class="xref py py-func docutils literal notranslate"><span class="pre">take_along_axis()</span></code></a>.</p>
@@ -925,8 +944,8 @@ kernel would be extremely inefficient.</p>
<p>Indexing with boolean masks is something that MLX may support in the future. In
general, MLX has limited support for operations for which output
<em>shapes</em> are dependent on input <em>data</em>. Other examples of these types of
operations which MLX does not yet support include <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.nonzero.html#numpy.nonzero" title="(in NumPy v2.1)"><code class="xref py py-func docutils literal notranslate"><span class="pre">numpy.nonzero()</span></code></a> and the
single input version of <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.where.html#numpy.where" title="(in NumPy v2.1)"><code class="xref py py-func docutils literal notranslate"><span class="pre">numpy.where()</span></code></a>.</p>
operations which MLX does not yet support include <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.nonzero.html#numpy.nonzero" title="(in NumPy v2.2)"><code class="xref py py-func docutils literal notranslate"><span class="pre">numpy.nonzero()</span></code></a> and the
single input version of <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.where.html#numpy.where" title="(in NumPy v2.2)"><code class="xref py py-func docutils literal notranslate"><span class="pre">numpy.where()</span></code></a>.</p>
</section>
<section id="in-place-updates">
<h2>In Place Updates<a class="headerlink" href="#in-place-updates" title="Link to this heading">#</a></h2>
@@ -950,7 +969,7 @@ array<span class="o">([</span><span class="m">1</span>,<span class="w"> </span><
</div>
<p>Transformations of functions which use in-place updates are allowed and work as
expected. For example:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">idx</span><span class="p">):</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">idx</span><span class="p">):</span>
<span class="n">x</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="mf">2.0</span>
<span class="k">return</span> <span class="n">x</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
@@ -1059,8 +1078,8 @@ By MLX Contributors
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
</footer>

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Lazy Evaluation &#8212; MLX 0.21.1 documentation</title>
<title>Lazy Evaluation &#8212; MLX 0.22.0 documentation</title>
@@ -16,8 +16,8 @@
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
@@ -27,19 +27,19 @@
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
<!-- So that users can add custom icons -->
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../_static/documentation_options.js?v=acb17c73"></script>
<script src="../_static/documentation_options.js?v=c952a61b"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -51,7 +51,7 @@
<link rel="prev" title="Quick Start Guide" href="quick_start.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="0.21.1" />
<meta name="docsearch:version" content="0.22.0" />
</head>
@@ -130,8 +130,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
</a></div>
@@ -160,6 +160,7 @@
<li class="toctree-l1"><a class="reference internal" href="numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="distributed.html">Distributed Communication</a></li>
<li class="toctree-l1"><a class="reference internal" href="using_streams.html">Using Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="export.html">Exporting Functions</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="nav bd-sidenav">
@@ -228,6 +229,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
@@ -242,6 +244,13 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
@@ -324,6 +333,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
@@ -379,6 +389,8 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
@@ -403,6 +415,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
@@ -695,6 +708,7 @@
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
</ul>
</div>
@@ -703,9 +717,14 @@
<div class="sidebar-primary-items__end sidebar-primary__section">
<div class="sidebar-primary-item">
<div id="ethical-ad-placement"
class="flat"
data-ea-publisher="readthedocs"
data-ea-type="readthedocs-sidebar"
data-ea-manual="true">
</div></div>
</div>
<div id="rtd-footer-container"></div>
</div>
@@ -889,7 +908,7 @@ integrate compilation for future performance enhancements.</p>
<h3>Only Compute What You Use<a class="headerlink" href="#only-compute-what-you-use" title="Link to this heading">#</a></h3>
<p>In MLX you do not need to worry as much about computing outputs that are never
used. For example:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">fun1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">expensive_fun</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
<span class="k">return</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span>
@@ -953,7 +972,7 @@ stochastic gradient descent). A natural and usually efficient place to use
</div>
<p>An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you <code class="docutils literal notranslate"><span class="pre">print</span></code> an array, convert it to an
<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v2.1)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">numpy.ndarray</span></code></a>, or otherwise access its memory via <a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#memoryview" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">memoryview</span></code></a>,
<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v2.2)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">numpy.ndarray</span></code></a>, or otherwise access its memory via <a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#memoryview" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">memoryview</span></code></a>,
the graph will be evaluated. Saving arrays via <a class="reference internal" href="../python/_autosummary/mlx.core.save.html#mlx.core.save" title="mlx.core.save"><code class="xref py py-func docutils literal notranslate"><span class="pre">save()</span></code></a> (or any other MLX
saving functions) will also evaluate the array.</p>
<p>Calling <a class="reference internal" href="../python/_autosummary/mlx.core.array.item.html#mlx.core.array.item" title="mlx.core.array.item"><code class="xref py py-func docutils literal notranslate"><span class="pre">array.item()</span></code></a> on a scalar array will also evaluate it. In the
@@ -968,7 +987,7 @@ perfectly fine. This is effectively a no-op.</p>
<p>Using scalar arrays for control-flow will cause an evaluation.</p>
</div>
<p>Here is an example:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="n">h</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">first_layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">if</span> <span class="n">y</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># An evaluation is done here!</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">second_layer_a</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
@@ -1083,8 +1102,8 @@ By MLX Contributors
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
</footer>

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Conversion to NumPy and Other Frameworks &#8212; MLX 0.21.1 documentation</title>
<title>Conversion to NumPy and Other Frameworks &#8212; MLX 0.22.0 documentation</title>
@@ -16,8 +16,8 @@
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
@@ -27,19 +27,19 @@
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
<!-- So that users can add custom icons -->
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../_static/documentation_options.js?v=acb17c73"></script>
<script src="../_static/documentation_options.js?v=c952a61b"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -51,7 +51,7 @@
<link rel="prev" title="Compilation" href="compile.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="0.21.1" />
<meta name="docsearch:version" content="0.22.0" />
</head>
@@ -130,8 +130,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
</a></div>
@@ -160,6 +160,7 @@
<li class="toctree-l1 current active"><a class="current reference internal" href="#">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="distributed.html">Distributed Communication</a></li>
<li class="toctree-l1"><a class="reference internal" href="using_streams.html">Using Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="export.html">Exporting Functions</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="nav bd-sidenav">
@@ -228,6 +229,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
@@ -242,6 +244,13 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
@@ -324,6 +333,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
@@ -379,6 +389,8 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
@@ -403,6 +415,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
@@ -695,6 +708,7 @@
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
</ul>
</div>
@@ -703,9 +717,14 @@
<div class="sidebar-primary-items__end sidebar-primary__section">
<div class="sidebar-primary-item">
<div id="ethical-ad-placement"
class="flat"
data-ea-publisher="readthedocs"
data-ea-type="readthedocs-sidebar"
data-ea-manual="true">
</div></div>
</div>
<div id="rtd-footer-container"></div>
</div>
@@ -872,8 +891,8 @@
<li><p><a class="reference external" href="https://dmlc.github.io/dlpack/latest/">DLPack</a>.</p></li>
</ul>
<p>Lets convert an array to NumPy and back.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="c1"># copy of a</span>
@@ -898,7 +917,7 @@ Otherwise, you will receive an error like: <code class="docutils literal notrans
This means writing to the view is reflected in the original array.</p>
<p>While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.</p>
<p>Lets demonstrate this in an example:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">f</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="n">x_view</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">x_view</span><span class="p">[:]</span> <span class="o">*=</span> <span class="n">x_view</span> <span class="c1"># modify memory without telling mx</span>
<span class="k">return</span> <span class="n">x</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
@@ -924,8 +943,8 @@ even though no in-place operations on MLX memory are executed.</p>
multi-dimensional arrays. Casting to NumPy first is advised for now.</p>
</div>
<p>PyTorch supports the buffer protocol, but it requires an explicit <a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#memoryview" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">memoryview</span></code></a>.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">import</span> <span class="nn">torch</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="nb">memoryview</span><span class="p">(</span><span class="n">a</span><span class="p">))</span>
@@ -937,8 +956,8 @@ multi-dimensional arrays. Casting to NumPy first is advised for now.</p>
<section id="jax">
<h2>JAX<a class="headerlink" href="#jax" title="Link to this heading">#</a></h2>
<p>JAX fully supports the buffer protocol.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">jax.numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">jnp</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
@@ -949,8 +968,8 @@ multi-dimensional arrays. Casting to NumPy first is advised for now.</p>
<section id="tensorflow">
<h2>TensorFlow<a class="headerlink" href="#tensorflow" title="Link to this heading">#</a></h2>
<p>TensorFlow supports the buffer protocol, but it requires an explicit <a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#memoryview" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">memoryview</span></code></a>.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">tensorflow</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">tf</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="nb">memoryview</span><span class="p">(</span><span class="n">a</span><span class="p">))</span>
@@ -1057,8 +1076,8 @@ By MLX Contributors
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
</footer>

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Quick Start Guide &#8212; MLX 0.21.1 documentation</title>
<title>Quick Start Guide &#8212; MLX 0.22.0 documentation</title>
@@ -16,8 +16,8 @@
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
@@ -27,19 +27,19 @@
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
<!-- So that users can add custom icons -->
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../_static/documentation_options.js?v=acb17c73"></script>
<script src="../_static/documentation_options.js?v=c952a61b"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -51,7 +51,7 @@
<link rel="prev" title="Build and Install" href="../install.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="0.21.1" />
<meta name="docsearch:version" content="0.22.0" />
</head>
@@ -130,8 +130,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
</a></div>
@@ -160,6 +160,7 @@
<li class="toctree-l1"><a class="reference internal" href="numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="distributed.html">Distributed Communication</a></li>
<li class="toctree-l1"><a class="reference internal" href="using_streams.html">Using Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="export.html">Exporting Functions</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="nav bd-sidenav">
@@ -228,6 +229,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
@@ -242,6 +244,13 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
@@ -324,6 +333,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
@@ -379,6 +389,8 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
@@ -403,6 +415,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
@@ -695,6 +708,7 @@
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
</ul>
</div>
@@ -703,9 +717,14 @@
<div class="sidebar-primary-items__end sidebar-primary__section">
<div class="sidebar-primary-item">
<div id="ethical-ad-placement"
class="flat"
data-ea-publisher="readthedocs"
data-ea-type="readthedocs-sidebar"
data-ea-manual="true">
</div></div>
</div>
<div id="rtd-footer-container"></div>
</div>
@@ -868,7 +887,7 @@
<section id="basics">
<h2>Basics<a class="headerlink" href="#basics" title="Link to this heading">#</a></h2>
<p>Import <code class="docutils literal notranslate"><span class="pre">mlx.core</span></code> and make an <a class="reference internal" href="../python/_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code></a>:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&gt;&gt;</span> <span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&gt;&gt;</span> <span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
<span class="o">&gt;&gt;</span> <span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">])</span>
<span class="o">&gt;&gt;</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span>
<span class="p">[</span><span class="mi">4</span><span class="p">]</span>
@@ -883,7 +902,7 @@
until they are needed. To force an array to be evaluated use
<a class="reference internal" href="../python/_autosummary/mlx.core.eval.html#mlx.core.eval" title="mlx.core.eval"><code class="xref py py-func docutils literal notranslate"><span class="pre">eval()</span></code></a>. Arrays will automatically be evaluated in a few cases. For
example, inspecting a scalar with <a class="reference internal" href="../python/_autosummary/mlx.core.array.item.html#mlx.core.array.item" title="mlx.core.array.item"><code class="xref py py-meth docutils literal notranslate"><span class="pre">array.item()</span></code></a>, printing an array,
or converting an array from <a class="reference internal" href="../python/_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code></a> to <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v2.1)"><code class="xref py py-class docutils literal notranslate"><span class="pre">numpy.ndarray</span></code></a> all
or converting an array from <a class="reference internal" href="../python/_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code></a> to <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v2.2)"><code class="xref py py-class docutils literal notranslate"><span class="pre">numpy.ndarray</span></code></a> all
automatically evaluate the array.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&gt;&gt;</span> <span class="n">c</span> <span class="o">=</span> <span class="n">a</span> <span class="o">+</span> <span class="n">b</span> <span class="c1"># c not yet evaluated</span>
<span class="o">&gt;&gt;</span> <span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">c</span><span class="p">)</span> <span class="c1"># evaluates c</span>
@@ -891,7 +910,7 @@ automatically evaluate the array.</p>
<span class="o">&gt;&gt;</span> <span class="nb">print</span><span class="p">(</span><span class="n">c</span><span class="p">)</span> <span class="c1"># Also evaluates c</span>
<span class="n">array</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">8</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float32</span><span class="p">)</span>
<span class="o">&gt;&gt;</span> <span class="n">c</span> <span class="o">=</span> <span class="n">a</span> <span class="o">+</span> <span class="n">b</span>
<span class="o">&gt;&gt;</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="o">&gt;&gt;</span> <span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
<span class="o">&gt;&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">c</span><span class="p">)</span> <span class="c1"># Also evaluates c</span>
<span class="n">array</span><span class="p">([</span><span class="mf">2.</span><span class="p">,</span> <span class="mf">4.</span><span class="p">,</span> <span class="mf">6.</span><span class="p">,</span> <span class="mf">8.</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float32</span><span class="p">)</span>
</pre></div>
@@ -1015,8 +1034,8 @@ By MLX Contributors
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
</footer>

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Saving and Loading Arrays &#8212; MLX 0.21.1 documentation</title>
<title>Saving and Loading Arrays &#8212; MLX 0.22.0 documentation</title>
@@ -16,8 +16,8 @@
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
@@ -27,19 +27,19 @@
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
<!-- So that users can add custom icons -->
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../_static/documentation_options.js?v=acb17c73"></script>
<script src="../_static/documentation_options.js?v=c952a61b"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -51,7 +51,7 @@
<link rel="prev" title="Indexing Arrays" href="indexing.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="0.21.1" />
<meta name="docsearch:version" content="0.22.0" />
</head>
@@ -130,8 +130,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
</a></div>
@@ -160,6 +160,7 @@
<li class="toctree-l1"><a class="reference internal" href="numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="distributed.html">Distributed Communication</a></li>
<li class="toctree-l1"><a class="reference internal" href="using_streams.html">Using Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="export.html">Exporting Functions</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="nav bd-sidenav">
@@ -228,6 +229,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
@@ -242,6 +244,13 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
@@ -324,6 +333,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
@@ -379,6 +389,8 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
@@ -403,6 +415,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
@@ -695,6 +708,7 @@
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
</ul>
</div>
@@ -703,9 +717,14 @@
<div class="sidebar-primary-items__end sidebar-primary__section">
<div class="sidebar-primary-item">
<div id="ethical-ad-placement"
class="flat"
data-ea-publisher="readthedocs"
data-ea-type="readthedocs-sidebar"
data-ea-manual="true">
</div></div>
</div>
<div id="rtd-footer-container"></div>
</div>
@@ -915,7 +934,7 @@ array<span class="o">([</span><span class="m">1</span><span class="o">]</span>,<
&gt;&gt;&gt;<span class="w"> </span>mx.savez<span class="o">(</span><span class="s2">&quot;arrays&quot;</span>,<span class="w"> </span>a,<span class="w"> </span><span class="nv">b</span><span class="o">=</span>b<span class="o">)</span>
</pre></div>
</div>
<p>For compatibility with <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.savez.html#numpy.savez" title="(in NumPy v2.1)"><code class="xref py py-func docutils literal notranslate"><span class="pre">numpy.savez()</span></code></a> the MLX <a class="reference internal" href="../python/_autosummary/mlx.core.savez.html#mlx.core.savez" title="mlx.core.savez"><code class="xref py py-func docutils literal notranslate"><span class="pre">savez()</span></code></a> takes arrays
<p>For compatibility with <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.savez.html#numpy.savez" title="(in NumPy v2.2)"><code class="xref py py-func docutils literal notranslate"><span class="pre">numpy.savez()</span></code></a> the MLX <a class="reference internal" href="../python/_autosummary/mlx.core.savez.html#mlx.core.savez" title="mlx.core.savez"><code class="xref py py-func docutils literal notranslate"><span class="pre">savez()</span></code></a> takes arrays
as arguments. If the keywords are missing, then default names will be
provided. This can be loaded with:</p>
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>&gt;&gt;&gt;<span class="w"> </span>mx.load<span class="o">(</span><span class="s2">&quot;arrays.npz&quot;</span><span class="o">)</span>
@@ -1011,8 +1030,8 @@ By MLX Contributors
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
</footer>

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Unified Memory &#8212; MLX 0.21.1 documentation</title>
<title>Unified Memory &#8212; MLX 0.22.0 documentation</title>
@@ -16,8 +16,8 @@
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
@@ -27,19 +27,19 @@
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
<!-- So that users can add custom icons -->
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../_static/documentation_options.js?v=acb17c73"></script>
<script src="../_static/documentation_options.js?v=c952a61b"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -51,7 +51,7 @@
<link rel="prev" title="Lazy Evaluation" href="lazy_evaluation.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="0.21.1" />
<meta name="docsearch:version" content="0.22.0" />
</head>
@@ -130,8 +130,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
</a></div>
@@ -160,6 +160,7 @@
<li class="toctree-l1"><a class="reference internal" href="numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="distributed.html">Distributed Communication</a></li>
<li class="toctree-l1"><a class="reference internal" href="using_streams.html">Using Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="export.html">Exporting Functions</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="nav bd-sidenav">
@@ -228,6 +229,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
@@ -242,6 +244,13 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
@@ -324,6 +333,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
@@ -379,6 +389,8 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
@@ -403,6 +415,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
@@ -695,6 +708,7 @@
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
</ul>
</div>
@@ -703,9 +717,14 @@
<div class="sidebar-primary-items__end sidebar-primary__section">
<div class="sidebar-primary-item">
<div id="ethical-ad-placement"
class="flat"
data-ea-publisher="readthedocs"
data-ea-type="readthedocs-sidebar"
data-ea-manual="true">
</div></div>
</div>
<div id="rtd-footer-container"></div>
</div>
@@ -899,7 +918,7 @@ available.</p>
<h2>A Simple Example<a class="headerlink" href="#a-simple-example" title="Link to this heading">#</a></h2>
<p>Here is a more interesting (albeit slightly contrived example) of how unified
memory can be helpful. Suppose we have the following computation:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">fun</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">d1</span><span class="p">,</span> <span class="n">d2</span><span class="p">):</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">fun</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">d1</span><span class="p">,</span> <span class="n">d2</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">stream</span><span class="o">=</span><span class="n">d1</span><span class="p">)</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">500</span><span class="p">):</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">stream</span><span class="o">=</span><span class="n">d2</span><span class="p">)</span>
@@ -1016,8 +1035,8 @@ By MLX Contributors
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
</footer>

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Using Streams &#8212; MLX 0.21.1 documentation</title>
<title>Using Streams &#8212; MLX 0.22.0 documentation</title>
@@ -16,8 +16,8 @@
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
@@ -27,19 +27,19 @@
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
<!-- So that users can add custom icons -->
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../_static/documentation_options.js?v=acb17c73"></script>
<script src="../_static/documentation_options.js?v=c952a61b"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -47,11 +47,11 @@
<link rel="icon" href="../_static/mlx_logo.png"/>
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="Linear Regression" href="../examples/linear_regression.html" />
<link rel="next" title="Exporting Functions" href="export.html" />
<link rel="prev" title="Distributed Communication" href="distributed.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="0.21.1" />
<meta name="docsearch:version" content="0.22.0" />
</head>
@@ -130,8 +130,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
</a></div>
@@ -160,6 +160,7 @@
<li class="toctree-l1"><a class="reference internal" href="numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="distributed.html">Distributed Communication</a></li>
<li class="toctree-l1 current active"><a class="current reference internal" href="#">Using Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="export.html">Exporting Functions</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="nav bd-sidenav">
@@ -228,6 +229,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
@@ -242,6 +244,13 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
@@ -324,6 +333,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
@@ -379,6 +389,8 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
@@ -403,6 +415,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
@@ -695,6 +708,7 @@
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
</ul>
</div>
@@ -703,9 +717,14 @@
<div class="sidebar-primary-items__end sidebar-primary__section">
<div class="sidebar-primary-item">
<div id="ethical-ad-placement"
class="flat"
data-ea-publisher="readthedocs"
data-ea-type="readthedocs-sidebar"
data-ea-manual="true">
</div></div>
</div>
<div id="rtd-footer-container"></div>
</div>
@@ -898,11 +917,11 @@ run on the default stream of the provided device
</div>
</a>
<a class="right-next"
href="../examples/linear_regression.html"
href="export.html"
title="next page">
<div class="prev-next-info">
<p class="prev-next-subtitle">next</p>
<p class="prev-next-title">Linear Regression</p>
<p class="prev-next-title">Exporting Functions</p>
</div>
<i class="fa-solid fa-angle-right"></i>
</a>
@@ -972,8 +991,8 @@ By MLX Contributors
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
</footer>