mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
rebase
This commit is contained in:
73
docs/build/html/dev/custom_metal_kernels.html
vendored
73
docs/build/html/dev/custom_metal_kernels.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>Custom Metal Kernels — MLX 0.21.1 documentation</title>
|
||||
<title>Custom Metal Kernels — MLX 0.22.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
||||
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
||||
</script>
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
-->
|
||||
<noscript>
|
||||
<style>
|
||||
@@ -27,19 +27,19 @@
|
||||
</noscript>
|
||||
|
||||
<!-- Loaded before other Sphinx assets -->
|
||||
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
|
||||
|
||||
<!-- So that users can add custom icons -->
|
||||
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
|
||||
<!-- Pre-loaded scripts that we'll load fully later -->
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
|
||||
|
||||
<script src="../_static/documentation_options.js?v=acb17c73"></script>
|
||||
<script src="../_static/documentation_options.js?v=c952a61b"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -47,10 +47,11 @@
|
||||
<link rel="icon" href="../_static/mlx_logo.png"/>
|
||||
<link rel="index" title="Index" href="../genindex.html" />
|
||||
<link rel="search" title="Search" href="../search.html" />
|
||||
<link rel="next" title="Using MLX in C++" href="mlx_in_cpp.html" />
|
||||
<link rel="prev" title="Metal Debugger" href="metal_debugger.html" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<meta name="docsearch:language" content="en"/>
|
||||
<meta name="docsearch:version" content="0.21.1" />
|
||||
<meta name="docsearch:version" content="0.22.0" />
|
||||
</head>
|
||||
|
||||
|
||||
@@ -129,8 +130,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -159,6 +160,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/numpy.html">Conversion to NumPy and Other Frameworks</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/distributed.html">Distributed Communication</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/using_streams.html">Using Streams</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/export.html">Exporting Functions</a></li>
|
||||
</ul>
|
||||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
||||
<ul class="nav bd-sidenav">
|
||||
@@ -227,6 +229,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
@@ -241,6 +244,13 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
|
||||
@@ -323,6 +333,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
|
||||
@@ -378,6 +389,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
|
||||
@@ -402,6 +415,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
|
||||
@@ -694,6 +708,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="extensions.html">Custom Extensions in MLX</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="metal_debugger.html">Metal Debugger</a></li>
|
||||
<li class="toctree-l1 current active"><a class="current reference internal" href="#">Custom Metal Kernels</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="mlx_in_cpp.html">Using MLX in C++</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
@@ -702,9 +717,14 @@
|
||||
|
||||
|
||||
<div class="sidebar-primary-items__end sidebar-primary__section">
|
||||
<div class="sidebar-primary-item">
|
||||
<div id="ethical-ad-placement"
|
||||
class="flat"
|
||||
data-ea-publisher="readthedocs"
|
||||
data-ea-type="readthedocs-sidebar"
|
||||
data-ea-manual="true">
|
||||
</div></div>
|
||||
</div>
|
||||
|
||||
<div id="rtd-footer-container"></div>
|
||||
|
||||
|
||||
</div>
|
||||
@@ -870,7 +890,7 @@
|
||||
<section id="simple-example">
|
||||
<h2>Simple Example<a class="headerlink" href="#simple-example" title="Link to this heading">#</a></h2>
|
||||
<p>Let’s write a custom kernel that computes <code class="docutils literal notranslate"><span class="pre">exp</span></code> elementwise:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
|
||||
<span class="n">source</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
|
||||
<span class="s2"> T tmp = inp[elem];</span>
|
||||
@@ -963,7 +983,7 @@ when indexing.</p>
|
||||
input array <code class="docutils literal notranslate"><span class="pre">a</span></code> if any are present in <code class="docutils literal notranslate"><span class="pre">source</span></code>.
|
||||
We can then use MLX’s built in indexing utils to fetch the right elements for each thread.</p>
|
||||
<p>Let’s convert <code class="docutils literal notranslate"><span class="pre">myexp</span></code> above to support arbitrarily strided arrays without relying on a copy from <code class="docutils literal notranslate"><span class="pre">ensure_row_contiguous</span></code>:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
|
||||
<span class="n">source</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
|
||||
<span class="s2"> // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included</span>
|
||||
@@ -1002,7 +1022,7 @@ We can then use MLX’s built in indexing utils to fetch the right elements for
|
||||
<h2>Complex Example<a class="headerlink" href="#complex-example" title="Link to this heading">#</a></h2>
|
||||
<p>Let’s implement a more complex example: <code class="docutils literal notranslate"><span class="pre">grid_sample</span></code> in <code class="docutils literal notranslate"><span class="pre">"bilinear"</span></code> mode.</p>
|
||||
<p>We’ll start with the following MLX implementation using standard ops:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">grid_sample_ref</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">):</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">grid_sample_ref</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">):</span>
|
||||
<span class="n">N</span><span class="p">,</span> <span class="n">H_in</span><span class="p">,</span> <span class="n">W_in</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="n">ix</span> <span class="o">=</span> <span class="p">((</span><span class="n">grid</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">W_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
|
||||
<span class="n">iy</span> <span class="o">=</span> <span class="p">((</span><span class="n">grid</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">H_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
|
||||
@@ -1048,7 +1068,7 @@ We can then use MLX’s built in indexing utils to fetch the right elements for
|
||||
to write a fast GPU kernel for both the forward and backward passes.</p>
|
||||
<p>First we’ll implement the forward pass as a fused kernel:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@mx</span><span class="o">.</span><span class="n">custom_function</span>
|
||||
<span class="k">def</span> <span class="nf">grid_sample</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">grid_sample</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">):</span>
|
||||
|
||||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">4</span><span class="p">,</span> <span class="s2">"`x` must be 4D."</span>
|
||||
<span class="k">assert</span> <span class="n">grid</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">4</span><span class="p">,</span> <span class="s2">"`grid` must be 4D."</span>
|
||||
@@ -1155,7 +1175,7 @@ See section 6.15 of the <a class="reference external" href="https://developer.ap
|
||||
</ul>
|
||||
<p>We can then implement the backwards pass as follows:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@grid_sample</span><span class="o">.</span><span class="n">vjp</span>
|
||||
<span class="k">def</span> <span class="nf">grid_sample_vjp</span><span class="p">(</span><span class="n">primals</span><span class="p">,</span> <span class="n">cotangent</span><span class="p">,</span> <span class="n">_</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">grid_sample_vjp</span><span class="p">(</span><span class="n">primals</span><span class="p">,</span> <span class="n">cotangent</span><span class="p">,</span> <span class="n">_</span><span class="p">):</span>
|
||||
<span class="n">x</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">primals</span>
|
||||
<span class="n">B</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="n">_</span><span class="p">,</span> <span class="n">gN</span><span class="p">,</span> <span class="n">gM</span><span class="p">,</span> <span class="n">D</span> <span class="o">=</span> <span class="n">grid</span><span class="o">.</span><span class="n">shape</span>
|
||||
@@ -1303,6 +1323,15 @@ See section 6.15 of the <a class="reference external" href="https://developer.ap
|
||||
<p class="prev-next-title">Metal Debugger</p>
|
||||
</div>
|
||||
</a>
|
||||
<a class="right-next"
|
||||
href="mlx_in_cpp.html"
|
||||
title="next page">
|
||||
<div class="prev-next-info">
|
||||
<p class="prev-next-subtitle">next</p>
|
||||
<p class="prev-next-title">Using MLX in C++</p>
|
||||
</div>
|
||||
<i class="fa-solid fa-angle-right"></i>
|
||||
</a>
|
||||
</div>
|
||||
</footer>
|
||||
|
||||
@@ -1372,8 +1401,8 @@ By MLX Contributors
|
||||
</div>
|
||||
|
||||
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
|
||||
|
||||
<footer class="bd-footer">
|
||||
</footer>
|
||||
|
101
docs/build/html/dev/extensions.html
vendored
101
docs/build/html/dev/extensions.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>Custom Extensions in MLX — MLX 0.21.1 documentation</title>
|
||||
<title>Custom Extensions in MLX — MLX 0.22.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
||||
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
||||
</script>
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
-->
|
||||
<noscript>
|
||||
<style>
|
||||
@@ -27,19 +27,19 @@
|
||||
</noscript>
|
||||
|
||||
<!-- Loaded before other Sphinx assets -->
|
||||
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
|
||||
|
||||
<!-- So that users can add custom icons -->
|
||||
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
|
||||
<!-- Pre-loaded scripts that we'll load fully later -->
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
|
||||
|
||||
<script src="../_static/documentation_options.js?v=acb17c73"></script>
|
||||
<script src="../_static/documentation_options.js?v=c952a61b"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -51,7 +51,7 @@
|
||||
<link rel="prev" title="Operations" href="../cpp/ops.html" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<meta name="docsearch:language" content="en"/>
|
||||
<meta name="docsearch:version" content="0.21.1" />
|
||||
<meta name="docsearch:version" content="0.22.0" />
|
||||
</head>
|
||||
|
||||
|
||||
@@ -130,8 +130,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -160,6 +160,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/numpy.html">Conversion to NumPy and Other Frameworks</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/distributed.html">Distributed Communication</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/using_streams.html">Using Streams</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/export.html">Exporting Functions</a></li>
|
||||
</ul>
|
||||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
||||
<ul class="nav bd-sidenav">
|
||||
@@ -228,6 +229,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
@@ -242,6 +244,13 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
|
||||
@@ -324,6 +333,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
|
||||
@@ -379,6 +389,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
|
||||
@@ -403,6 +415,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
|
||||
@@ -695,6 +708,7 @@
|
||||
<li class="toctree-l1 current active"><a class="current reference internal" href="#">Custom Extensions in MLX</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="metal_debugger.html">Metal Debugger</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="custom_metal_kernels.html">Custom Metal Kernels</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="mlx_in_cpp.html">Using MLX in C++</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
@@ -703,9 +717,14 @@
|
||||
|
||||
|
||||
<div class="sidebar-primary-items__end sidebar-primary__section">
|
||||
<div class="sidebar-primary-item">
|
||||
<div id="ethical-ad-placement"
|
||||
class="flat"
|
||||
data-ea-publisher="readthedocs"
|
||||
data-ea-type="readthedocs-sidebar"
|
||||
data-ea-manual="true">
|
||||
</div></div>
|
||||
</div>
|
||||
|
||||
<div id="rtd-footer-container"></div>
|
||||
|
||||
|
||||
</div>
|
||||
@@ -895,9 +914,9 @@ explains how to do that with a simple example.</p>
|
||||
<code class="docutils literal notranslate"><span class="pre">y</span></code>, scales them both by coefficients <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code> respectively,
|
||||
and then adds them together to get the result <code class="docutils literal notranslate"><span class="pre">z</span> <span class="pre">=</span> <span class="pre">alpha</span> <span class="pre">*</span> <span class="pre">x</span> <span class="pre">+</span> <span class="pre">beta</span> <span class="pre">*</span> <span class="pre">y</span></code>.
|
||||
You can do that in MLX directly:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">simple_axpby</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">:</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">simple_axpby</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">y</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
@@ -1273,8 +1292,8 @@ element in the output.</p>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="o">&</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">3</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="o">&</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">4</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">int</span><span class="o">*</span><span class="w"> </span><span class="n">shape</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">5</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">size_t</span><span class="o">*</span><span class="w"> </span><span class="n">x_strides</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">6</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">size_t</span><span class="o">*</span><span class="w"> </span><span class="n">y_strides</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">7</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">int64_t</span><span class="o">*</span><span class="w"> </span><span class="n">x_strides</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">6</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">int64_t</span><span class="o">*</span><span class="w"> </span><span class="n">y_strides</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">7</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">int</span><span class="o">&</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">8</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">uint</span><span class="w"> </span><span class="n">index</span><span class="w"> </span><span class="p">[[</span><span class="n">thread_position_in_grid</span><span class="p">]])</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="c1">// Convert linear indices to offsets in array</span>
|
||||
@@ -1289,24 +1308,10 @@ element in the output.</p>
|
||||
</div>
|
||||
<p>We then need to instantiate this template for all floating point types and give
|
||||
each instantiation a unique host name so we can identify it.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cp">#define instantiate_axpby(type_name, type) \</span>
|
||||
<span class="cp"> template [[host_name("axpby_general_" #type_name)]] \</span>
|
||||
<span class="cp"> [[kernel]] void axpby_general<type>( \</span>
|
||||
<span class="cp"> device const type* x [[buffer(0)]], \</span>
|
||||
<span class="cp"> device const type* y [[buffer(1)]], \</span>
|
||||
<span class="cp"> device type* out [[buffer(2)]], \</span>
|
||||
<span class="cp"> constant const float& alpha [[buffer(3)]], \</span>
|
||||
<span class="cp"> constant const float& beta [[buffer(4)]], \</span>
|
||||
<span class="cp"> constant const int* shape [[buffer(5)]], \</span>
|
||||
<span class="cp"> constant const size_t* x_strides [[buffer(6)]], \</span>
|
||||
<span class="cp"> constant const size_t* y_strides [[buffer(7)]], \</span>
|
||||
<span class="cp"> constant const int& ndim [[buffer(8)]], \</span>
|
||||
<span class="cp"> uint index [[thread_position_in_grid]]);</span>
|
||||
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">float32</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="p">);</span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">float16</span><span class="p">,</span><span class="w"> </span><span class="n">half</span><span class="p">);</span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">bfloat16</span><span class="p">,</span><span class="w"> </span><span class="n">bfloat16_t</span><span class="p">);</span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">complex64</span><span class="p">,</span><span class="w"> </span><span class="n">complex64_t</span><span class="p">);</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">instantiate_kernel</span><span class="p">(</span><span class="s">"axpby_general_float32"</span><span class="p">,</span><span class="w"> </span><span class="n">axpby_general</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="p">)</span>
|
||||
<span class="n">instantiate_kernel</span><span class="p">(</span><span class="s">"axpby_general_float16"</span><span class="p">,</span><span class="w"> </span><span class="n">axpby_general</span><span class="p">,</span><span class="w"> </span><span class="n">float16_t</span><span class="p">)</span>
|
||||
<span class="n">instantiate_kernel</span><span class="p">(</span><span class="s">"axpby_general_bfloat16"</span><span class="p">,</span><span class="w"> </span><span class="n">axpby_general</span><span class="p">,</span><span class="w"> </span><span class="n">bfloat16_t</span><span class="p">)</span>
|
||||
<span class="n">instantiate_kernel</span><span class="p">(</span><span class="s">"axpby_general_complex64"</span><span class="p">,</span><span class="w"> </span><span class="n">axpby_general</span><span class="p">,</span><span class="w"> </span><span class="n">complex64_t</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||
@@ -1589,8 +1594,8 @@ automatically imported with MLX package).</p>
|
||||
<h3>Building with <code class="docutils literal notranslate"><span class="pre">setuptools</span></code><a class="headerlink" href="#building-with-setuptools" title="Link to this heading">#</a></h3>
|
||||
<p>Once we have set out the CMake build rules as described above, we can use the
|
||||
build utilities defined in <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.extension</span></code>:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">mlx</span> <span class="kn">import</span> <span class="n">extension</span>
|
||||
<span class="kn">from</span> <span class="nn">setuptools</span> <span class="kn">import</span> <span class="n">setup</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">mlx</span><span class="w"> </span><span class="kn">import</span> <span class="n">extension</span>
|
||||
<span class="kn">from</span><span class="w"> </span><span class="nn">setuptools</span><span class="w"> </span><span class="kn">import</span> <span class="n">setup</span>
|
||||
|
||||
<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">"__main__"</span><span class="p">:</span>
|
||||
<span class="n">setup</span><span class="p">(</span>
|
||||
@@ -1642,8 +1647,8 @@ copied along with the Python binding since they are specified as
|
||||
<p>After installing the extension as described above, you should be able to simply
|
||||
import the Python package and play with it as you would any other MLX operation.</p>
|
||||
<p>Let’s look at a simple script and its results:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||||
<span class="kn">from</span> <span class="nn">mlx_sample_extensions</span> <span class="kn">import</span> <span class="n">axpby</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||||
<span class="kn">from</span><span class="w"> </span><span class="nn">mlx_sample_extensions</span><span class="w"> </span><span class="kn">import</span> <span class="n">axpby</span>
|
||||
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
|
||||
<span class="n">b</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
|
||||
@@ -1664,13 +1669,13 @@ import the Python package and play with it as you would any other MLX operation.
|
||||
<h3>Results<a class="headerlink" href="#results" title="Link to this heading">#</a></h3>
|
||||
<p>Let’s run a quick benchmark and see how our new <code class="docutils literal notranslate"><span class="pre">axpby</span></code> operation compares
|
||||
with the naive <code class="xref py py-meth docutils literal notranslate"><span class="pre">simple_axpby()</span></code> we first defined on the CPU.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||||
<span class="kn">from</span> <span class="nn">mlx_sample_extensions</span> <span class="kn">import</span> <span class="n">axpby</span>
|
||||
<span class="kn">import</span> <span class="nn">time</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||||
<span class="kn">from</span><span class="w"> </span><span class="nn">mlx_sample_extensions</span><span class="w"> </span><span class="kn">import</span> <span class="n">axpby</span>
|
||||
<span class="kn">import</span><span class="w"> </span><span class="nn">time</span>
|
||||
|
||||
<span class="n">mx</span><span class="o">.</span><span class="n">set_default_device</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">simple_axpby</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">:</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">simple_axpby</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">y</span>
|
||||
|
||||
<span class="n">M</span> <span class="o">=</span> <span class="mi">256</span>
|
||||
@@ -1683,7 +1688,7 @@ with the naive <code class="xref py py-meth docutils literal notranslate"><span
|
||||
|
||||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">bench</span><span class="p">(</span><span class="n">f</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">bench</span><span class="p">(</span><span class="n">f</span><span class="p">):</span>
|
||||
<span class="c1"># Warm up</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span>
|
||||
@@ -1837,8 +1842,8 @@ By MLX Contributors
|
||||
</div>
|
||||
|
||||
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
|
||||
|
||||
<footer class="bd-footer">
|
||||
</footer>
|
||||
|
55
docs/build/html/dev/metal_debugger.html
vendored
55
docs/build/html/dev/metal_debugger.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>Metal Debugger — MLX 0.21.1 documentation</title>
|
||||
<title>Metal Debugger — MLX 0.22.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
||||
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
||||
</script>
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
-->
|
||||
<noscript>
|
||||
<style>
|
||||
@@ -27,19 +27,19 @@
|
||||
</noscript>
|
||||
|
||||
<!-- Loaded before other Sphinx assets -->
|
||||
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
|
||||
|
||||
<!-- So that users can add custom icons -->
|
||||
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
|
||||
<!-- Pre-loaded scripts that we'll load fully later -->
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
|
||||
|
||||
<script src="../_static/documentation_options.js?v=acb17c73"></script>
|
||||
<script src="../_static/documentation_options.js?v=c952a61b"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -51,7 +51,7 @@
|
||||
<link rel="prev" title="Custom Extensions in MLX" href="extensions.html" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<meta name="docsearch:language" content="en"/>
|
||||
<meta name="docsearch:version" content="0.21.1" />
|
||||
<meta name="docsearch:version" content="0.22.0" />
|
||||
</head>
|
||||
|
||||
|
||||
@@ -130,8 +130,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -160,6 +160,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/numpy.html">Conversion to NumPy and Other Frameworks</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/distributed.html">Distributed Communication</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/using_streams.html">Using Streams</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/export.html">Exporting Functions</a></li>
|
||||
</ul>
|
||||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
||||
<ul class="nav bd-sidenav">
|
||||
@@ -228,6 +229,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
@@ -242,6 +244,13 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
|
||||
@@ -324,6 +333,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
|
||||
@@ -379,6 +389,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
|
||||
@@ -403,6 +415,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
|
||||
@@ -695,6 +708,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="extensions.html">Custom Extensions in MLX</a></li>
|
||||
<li class="toctree-l1 current active"><a class="current reference internal" href="#">Metal Debugger</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="custom_metal_kernels.html">Custom Metal Kernels</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="mlx_in_cpp.html">Using MLX in C++</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
@@ -703,9 +717,14 @@
|
||||
|
||||
|
||||
<div class="sidebar-primary-items__end sidebar-primary__section">
|
||||
<div class="sidebar-primary-item">
|
||||
<div id="ethical-ad-placement"
|
||||
class="flat"
|
||||
data-ea-publisher="readthedocs"
|
||||
data-ea-type="readthedocs-sidebar"
|
||||
data-ea-manual="true">
|
||||
</div></div>
|
||||
</div>
|
||||
|
||||
<div id="rtd-footer-container"></div>
|
||||
|
||||
|
||||
</div>
|
||||
@@ -881,7 +900,7 @@ work.</p>
|
||||
<p>To capture a GPU trace you must run the application with
|
||||
<code class="docutils literal notranslate"><span class="pre">MTL_CAPTURE_ENABLED=1</span></code>.</p>
|
||||
</div>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||||
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">))</span>
|
||||
<span class="n">b</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">))</span>
|
||||
@@ -1012,8 +1031,8 @@ By MLX Contributors
|
||||
</div>
|
||||
|
||||
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
|
||||
|
||||
<footer class="bd-footer">
|
||||
</footer>
|
||||
|
1059
docs/build/html/dev/mlx_in_cpp.html
vendored
Normal file
1059
docs/build/html/dev/mlx_in_cpp.html
vendored
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user