docs update

This commit is contained in:
Awni Hannun
2024-04-11 17:33:33 -07:00
committed by CircleCI Docs
parent 1d2cadbc78
commit f77d99b285
413 changed files with 9992 additions and 2202 deletions

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<title>Developer Documentation &#8212; MLX 0.9.0 documentation</title>
<title>Developer Documentation &#8212; MLX 0.10.0 documentation</title>
@@ -36,7 +36,7 @@
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=5b4479735964841361fd" />
<script src="../_static/vendor/fontawesome/6.1.2/js/all.min.js?digest=5b4479735964841361fd"></script>
<script src="../_static/documentation_options.js?v=2a76c96f"></script>
<script src="../_static/documentation_options.js?v=cb265169"></script>
<script src="../_static/doctools.js?v=888ff710"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=efea14e4"></script>
@@ -131,8 +131,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.9.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.9.0 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.10.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.10.0 documentation - Home"/>`);</script>
</a></div>
@@ -286,6 +286,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erf.html">mlx.core.erf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erfinv.html">mlx.core.erfinv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exp.html">mlx.core.exp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expm1.html">mlx.core.expm1</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expand_dims.html">mlx.core.expand_dims</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eye.html">mlx.core.eye</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.flatten.html">mlx.core.flatten</a></li>
@@ -318,6 +319,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.max.html">mlx.core.max</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.maximum.html">mlx.core.maximum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.mean.html">mlx.core.mean</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.meshgrid.html">mlx.core.meshgrid</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.min.html">mlx.core.min</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.minimum.html">mlx.core.minimum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.moveaxis.html">mlx.core.moveaxis</a></li>
@@ -352,6 +354,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.square.html">mlx.core.square</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.squeeze.html">mlx.core.squeeze</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stack.html">mlx.core.stack</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.std.html">mlx.core.std</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stop_gradient.html">mlx.core.stop_gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.subtract.html">mlx.core.subtract</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sum.html">mlx.core.sum</a></li>
@@ -379,6 +382,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.gumbel.html">mlx.core.random.gumbel</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.key.html">mlx.core.random.key</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.normal.html">mlx.core.random.normal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.multivariate_normal.html">mlx.core.random.multivariate_normal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.randint.html">mlx.core.random.randint</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.seed.html">mlx.core.random.seed</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.split.html">mlx.core.random.split</a></li>
@@ -432,6 +436,8 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.get_cache_memory.html">mlx.core.metal.get_cache_memory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.set_memory_limit.html">mlx.core.metal.set_memory_limit</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.set_cache_limit.html">mlx.core.metal.set_cache_limit</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.start_capture.html">mlx.core.metal.start_capture</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.stop_capture.html">mlx.core.metal.stop_capture</a></li>
</ul>
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/nn.html">Neural Networks</a><input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-11"><i class="fa-solid fa-chevron-down"></i></label><ul>
@@ -763,12 +769,12 @@ document.write(`
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#operations-and-primitives">Operations and Primitives</a><ul class="visible nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#operations">Operations</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#primitives">Primitives</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#using-the-primitives">Using the Primitives</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#using-the-primitive">Using the Primitive</a></li>
</ul>
</li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-primitive">Implementing the Primitive</a><ul class="visible nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-cpu-backend">Implementing the CPU Backend</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-gpu-backend">Implementing the GPU Backend</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-cpu-back-end">Implementing the CPU Back-end</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-gpu-back-end">Implementing the GPU Back-end</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#primitive-transforms">Primitive Transforms</a></li>
</ul>
</li>
@@ -796,61 +802,45 @@ document.write(`
<section id="developer-documentation">
<h1>Developer Documentation<a class="headerlink" href="#developer-documentation" title="Link to this heading">#</a></h1>
<p>MLX provides a open and flexible backend to which users may add operations
and specialized implementations without much hassle. While the library supplies
efficient operations that can be used and composed for any number of
applications, there may arise cases where new functionalities or highly
optimized implementations are needed. For such cases, you may design and
implement your own operations that link to and build on top of <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code>.
We will introduce the inner-workings of MLX and go over a simple example to
learn the steps involved in adding new operations to MLX with your own CPU
and GPU implementations.</p>
<p>You can extend MLX with custom operations on the CPU or GPU. This guide
explains how to do that with a simple example.</p>
<section id="introducing-the-example">
<h2>Introducing the Example<a class="headerlink" href="#introducing-the-example" title="Link to this heading">#</a></h2>
<p>Lets say that you would like an operation that takes in two arrays,
<code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>, scales them both by some coefficients <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>
respectively, and then adds them together to get the result
<code class="docutils literal notranslate"><span class="pre">z</span> <span class="pre">=</span> <span class="pre">alpha</span> <span class="pre">*</span> <span class="pre">x</span> <span class="pre">+</span> <span class="pre">beta</span> <span class="pre">*</span> <span class="pre">y</span></code>. Well, you can very easily do that by just
writing out a function as follows:</p>
<p>Lets say you would like an operation that takes in two arrays, <code class="docutils literal notranslate"><span class="pre">x</span></code> and
<code class="docutils literal notranslate"><span class="pre">y</span></code>, scales them both by coefficients <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code> respectively,
and then adds them together to get the result <code class="docutils literal notranslate"><span class="pre">z</span> <span class="pre">=</span> <span class="pre">alpha</span> <span class="pre">*</span> <span class="pre">x</span> <span class="pre">+</span> <span class="pre">beta</span> <span class="pre">*</span> <span class="pre">y</span></code>.
You can do that in MLX directly:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="k">def</span> <span class="nf">simple_axpby</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">:</span>
<span class="k">return</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">y</span>
</pre></div>
</div>
<p>This function performs that operation while leaving the implementations and
differentiation to MLX.</p>
<p>However, you work with vector math libraries often and realize that the
<code class="docutils literal notranslate"><span class="pre">axpby</span></code> routine defines the same operation <code class="docutils literal notranslate"><span class="pre">Y</span> <span class="pre">=</span> <span class="pre">(alpha</span> <span class="pre">*</span> <span class="pre">X)</span> <span class="pre">+</span> <span class="pre">(beta</span> <span class="pre">*</span> <span class="pre">Y)</span></code>.
You would really like the part of your applications that does this operation
on the CPU to be very fast - so you decide that you want it to rely on the
<code class="docutils literal notranslate"><span class="pre">axpby</span></code> routine provided by the <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a> framework. Continuing to impose
our assumptions on to you, lets also assume that you want to learn how to add
your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework.</p>
<p>Well, what a coincidence! You are in the right place. Over the course of this
example, we will learn:</p>
<p>This function performs that operation while leaving the implementation and
function transformations to MLX.</p>
<p>However you may need to customize the underlying implementation, perhaps to
make it faster or for custom differentiation. In this tutorial we will go
through adding custom extensions. It will cover:</p>
<ul class="simple">
<li><p>The structure of the MLX library from the frontend API to the backend implementations.</p></li>
<li><p>How to implement your own CPU backend that redirects to <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a> when appropriate (and a fallback if needed).</p></li>
<li><p>How to implement your own GPU implementation using metal.</p></li>
<li><p>How to add your own <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code>.</p></li>
<li><p>How to build your implementations, link them to MLX, and bind them to python.</p></li>
<li><p>The structure of the MLX library.</p></li>
<li><p>Implementing a CPU operation that redirects to <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a> when appropriate.</p></li>
<li><p>Implementing a GPU operation using metal.</p></li>
<li><p>Adding the <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code> function transformation.</p></li>
<li><p>Building a custom extension and binding it to python.</p></li>
</ul>
</section>
<section id="operations-and-primitives">
<h2>Operations and Primitives<a class="headerlink" href="#operations-and-primitives" title="Link to this heading">#</a></h2>
<p>In one sentence, operations in MLX build the computation graph, and primitives
provide the rules for evaluation and transformations of said graph. Lets start
by discussing operations in more detail.</p>
<p>Operations in MLX build the computation graph. Primitives provide the rules for
evaluating and transforming the graph. Lets start by discussing operations in
more detail.</p>
<section id="operations">
<h3>Operations<a class="headerlink" href="#operations" title="Link to this heading">#</a></h3>
<p>Operations are the frontend functions that operate on arrays. They are defined
in the C++ API (<a class="reference internal" href="../cpp/ops.html#cpp-ops"><span class="std std-ref">Operations</span></a>) and then we provide bindings to these
operations in the Python API (<a class="reference internal" href="../python/ops.html#ops"><span class="std std-ref">Operations</span></a>).</p>
<p>We would like an operation, <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> that takes in two arrays <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>,
and two scalars, <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>. This is how we would define it in the
C++ API:</p>
<p>Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (<a class="reference internal" href="../cpp/ops.html#cpp-ops"><span class="std std-ref">Operations</span></a>), and the Python API (<a class="reference internal" href="../python/ops.html#ops"><span class="std std-ref">Operations</span></a>) binds them.</p>
<p>We would like an operation, <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> that takes in two arrays <code class="docutils literal notranslate"><span class="pre">x</span></code> and
<code class="docutils literal notranslate"><span class="pre">y</span></code>, and two scalars, <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>. This is how to define it in
C++:</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/**</span>
<span class="cm">* Scale and sum two vectors element-wise</span>
<span class="cm">* z = alpha * x + beta * y</span>
@@ -867,9 +857,7 @@ C++ API:</p>
<span class="p">);</span>
</pre></div>
</div>
<p>This operation itself can call other operations within it if needed. So, the
simplest way to go about implementing this operation would be do so in terms
of existing operations.</p>
<p>The simplest way to this operation is in terms of existing operations:</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array x</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array y</span>
@@ -886,19 +874,17 @@ of existing operations.</p>
<span class="p">}</span>
</pre></div>
</div>
<p>However, as we discussed earlier, this is not our goal. The operations themselves
do not contain the implementations that act on the data, nor do they contain the
rules of transformations. Rather, they are an easy to use interface that build
on top of the building blocks we call <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>.</p>
<p>The operations themselves do not contain the implementations that act on the
data, nor do they contain the rules of transformations. Rather, they are an
easy to use interface that use <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> building blocks.</p>
</section>
<section id="primitives">
<h3>Primitives<a class="headerlink" href="#primitives" title="Link to this heading">#</a></h3>
<p>A <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> is part of the computation graph of an <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code>. It
defines how to create an output given a set of input <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> . Further,
a <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> is a class that contains rules on how it is evaluated
on the CPU or GPU, and how it acts under transformations such as <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and
<code class="docutils literal notranslate"><span class="pre">jvp</span></code>. These words on their own can be a bit abstract, so lets take a step
back and go to our example to give ourselves a more concrete image.</p>
defines how to create outputs arrays given a input arrays. Further, a
<code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> has methods to run on the CPU or GPU and for function
transformations such as <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code>. Lets go back to our example to be
more concrete:</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">Axpby</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="k">public</span><span class="w"> </span><span class="n">Primitive</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">public</span><span class="o">:</span>
<span class="w"> </span><span class="k">explicit</span><span class="w"> </span><span class="n">Axpby</span><span class="p">(</span><span class="n">Stream</span><span class="w"> </span><span class="n">stream</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta</span><span class="p">)</span>
@@ -911,11 +897,15 @@ back and go to our example to give ourselves a more concrete image.</p>
<span class="cm"> * To avoid unnecessary allocations, the evaluation function</span>
<span class="cm"> * is responsible for allocating space for the array.</span>
<span class="cm"> */</span>
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_cpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_gpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_cpu</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_gpu</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
<span class="w"> </span><span class="cm">/** The Jacobian-vector product. */</span>
<span class="w"> </span><span class="n">array</span><span class="w"> </span><span class="nf">jvp</span><span class="p">(</span>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;</span><span class="w"> </span><span class="n">jvp</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
@@ -924,7 +914,8 @@ back and go to our example to give ourselves a more concrete image.</p>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;</span><span class="w"> </span><span class="n">vjp</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">argnums</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
<span class="w"> </span><span class="cm">/**</span>
<span class="cm"> * The primitive must know how to vectorize itself across</span>
@@ -932,7 +923,7 @@ back and go to our example to give ourselves a more concrete image.</p>
<span class="cm"> * representing the vectorized computation and the axis which</span>
<span class="cm"> * corresponds to the output vectorized dimension.</span>
<span class="cm"> */</span>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o">&lt;</span><span class="n">array</span><span class="p">,</span><span class="w"> </span><span class="kt">int</span><span class="o">&gt;</span><span class="w"> </span><span class="n">vmap</span><span class="p">(</span>
<span class="w"> </span><span class="k">virtual</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;</span><span class="p">,</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&gt;</span><span class="w"> </span><span class="n">vmap</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">axes</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
@@ -953,20 +944,20 @@ back and go to our example to give ourselves a more concrete image.</p>
<span class="p">};</span>
</pre></div>
</div>
<p>The <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> class derives from the base <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> class and
follows the above demonstrated interface. <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> treats <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and
<code class="docutils literal notranslate"><span class="pre">beta</span></code> as parameters. It then provides implementations of how the array <code class="docutils literal notranslate"><span class="pre">out</span></code>
is produced given <code class="docutils literal notranslate"><span class="pre">inputs</span></code> through <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code> and
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code>. Further, it provides rules of transformations in
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::jvp()</span></code>, <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::vjp()</span></code>, and <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::vmap()</span></code>.</p>
<p>The <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> class derives from the base <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> class. The
<code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> treats <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code> as parameters. It then provides
implementations of how the output array is produced given the inputs through
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code> and <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code>. It also provides rules
of transformations in <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::jvp()</span></code>, <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::vjp()</span></code>, and
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::vmap()</span></code>.</p>
</section>
<section id="using-the-primitives">
<h3>Using the Primitives<a class="headerlink" href="#using-the-primitives" title="Link to this heading">#</a></h3>
<p>Operations can use this <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> to add a new <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> to
the computation graph. An <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> can be constructed by providing its
data type, shape, the <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> that computes it, and the
<code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> inputs that are passed to the primitive.</p>
<p>Lets re-implement our operation now in terms of our <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> primitive.</p>
<section id="using-the-primitive">
<h3>Using the Primitive<a class="headerlink" href="#using-the-primitive" title="Link to this heading">#</a></h3>
<p>Operations can use this <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> to add a new <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> to the
computation graph. An <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> can be constructed by providing its data
type, shape, the <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> that computes it, and the <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code>
inputs that are passed to the primitive.</p>
<p>Lets reimplement our operation now in terms of our <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> primitive.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array x</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array y</span>
@@ -1012,25 +1003,24 @@ data type, shape, the <code class="xref py py-class docutils literal notranslate
</section>
<section id="implementing-the-primitive">
<h2>Implementing the Primitive<a class="headerlink" href="#implementing-the-primitive" title="Link to this heading">#</a></h2>
<p>No computation happens when we call the operation alone. In effect, the
operation only builds the computation graph. When we evaluate the output
array, MLX schedules the execution of the computation graph, and calls
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code> or <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code> depending on the
stream/device specified by the user.</p>
<p>No computation happens when we call the operation alone. The operation only
builds the computation graph. When we evaluate the output array, MLX schedules
the execution of the computation graph, and calls <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code> or
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code> depending on the stream/device specified by the user.</p>
<div class="admonition warning">
<p class="admonition-title">Warning</p>
<p>When <code class="xref py py-meth docutils literal notranslate"><span class="pre">Primitive::eval_cpu()</span></code> or <code class="xref py py-meth docutils literal notranslate"><span class="pre">Primitive::eval_gpu()</span></code> are called,
no memory has been allocated for the output array. It falls on the implementation
of these functions to allocate memory as needed</p>
of these functions to allocate memory as needed.</p>
</div>
<section id="implementing-the-cpu-backend">
<h3>Implementing the CPU Backend<a class="headerlink" href="#implementing-the-cpu-backend" title="Link to this heading">#</a></h3>
<p>Lets start by trying to implement a naive and generic version of
<section id="implementing-the-cpu-back-end">
<h3>Implementing the CPU Back-end<a class="headerlink" href="#implementing-the-cpu-back-end" title="Link to this heading">#</a></h3>
<p>Lets start by implementing a naive and generic version of
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>. We declared this as a private member function of
<code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> earlier called <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval()</span></code>.</p>
<p>Our naive method will go over each element of the output array, find the
corresponding input elements of <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> and perform the operation
pointwise. This is captured in the templated function <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby_impl()</span></code>.</p>
point-wise. This is captured in the templated function <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby_impl()</span></code>.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o">&lt;</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">&gt;</span>
<span class="kt">void</span><span class="w"> </span><span class="n">axpby_impl</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="p">,</span>
@@ -1066,16 +1056,16 @@ pointwise. This is captured in the templated function <code class="xref py py-me
<span class="p">}</span>
</pre></div>
</div>
<p>Now, we would like our implementation to be able to do this pointwise operation
for all incoming floating point arrays. Accordingly, we add dispatches for
<code class="docutils literal notranslate"><span class="pre">float32</span></code>, <code class="docutils literal notranslate"><span class="pre">float16</span></code>, <code class="docutils literal notranslate"><span class="pre">bfloat16</span></code> and <code class="docutils literal notranslate"><span class="pre">complex64</span></code>. We throw an error
if we encounter an unexpected type.</p>
<p>Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for <code class="docutils literal notranslate"><span class="pre">float32</span></code>, <code class="docutils literal notranslate"><span class="pre">float16</span></code>, <code class="docutils literal notranslate"><span class="pre">bfloat16</span></code> and
<code class="docutils literal notranslate"><span class="pre">complex64</span></code>. We throw an error if we encounter an unexpected type.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Fall back implementation for evaluation on CPU */</span>
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Check the inputs (registered in the op while constructing the out array)</span>
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
<span class="w"> </span><span class="c1">// Dispatch to the correct dtype</span>
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float32</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
@@ -1088,29 +1078,27 @@ if we encounter an unexpected type.</p>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o">&lt;</span><span class="n">complex64_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span>
<span class="w"> </span><span class="s">&quot;Axpby is only supported for floating point types.&quot;</span><span class="p">);</span>
<span class="w"> </span><span class="s">&quot;[Axpby] Only supports floating point types.&quot;</span><span class="p">);</span>
<span class="w"> </span><span class="p">}</span>
<span class="p">}</span>
</pre></div>
</div>
<p>We have a fallback implementation! Now, to do what we are really here to do.
Remember we wanted to use the <code class="docutils literal notranslate"><span class="pre">axpby</span></code> routine provided by the <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a>
framework? Well, there are 3 complications to keep in mind:</p>
<p>This is good as a fallback implementation. We can use the <code class="docutils literal notranslate"><span class="pre">axpby</span></code> routine
provided by the <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a> framework for a faster implementation in certain
cases:</p>
<ol class="arabic simple">
<li><p>Accelerate does not provide implementations of <code class="docutils literal notranslate"><span class="pre">axpby</span></code> for half precision
floats. We can only direct to it for <code class="docutils literal notranslate"><span class="pre">float32</span></code> types</p></li>
<li><p>Accelerate assumes the inputs <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> are contiguous and all elements
have fixed strides between them. Possibly due to broadcasts and transposes,
we arent guaranteed that the inputs fit this requirement. We can
only direct to Accelerate if both <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> are row contiguous or
column contiguous.</p></li>
<li><p>Accelerate performs the routine <code class="docutils literal notranslate"><span class="pre">Y</span> <span class="pre">=</span> <span class="pre">(alpha</span> <span class="pre">*</span> <span class="pre">X)</span> <span class="pre">+</span> <span class="pre">(beta</span> <span class="pre">*</span> <span class="pre">Y)</span></code> inplace.
MLX expects to write out the answer to a new array. We must copy the elements
of <code class="docutils literal notranslate"><span class="pre">y</span></code> into the output array and use that as an input to <code class="docutils literal notranslate"><span class="pre">axpby</span></code></p></li>
floats. We can only use it for <code class="docutils literal notranslate"><span class="pre">float32</span></code> types.</p></li>
<li><p>Accelerate assumes the inputs <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> are contiguous and all
elements have fixed strides between them. We only direct to Accelerate
if both <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> are row contiguous or column contiguous.</p></li>
<li><p>Accelerate performs the routine <code class="docutils literal notranslate"><span class="pre">Y</span> <span class="pre">=</span> <span class="pre">(alpha</span> <span class="pre">*</span> <span class="pre">X)</span> <span class="pre">+</span> <span class="pre">(beta</span> <span class="pre">*</span> <span class="pre">Y)</span></code> in-place.
MLX expects to write the output to a new array. We must copy the elements
of <code class="docutils literal notranslate"><span class="pre">y</span></code> into the output and use that as an input to <code class="docutils literal notranslate"><span class="pre">axpby</span></code>.</p></li>
</ol>
<p>Lets write out an implementation that uses Accelerate in the right conditions.
It must simply allocate data for the output, copy elements of <code class="docutils literal notranslate"><span class="pre">y</span></code> into it,
and then call the <code class="xref py py-meth docutils literal notranslate"><span class="pre">catlas_saxpby()</span></code> from accelerate.</p>
<p>Lets write an implementation that uses Accelerate in the right conditions.
It allocates data for the output, copies <code class="docutils literal notranslate"><span class="pre">y</span></code> into it, and then calls the
<code class="xref py py-func docutils literal notranslate"><span class="pre">catlas_saxpby()</span></code> from accelerate.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o">&lt;</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">&gt;</span>
<span class="kt">void</span><span class="w"> </span><span class="n">axpby_impl_accelerate</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="p">,</span>
@@ -1121,17 +1109,7 @@ and then call the <code class="xref py py-meth docutils literal notranslate"><sp
<span class="w"> </span><span class="c1">// Accelerate library provides catlas_saxpby which does</span>
<span class="w"> </span><span class="c1">// Y = (alpha * X) + (beta * Y) in place</span>
<span class="w"> </span><span class="c1">// To use it, we first copy the data in y over to the output array</span>
<span class="w"> </span><span class="c1">// This specialization requires both x and y be contiguous in the same mode</span>
<span class="w"> </span><span class="c1">// i.e: corresponding linear indices in both point to corresponding elements</span>
<span class="w"> </span><span class="c1">// The data in the output array is allocated to match the strides in y</span>
<span class="w"> </span><span class="c1">// such that x, y, and out are contiguous in the same mode and</span>
<span class="w"> </span><span class="c1">// no transposition is needed</span>
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span>
<span class="w"> </span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">data_size</span><span class="p">()</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">itemsize</span><span class="p">()),</span>
<span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">data_size</span><span class="p">(),</span>
<span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span>
<span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">flags</span><span class="p">());</span>
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
<span class="w"> </span><span class="c1">// We then copy over the elements using the contiguous vector specialization</span>
<span class="w"> </span><span class="n">copy_inplace</span><span class="p">(</span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">CopyType</span><span class="o">::</span><span class="n">Vector</span><span class="p">);</span>
@@ -1155,14 +1133,17 @@ and then call the <code class="xref py py-meth docutils literal notranslate"><sp
<span class="p">}</span>
</pre></div>
</div>
<p>Great! But what about the inputs that do not fit the criteria for accelerate?
Luckily, we can always just direct back to <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval()</span></code>.</p>
<p>With this in mind, lets finally implement our <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>.</p>
<p>For inputs that do not fit the criteria for accelerate, we fall back to
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval()</span></code>. With this in mind, lets finish our
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Evaluate primitive on CPU using accelerate specializations */</span>
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_cpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_cpu</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
<span class="w"> </span><span class="c1">// Accelerate specialization for contiguous single precision float arrays</span>
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float32</span><span class="w"> </span><span class="o">&amp;&amp;</span>
@@ -1172,33 +1153,32 @@ Luckily, we can always just direct back to <code class="xref py py-meth docutils
<span class="w"> </span><span class="k">return</span><span class="p">;</span>
<span class="w"> </span><span class="p">}</span>
<span class="w"> </span><span class="c1">// Fall back to common backend if specializations are not available</span>
<span class="w"> </span><span class="n">eval</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Fall back to common back-end if specializations are not available</span>
<span class="w"> </span><span class="n">eval</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">outputs</span><span class="p">);</span>
<span class="p">}</span>
</pre></div>
</div>
<p>We have now hit a milestone! Just this much is enough to run the operation
<code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> on a CPU stream!</p>
<p>If you do not plan on running the operation on the GPU or using transforms on
<p>Just this much is enough to run the operation <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code>, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.</p>
</section>
<section id="implementing-the-gpu-backend">
<h3>Implementing the GPU Backend<a class="headerlink" href="#implementing-the-gpu-backend" title="Link to this heading">#</a></h3>
<section id="implementing-the-gpu-back-end">
<h3>Implementing the GPU Back-end<a class="headerlink" href="#implementing-the-gpu-back-end" title="Link to this heading">#</a></h3>
<p>Apple silicon devices address their GPUs using the <a class="reference external" href="https://developer.apple.com/documentation/metal?language=objc">Metal</a> shading language, and
all GPU kernels in MLX are written using metal.</p>
GPU kernels in MLX are written using Metal.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Here are some helpful resources if you are new to metal!</p>
<p>Here are some helpful resources if you are new to Metal:</p>
<ul class="simple">
<li><p>A walkthrough of the metal compute pipeline: <a class="reference external" href="https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc">Metal Example</a></p></li>
<li><p>Documentation for metal shading language: <a class="reference external" href="https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf">Metal Specification</a></p></li>
<li><p>Using metal from C++: <a class="reference external" href="https://developer.apple.com/metal/cpp/">Metal-cpp</a></p></li>
</ul>
</div>
<p>Lets keep the GPU algorithm simple. We will launch exactly as many threads
as there are elements in the output. Each thread will pick the element it needs
from <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>, do the pointwise operation, and then update its assigned
<p>Lets keep the GPU kernel simple. We will launch exactly as many threads as
there are elements in the output. Each thread will pick the element it needs
from <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>, do the point-wise operation, and update its assigned
element in the output.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o">&lt;</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">&gt;</span>
<span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="n">axpby_general</span><span class="p">(</span>
@@ -1223,8 +1203,7 @@ element in the output.</p>
</pre></div>
</div>
<p>We then need to instantiate this template for all floating point types and give
each instantiation a unique host name so we can identify the right kernel for
each data type.</p>
each instantiation a unique host name so we can identify it.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cp">#define instantiate_axpby(type_name, type) \</span>
<span class="cp"> template [[host_name(&quot;axpby_general_&quot; #type_name)]] \</span>
<span class="cp"> [[kernel]] void axpby_general&lt;type&gt;( \</span>
@@ -1245,25 +1224,18 @@ each data type.</p>
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">complex64</span><span class="p">,</span><span class="w"> </span><span class="n">complex64_t</span><span class="p">);</span>
</pre></div>
</div>
<p>This kernel will be compiled into a metal library <code class="docutils literal notranslate"><span class="pre">mlx_ext.metallib</span></code> as we
will see later in <a class="reference internal" href="#building-with-cmake"><span class="std std-ref">Building with CMake</span></a>. In the following example, we
assume that the library <code class="docutils literal notranslate"><span class="pre">mlx_ext.metallib</span></code> will always be co-located with
the executable/ shared-library calling the <code class="xref py py-meth docutils literal notranslate"><span class="pre">register_library()</span></code> function.
The <code class="xref py py-meth docutils literal notranslate"><span class="pre">register_library()</span></code> function takes the librarys name and potential
path (or in this case, a function that can produce the path of the metal
library) and tries to load that library if it hasnt already been registered
by the relevant static <code class="xref py py-class docutils literal notranslate"><span class="pre">mlx::core::metal::Device</span></code> object. This is why,
it is important to package your C++ library with the metal library. We will
go over this process in more detail later.</p>
<p>The logic to determine the kernel, set the inputs, resolve the grid dimensions
and dispatch it to the GPU are contained in <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code> as shown
<p>The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code> as shown
below.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Evaluate primitive on GPU */</span>
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_gpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_gpu</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Prepare inputs</span>
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
<span class="w"> </span><span class="c1">// Each primitive carries the stream it should execute on</span>
<span class="w"> </span><span class="c1">// and each stream carries its device identifiers</span>
@@ -1274,7 +1246,7 @@ below.</p>
<span class="w"> </span><span class="c1">// Allocate output memory</span>
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
<span class="w"> </span><span class="c1">// Resolve name of kernel (corresponds to axpby.metal)</span>
<span class="w"> </span><span class="c1">// Resolve name of kernel</span>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">ostringstream</span><span class="w"> </span><span class="n">kname</span><span class="p">;</span>
<span class="w"> </span><span class="n">kname</span><span class="w"> </span><span class="o">&lt;&lt;</span><span class="w"> </span><span class="s">&quot;axpby_&quot;</span><span class="w"> </span><span class="o">&lt;&lt;</span><span class="w"> </span><span class="s">&quot;general_&quot;</span><span class="w"> </span><span class="o">&lt;&lt;</span><span class="w"> </span><span class="n">type_to_name</span><span class="p">(</span><span class="n">out</span><span class="p">);</span>
@@ -1328,24 +1300,21 @@ below.</p>
</pre></div>
</div>
<p>We can now call the <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> operation on both the CPU and the GPU!</p>
<p>A few things to note about MLX and metal before moving on. MLX keeps track
of the active <code class="docutils literal notranslate"><span class="pre">compute_encoder</span></code>. We rely on <code class="xref py py-meth docutils literal notranslate"><span class="pre">d.get_command_encoder()</span></code>
to give us the active metal compute command encoder instead of building a
new one and calling <code class="xref py py-meth docutils literal notranslate"><span class="pre">compute_encoder-&gt;end_encoding()</span></code> at the end.
MLX keeps adding kernels (compute pipelines) to the active command encoder
until some specified limit is hit or the compute encoder needs to be flushed
for synchronization. MLX also handles enqueuing and committing the associated
command buffers as needed. We suggest taking a deeper dive into
<code class="xref py py-class docutils literal notranslate"><span class="pre">metal::Device</span></code> if you would like to study this routine further.</p>
<p>A few things to note about MLX and Metal before moving on. MLX keeps track of
the active <code class="docutils literal notranslate"><span class="pre">command_buffer</span></code> and the <code class="docutils literal notranslate"><span class="pre">MTLCommandBuffer</span></code> to which it is
associated. We rely on <code class="xref py py-meth docutils literal notranslate"><span class="pre">d.get_command_encoder()</span></code> to give us the active
metal compute command encoder instead of building a new one and calling
<code class="xref py py-meth docutils literal notranslate"><span class="pre">compute_encoder-&gt;end_encoding()</span></code> at the end. MLX adds kernels (compute
pipelines) to the active command buffer until some specified limit is hit or
the command buffer needs to be flushed for synchronization.</p>
</section>
<section id="primitive-transforms">
<h3>Primitive Transforms<a class="headerlink" href="#primitive-transforms" title="Link to this heading">#</a></h3>
<p>Now that we have come this far, lets also learn how to add implementations to
transformations in a <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>. These transformations can be built on
top of our operations, including the one we just defined now. Which then gives
us the following <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::jvp()</span></code> and <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::vjp()</span></code> implementations.</p>
<p>Next, lets add implementations for transformations in a <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>.
These transformations can be built on top of other operations, including the
one we just defined:</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** The Jacobian-vector product. */</span>
<span class="n">array</span><span class="w"> </span><span class="nf">Axpby::jvp</span><span class="p">(</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;</span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">jvp</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
@@ -1360,12 +1329,12 @@ us the following <code class="xref py py-meth docutils literal notranslate"><spa
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">argnums</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">&gt;</span><span class="w"> </span><span class="mi">1</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">argnums</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">0</span><span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">alpha_</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dtype</span><span class="p">());</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="p">{</span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">stream</span><span class="p">())};</span>
<span class="w"> </span><span class="p">}</span>
<span class="w"> </span><span class="c1">// If, argnums = {0, 1}, we take contributions from both</span>
<span class="w"> </span><span class="c1">// which gives us jvp = tangent_x * alpha + tangent_y * beta</span>
<span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby</span><span class="p">(</span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="p">{</span><span class="n">axpby</span><span class="p">(</span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">())};</span>
<span class="w"> </span><span class="p">}</span>
<span class="p">}</span>
</pre></div>
@@ -1373,26 +1342,27 @@ us the following <code class="xref py py-meth docutils literal notranslate"><spa
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** The vector-Jacobian product. */</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;</span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vjp</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">cotangents</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">argnums</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="cm">/* unused */</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Reverse mode diff</span>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;</span><span class="w"> </span><span class="n">vjps</span><span class="p">;</span>
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="k">auto</span><span class="w"> </span><span class="n">arg</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">arg</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">0</span><span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">alpha_</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">cotan</span><span class="p">.</span><span class="n">dtype</span><span class="p">());</span>
<span class="w"> </span><span class="n">vjps</span><span class="p">.</span><span class="n">push_back</span><span class="p">(</span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">()));</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">cotangents</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dtype</span><span class="p">());</span>
<span class="w"> </span><span class="n">vjps</span><span class="p">.</span><span class="n">push_back</span><span class="p">(</span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">cotangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">stream</span><span class="p">()));</span>
<span class="w"> </span><span class="p">}</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">vjps</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
<p>Finally, you need not have a transformation fully defined to start using your
own <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>.</p>
<p>Note, a transformation does not need to be fully defined to start using
the <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Vectorize primitive along given axis */</span>
<span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o">&lt;</span><span class="n">array</span><span class="p">,</span><span class="w"> </span><span class="kt">int</span><span class="o">&gt;</span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vmap</span><span class="p">(</span>
<span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;</span><span class="p">,</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&gt;</span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vmap</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">axes</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span><span class="s">&quot;Axpby has no vmap implementation.&quot;</span><span class="p">);</span>
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span><span class="s">&quot;[Axpby] vmap not implemented.&quot;</span><span class="p">);</span>
<span class="p">}</span>
</pre></div>
</div>
@@ -1416,64 +1386,63 @@ own <code class="xref py py-class docutils literal notranslate"><span class="pre
<ul class="simple">
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/axpby/</span></code> defines the C++ extension library</p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> sets out the structure for the
associated python package</p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/bindings.cpp</span></code> provides python bindings for our operation</p></li>
associated Python package</p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/bindings.cpp</span></code> provides Python bindings for our operation</p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/CMakeLists.txt</span></code> holds CMake rules to build the library and
python bindings</p></li>
Python bindings</p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/setup.py</span></code> holds the <code class="docutils literal notranslate"><span class="pre">setuptools</span></code> rules to build and install
the python package</p></li>
the Python package</p></li>
</ul>
<section id="binding-to-python">
<h3>Binding to Python<a class="headerlink" href="#binding-to-python" title="Link to this heading">#</a></h3>
<p>We use <a class="reference external" href="https://pybind11.readthedocs.io/en/stable/">PyBind11</a> to build a Python API for the C++ library. Since bindings for
<p>We use <a class="reference external" href="https://nanobind.readthedocs.io/en/latest/">nanobind</a> to build a Python API for the C++ library. Since bindings for
components such as <a class="reference internal" href="../python/_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a>, <a class="reference internal" href="../python/_autosummary/mlx.core.stream.html#mlx.core.stream" title="mlx.core.stream"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.stream</span></code></a>, etc. are
already provided, adding our <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> is simple!</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">PYBIND11_MODULE</span><span class="p">(</span><span class="n">mlx_sample_extensions</span><span class="p">,</span><span class="w"> </span><span class="n">m</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">doc</span><span class="p">()</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s">&quot;Sample C++ and metal extensions for MLX&quot;</span><span class="p">;</span>
already provided, adding our <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> is simple.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">NB_MODULE</span><span class="p">(</span><span class="n">_ext</span><span class="p">,</span><span class="w"> </span><span class="n">m</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">doc</span><span class="p">()</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s">&quot;Sample extension for MLX&quot;</span><span class="p">;</span>
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">def</span><span class="p">(</span>
<span class="w"> </span><span class="s">&quot;axpby&quot;</span><span class="p">,</span>
<span class="w"> </span><span class="o">&amp;</span><span class="n">axpby</span><span class="p">,</span>
<span class="w"> </span><span class="s">&quot;x&quot;</span><span class="n">_a</span><span class="p">,</span>
<span class="w"> </span><span class="s">&quot;y&quot;</span><span class="n">_a</span><span class="p">,</span>
<span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">pos_only</span><span class="p">(),</span>
<span class="w"> </span><span class="s">&quot;alpha&quot;</span><span class="n">_a</span><span class="p">,</span>
<span class="w"> </span><span class="s">&quot;beta&quot;</span><span class="n">_a</span><span class="p">,</span>
<span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">kw_only</span><span class="p">(),</span>
<span class="w"> </span><span class="s">&quot;stream&quot;</span><span class="n">_a</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">none</span><span class="p">(),</span>
<span class="w"> </span><span class="sa">R</span><span class="s">&quot;</span><span class="dl">pbdoc(</span>
<span class="s"> Scale and sum two vectors element-wise</span>
<span class="s"> ``z = alpha * x + beta * y``</span>
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">def</span><span class="p">(</span>
<span class="w"> </span><span class="s">&quot;axpby&quot;</span><span class="p">,</span>
<span class="w"> </span><span class="o">&amp;</span><span class="n">axpby</span><span class="p">,</span>
<span class="w"> </span><span class="s">&quot;x&quot;</span><span class="n">_a</span><span class="p">,</span>
<span class="w"> </span><span class="s">&quot;y&quot;</span><span class="n">_a</span><span class="p">,</span>
<span class="w"> </span><span class="s">&quot;alpha&quot;</span><span class="n">_a</span><span class="p">,</span>
<span class="w"> </span><span class="s">&quot;beta&quot;</span><span class="n">_a</span><span class="p">,</span>
<span class="w"> </span><span class="n">nb</span><span class="o">::</span><span class="n">kw_only</span><span class="p">(),</span>
<span class="w"> </span><span class="s">&quot;stream&quot;</span><span class="n">_a</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">nb</span><span class="o">::</span><span class="n">none</span><span class="p">(),</span>
<span class="w"> </span><span class="sa">R</span><span class="s">&quot;</span><span class="dl">(</span>
<span class="s"> Scale and sum two vectors element-wise</span>
<span class="s"> ``z = alpha * x + beta * y``</span>
<span class="s"> Follows numpy style broadcasting between ``x`` and ``y``</span>
<span class="s"> Inputs are upcasted to floats if needed</span>
<span class="s"> Follows numpy style broadcasting between ``x`` and ``y``</span>
<span class="s"> Inputs are upcasted to floats if needed</span>
<span class="s"> Args:</span>
<span class="s"> x (array): Input array.</span>
<span class="s"> y (array): Input array.</span>
<span class="s"> alpha (float): Scaling factor for ``x``.</span>
<span class="s"> beta (float): Scaling factor for ``y``.</span>
<span class="s"> Args:</span>
<span class="s"> x (array): Input array.</span>
<span class="s"> y (array): Input array.</span>
<span class="s"> alpha (float): Scaling factor for ``x``.</span>
<span class="s"> beta (float): Scaling factor for ``y``.</span>
<span class="s"> Returns:</span>
<span class="s"> array: ``alpha * x + beta * y``</span>
<span class="s"> </span><span class="dl">)pbdoc</span><span class="s">&quot;</span><span class="p">);</span>
<span class="p">}</span>
<span class="s"> Returns:</span>
<span class="s"> array: ``alpha * x + beta * y``</span>
<span class="s"> </span><span class="dl">)</span><span class="s">&quot;</span><span class="p">);</span>
<span class="w"> </span><span class="p">}</span>
</pre></div>
</div>
<p>Most of the complexity in the above example comes from additional bells and
whistles such as the literal names and doc-strings.</p>
<div class="admonition warning">
<p class="admonition-title">Warning</p>
<p><code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code> needs to be imported before importing
<code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx_sample_extensions</span></code> as defined by the pybind11 module above to
<p><code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code> must be imported before importing
<code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx_sample_extensions</span></code> as defined by the nanobind module above to
ensure that the casters for <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code> components like
<a class="reference internal" href="../python/_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> are available.</p>
</div>
</section>
<section id="building-with-cmake">
<span id="id1"></span><h3>Building with CMake<a class="headerlink" href="#building-with-cmake" title="Link to this heading">#</a></h3>
<p>Building the C++ extension library itself is simple, it only requires that you
<code class="docutils literal notranslate"><span class="pre">find_package(MLX</span> <span class="pre">CONFIG)</span></code> and then link it to your library.</p>
<p>Building the C++ extension library only requires that you <code class="docutils literal notranslate"><span class="pre">find_package(MLX</span>
<span class="pre">CONFIG)</span></code> and then link it to your library.</p>
<div class="highlight-cmake notranslate"><div class="highlight"><pre><span></span><span class="c"># Add library</span>
<span class="nb">add_library</span><span class="p">(</span><span class="s">mlx_ext</span><span class="p">)</span>
@@ -1493,11 +1462,11 @@ ensure that the casters for <code class="xref py py-mod docutils literal notrans
<span class="nb">target_link_libraries</span><span class="p">(</span><span class="s">mlx_ext</span><span class="w"> </span><span class="s">PUBLIC</span><span class="w"> </span><span class="s">mlx</span><span class="p">)</span>
</pre></div>
</div>
<p>We also need to build the attached metal library. For convenience, we provide a
<p>We also need to build the attached Metal library. For convenience, we provide a
<code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx_build_metallib()</span></code> function that builds a <code class="docutils literal notranslate"><span class="pre">.metallib</span></code> target given
sources, headers, destinations, etc. (defined in <code class="docutils literal notranslate"><span class="pre">cmake/extension.cmake</span></code> and
automatically imported with MLX package).</p>
<p>Here is what that looks like in practice!</p>
<p>Here is what that looks like in practice:</p>
<div class="highlight-cmake notranslate"><div class="highlight"><pre><span></span><span class="c"># Build metallib</span>
<span class="nb">if</span><span class="p">(</span><span class="s">MLX_BUILD_METAL</span><span class="p">)</span>
@@ -1517,15 +1486,17 @@ automatically imported with MLX package).</p>
<span class="nb">endif</span><span class="p">()</span>
</pre></div>
</div>
<p>Finally, we build the <a class="reference external" href="https://pybind11.readthedocs.io/en/stable/">Pybind11</a> bindings</p>
<div class="highlight-cmake notranslate"><div class="highlight"><pre><span></span><span class="nb">pybind11_add_module</span><span class="p">(</span>
<span class="w"> </span><span class="s">mlx_sample_extensions</span>
<span class="w"> </span><span class="o">${</span><span class="nv">CMAKE_CURRENT_LIST_DIR</span><span class="o">}</span><span class="s">/bindings.cpp</span>
<p>Finally, we build the <a class="reference external" href="https://nanobind.readthedocs.io/en/latest/">nanobind</a> bindings</p>
<div class="highlight-cmake notranslate"><div class="highlight"><pre><span></span><span class="nb">nanobind_add_module</span><span class="p">(</span>
<span class="w"> </span><span class="s">_ext</span>
<span class="w"> </span><span class="s">NB_STATIC</span><span class="w"> </span><span class="s">STABLE_ABI</span><span class="w"> </span><span class="s">LTO</span><span class="w"> </span><span class="s">NOMINSIZE</span>
<span class="w"> </span><span class="s">NB_DOMAIN</span><span class="w"> </span><span class="s">mlx</span>
<span class="w"> </span><span class="o">${</span><span class="nv">CMAKE_CURRENT_LIST_DIR</span><span class="o">}</span><span class="s">/bindings.cpp</span>
<span class="p">)</span>
<span class="nb">target_link_libraries</span><span class="p">(</span><span class="s">mlx_sample_extensions</span><span class="w"> </span><span class="s">PRIVATE</span><span class="w"> </span><span class="s">mlx_ext</span><span class="p">)</span>
<span class="nb">target_link_libraries</span><span class="p">(</span><span class="s">_ext</span><span class="w"> </span><span class="s">PRIVATE</span><span class="w"> </span><span class="s">mlx_ext</span><span class="p">)</span>
<span class="nb">if</span><span class="p">(</span><span class="s">BUILD_SHARED_LIBS</span><span class="p">)</span>
<span class="w"> </span><span class="nb">target_link_options</span><span class="p">(</span><span class="s">mlx_sample_extensions</span><span class="w"> </span><span class="s">PRIVATE</span><span class="w"> </span><span class="s">-Wl,-rpath,@loader_path</span><span class="p">)</span>
<span class="w"> </span><span class="nb">target_link_options</span><span class="p">(</span><span class="s">_ext</span><span class="w"> </span><span class="s">PRIVATE</span><span class="w"> </span><span class="s">-Wl,-rpath,@loader_path</span><span class="p">)</span>
<span class="nb">endif</span><span class="p">()</span>
</pre></div>
</div>
@@ -1533,7 +1504,7 @@ automatically imported with MLX package).</p>
<section id="building-with-setuptools">
<h3>Building with <code class="docutils literal notranslate"><span class="pre">setuptools</span></code><a class="headerlink" href="#building-with-setuptools" title="Link to this heading">#</a></h3>
<p>Once we have set out the CMake build rules as described above, we can use the
build utilities defined in <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.extension</span></code> for a simple build process.</p>
build utilities defined in <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.extension</span></code>:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">mlx</span> <span class="kn">import</span> <span class="n">extension</span>
<span class="kn">from</span> <span class="nn">setuptools</span> <span class="kn">import</span> <span class="n">setup</span>
@@ -1542,13 +1513,13 @@ build utilities defined in <code class="xref py py-mod docutils literal notransl
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;mlx_sample_extensions&quot;</span><span class="p">,</span>
<span class="n">version</span><span class="o">=</span><span class="s2">&quot;0.0.0&quot;</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Sample C++ and Metal extensions for MLX primitives.&quot;</span><span class="p">,</span>
<span class="n">ext_modules</span><span class="o">=</span><span class="p">[</span><span class="n">extension</span><span class="o">.</span><span class="n">CMakeExtension</span><span class="p">(</span><span class="s2">&quot;mlx_sample_extensions&quot;</span><span class="p">)],</span>
<span class="n">ext_modules</span><span class="o">=</span><span class="p">[</span><span class="n">extension</span><span class="o">.</span><span class="n">CMakeExtension</span><span class="p">(</span><span class="s2">&quot;mlx_sample_extensions._ext&quot;</span><span class="p">)],</span>
<span class="n">cmdclass</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;build_ext&quot;</span><span class="p">:</span> <span class="n">extension</span><span class="o">.</span><span class="n">CMakeBuild</span><span class="p">},</span>
<span class="n">packages</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;mlx_sample_extensions&quot;</span><span class="p">],</span>
<span class="n">package_dir</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;&quot;</span><span class="p">:</span> <span class="s2">&quot;mlx_sample_extensions&quot;</span><span class="p">},</span>
<span class="n">package_data</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;mlx_sample_extensions&quot;</span> <span class="p">:</span> <span class="p">[</span><span class="s2">&quot;*.so&quot;</span><span class="p">,</span> <span class="s2">&quot;*.dylib&quot;</span><span class="p">,</span> <span class="s2">&quot;*.metallib&quot;</span><span class="p">]},</span>
<span class="n">packages</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;mlx_sample_extensions&quot;</span><span class="p">],</span>
<span class="n">package_data</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;mlx_sample_extensions&quot;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&quot;*.so&quot;</span><span class="p">,</span> <span class="s2">&quot;*.dylib&quot;</span><span class="p">,</span> <span class="s2">&quot;*.metallib&quot;</span><span class="p">]},</span>
<span class="n">extras_require</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;dev&quot;</span><span class="p">:[]},</span>
<span class="n">zip_safe</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">python_requires</span><span class="o">=</span><span class="s2">&quot;&gt;=3.7&quot;</span><span class="p">,</span>
<span class="n">python_requires</span><span class="o">=</span><span class="s2">&quot;&gt;=3.8&quot;</span><span class="p">,</span>
<span class="p">)</span>
</pre></div>
</div>
@@ -1557,34 +1528,36 @@ build utilities defined in <code class="xref py py-mod docutils literal notransl
<p>We treat <code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> as the package directory
even though it only contains a <code class="docutils literal notranslate"><span class="pre">__init__.py</span></code> to ensure the following:</p>
<ul class="simple">
<li><p><code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code> is always imported before importing <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx_sample_extensions</span></code></p></li>
<li><p><code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code> must be imported before importing <code class="xref py py-mod docutils literal notranslate"><span class="pre">_ext</span></code></p></li>
<li><p>The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed</p></li>
</ul>
</div>
<p>You can build inplace for development using
<p>To build the package, first install the build dependencies with <code class="docutils literal notranslate"><span class="pre">pip</span> <span class="pre">install</span>
<span class="pre">-r</span> <span class="pre">requirements.txt</span></code>. You can then build inplace for development using
<code class="docutils literal notranslate"><span class="pre">python</span> <span class="pre">setup.py</span> <span class="pre">build_ext</span> <span class="pre">-j8</span> <span class="pre">--inplace</span></code> (in <code class="docutils literal notranslate"><span class="pre">extensions/</span></code>)</p>
<p>This will result in a directory structure as follows:</p>
<p>This results in the directory structure:</p>
<div class="line-block">
<div class="line">extensions</div>
<div class="line">├── mlx_sample_extensions</div>
<div class="line">│ ├── __init__.py</div>
<div class="line">│ ├── libmlx_ext.dylib # C++ extension library</div>
<div class="line">│ ├── mlx_ext.metallib # Metal library</div>
<div class="line">│ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding</div>
<div class="line">│ └── _ext.cpython-3x-darwin.so # Python Binding</div>
<div class="line"></div>
</div>
<p>When you try to install using the command <code class="docutils literal notranslate"><span class="pre">python</span> <span class="pre">-m</span> <span class="pre">pip</span> <span class="pre">install</span> <span class="pre">.</span></code>
(in <code class="docutils literal notranslate"><span class="pre">extensions/</span></code>), the package will be installed with the same structure as
<code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> and the C++ and metal library will be
copied along with the python binding since they are specified as <code class="docutils literal notranslate"><span class="pre">package_data</span></code>.</p>
<p>When you try to install using the command <code class="docutils literal notranslate"><span class="pre">python</span> <span class="pre">-m</span> <span class="pre">pip</span> <span class="pre">install</span> <span class="pre">.</span></code> (in
<code class="docutils literal notranslate"><span class="pre">extensions/</span></code>), the package will be installed with the same structure as
<code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> and the C++ and Metal library will be
copied along with the Python binding since they are specified as
<code class="docutils literal notranslate"><span class="pre">package_data</span></code>.</p>
</section>
</section>
<section id="usage">
<h2>Usage<a class="headerlink" href="#usage" title="Link to this heading">#</a></h2>
<p>After installing the extension as described above, you should be able to simply
import the python package and play with it as you would any other MLX operation!</p>
<p>Lets looks at a simple script and its results!</p>
import the Python package and play with it as you would any other MLX operation.</p>
<p>Lets look at a simple script and its results:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">from</span> <span class="nn">mlx_sample_extensions</span> <span class="kn">import</span> <span class="n">axpby</span>
@@ -1606,7 +1579,7 @@ import the python package and play with it as you would any other MLX operation!
<section id="results">
<h3>Results<a class="headerlink" href="#results" title="Link to this heading">#</a></h3>
<p>Lets run a quick benchmark and see how our new <code class="docutils literal notranslate"><span class="pre">axpby</span></code> operation compares
with the naive <code class="xref py py-meth docutils literal notranslate"><span class="pre">simple_axpby()</span></code> we defined at first on the CPU.</p>
with the naive <code class="xref py py-meth docutils literal notranslate"><span class="pre">simple_axpby()</span></code> we first defined on the CPU.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">from</span> <span class="nn">mlx_sample_extensions</span> <span class="kn">import</span> <span class="n">axpby</span>
<span class="kn">import</span> <span class="nn">time</span>
@@ -1624,7 +1597,7 @@ with the naive <code class="xref py py-meth docutils literal notranslate"><span
<span class="n">alpha</span> <span class="o">=</span> <span class="mf">4.0</span>
<span class="n">beta</span> <span class="o">=</span> <span class="mf">2.0</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">bench</span><span class="p">(</span><span class="n">f</span><span class="p">):</span>
<span class="c1"># Warm up</span>
@@ -1646,21 +1619,18 @@ with the naive <code class="xref py py-meth docutils literal notranslate"><span
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Simple axpby: </span><span class="si">{</span><span class="n">simple_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> s | Custom axpby: </span><span class="si">{</span><span class="n">custom_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> s&quot;</span><span class="p">)</span>
</pre></div>
</div>
<p>Results:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">Simple</span> <span class="n">axpby</span><span class="p">:</span> <span class="mf">0.114</span> <span class="n">s</span> <span class="o">|</span> <span class="n">Custom</span> <span class="n">axpby</span><span class="p">:</span> <span class="mf">0.109</span> <span class="n">s</span>
</pre></div>
</div>
<p>We see some modest improvements right away!</p>
<p>The results are <code class="docutils literal notranslate"><span class="pre">Simple</span> <span class="pre">axpby:</span> <span class="pre">0.114</span> <span class="pre">s</span> <span class="pre">|</span> <span class="pre">Custom</span> <span class="pre">axpby:</span> <span class="pre">0.109</span> <span class="pre">s</span></code>. We see
modest improvements right away!</p>
<p>This operation is now good to be used to build other operations, in
<a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> calls, and also as a part of graph transformations like
<code class="xref py py-meth docutils literal notranslate"><span class="pre">grad()</span></code>!</p>
<code class="xref py py-meth docutils literal notranslate"><span class="pre">grad()</span></code>.</p>
</section>
</section>
<section id="scripts">
<h2>Scripts<a class="headerlink" href="#scripts" title="Link to this heading">#</a></h2>
<div class="admonition-download-the-code admonition">
<p class="admonition-title">Download the code</p>
<p>The full example code is available in <a class="reference external" href="code">mlx</a>.</p>
<p>The full example code is available in <a class="reference external" href="https://github.com/ml-explore/mlx/tree/main/examples/extensions/">mlx</a>.</p>
</div>
</section>
</section>
@@ -1713,12 +1683,12 @@ with the naive <code class="xref py py-meth docutils literal notranslate"><span
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#operations-and-primitives">Operations and Primitives</a><ul class="visible nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#operations">Operations</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#primitives">Primitives</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#using-the-primitives">Using the Primitives</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#using-the-primitive">Using the Primitive</a></li>
</ul>
</li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-primitive">Implementing the Primitive</a><ul class="visible nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-cpu-backend">Implementing the CPU Backend</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-gpu-backend">Implementing the GPU Backend</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-cpu-back-end">Implementing the CPU Back-end</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-gpu-back-end">Implementing the GPU Back-end</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#primitive-transforms">Primitive Transforms</a></li>
</ul>
</li>

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<title>Metal Debugger &#8212; MLX 0.9.0 documentation</title>
<title>Metal Debugger &#8212; MLX 0.10.0 documentation</title>
@@ -36,7 +36,7 @@
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=5b4479735964841361fd" />
<script src="../_static/vendor/fontawesome/6.1.2/js/all.min.js?digest=5b4479735964841361fd"></script>
<script src="../_static/documentation_options.js?v=2a76c96f"></script>
<script src="../_static/documentation_options.js?v=cb265169"></script>
<script src="../_static/doctools.js?v=888ff710"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=efea14e4"></script>
@@ -130,8 +130,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.9.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.9.0 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.10.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.10.0 documentation - Home"/>`);</script>
</a></div>
@@ -285,6 +285,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erf.html">mlx.core.erf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erfinv.html">mlx.core.erfinv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exp.html">mlx.core.exp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expm1.html">mlx.core.expm1</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expand_dims.html">mlx.core.expand_dims</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eye.html">mlx.core.eye</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.flatten.html">mlx.core.flatten</a></li>
@@ -317,6 +318,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.max.html">mlx.core.max</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.maximum.html">mlx.core.maximum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.mean.html">mlx.core.mean</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.meshgrid.html">mlx.core.meshgrid</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.min.html">mlx.core.min</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.minimum.html">mlx.core.minimum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.moveaxis.html">mlx.core.moveaxis</a></li>
@@ -351,6 +353,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.square.html">mlx.core.square</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.squeeze.html">mlx.core.squeeze</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stack.html">mlx.core.stack</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.std.html">mlx.core.std</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stop_gradient.html">mlx.core.stop_gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.subtract.html">mlx.core.subtract</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sum.html">mlx.core.sum</a></li>
@@ -378,6 +381,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.gumbel.html">mlx.core.random.gumbel</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.key.html">mlx.core.random.key</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.normal.html">mlx.core.random.normal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.multivariate_normal.html">mlx.core.random.multivariate_normal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.randint.html">mlx.core.random.randint</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.seed.html">mlx.core.random.seed</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.split.html">mlx.core.random.split</a></li>
@@ -431,6 +435,8 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.get_cache_memory.html">mlx.core.metal.get_cache_memory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.set_memory_limit.html">mlx.core.metal.set_memory_limit</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.set_cache_limit.html">mlx.core.metal.set_cache_limit</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.start_capture.html">mlx.core.metal.start_capture</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.stop_capture.html">mlx.core.metal.stop_capture</a></li>
</ul>
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/nn.html">Neural Networks</a><input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-11"><i class="fa-solid fa-chevron-down"></i></label><ul>
@@ -773,25 +779,39 @@ document.write(`
<section id="metal-debugger">
<h1>Metal Debugger<a class="headerlink" href="#metal-debugger" title="Link to this heading">#</a></h1>
<p>Profiling is a key step for performance optimization. You can build MLX with
the <code class="docutils literal notranslate"><span class="pre">MLX_METAL_DEBUG</span></code> option to improve the Metal debugging and optimization
workflow. The <code class="docutils literal notranslate"><span class="pre">MLX_METAL_DEBUG</span></code> debug option:</p>
the <code class="docutils literal notranslate"><span class="pre">MLX_METAL_DEBUG</span></code> option to improve the Metal debugging and
optimization workflow. The <code class="docutils literal notranslate"><span class="pre">MLX_METAL_DEBUG</span></code> debug option:</p>
<ul class="simple">
<li><p>Records source during Metal compilation, for later inspection while
debugging.</p></li>
<li><p>Labels Metal objects such as command queues, improving capture readability.</p></li>
</ul>
<p>The <code class="docutils literal notranslate"><span class="pre">metal::start_capture</span></code> function initiates a capture of all MLX GPU work.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="kt">int</span><span class="w"> </span><span class="nf">main</span><span class="p">()</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="n">metal</span><span class="o">::</span><span class="n">start_capture</span><span class="p">(</span><span class="s">&quot;/Users/Jane/Developer/MLX.gputrace&quot;</span><span class="p">);</span>
<p>To build with debugging enabled in Python prepend
<code class="docutils literal notranslate"><span class="pre">CMAKE_ARGS=&quot;-DMLX_METAL_DEBUG=ON&quot;</span></code> to the build call.</p>
<p>The <a class="reference internal" href="../python/_autosummary/mlx.core.metal.start_capture.html#mlx.core.metal.start_capture" title="mlx.core.metal.start_capture"><code class="xref py py-func docutils literal notranslate"><span class="pre">metal.start_capture()</span></code></a> function initiates a capture of all MLX GPU
work.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>To capture a GPU trace you must run the application with
<code class="docutils literal notranslate"><span class="pre">MTL_CAPTURE_ENABLED=1</span></code>.</p>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">a</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">arange</span><span class="p">(</span><span class="mf">10.f</span><span class="p">,</span><span class="w"> </span><span class="mf">20.f</span><span class="p">,</span><span class="w"> </span><span class="mf">1.f</span><span class="p">,</span><span class="w"> </span><span class="n">float32</span><span class="p">);</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">b</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">arange</span><span class="p">(</span><span class="mf">30.f</span><span class="p">,</span><span class="w"> </span><span class="mf">40.f</span><span class="p">,</span><span class="w"> </span><span class="mf">1.f</span><span class="p">,</span><span class="w"> </span><span class="n">float32</span><span class="p">);</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">c</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">add</span><span class="p">(</span><span class="n">a</span><span class="p">,</span><span class="w"> </span><span class="n">b</span><span class="p">);</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">))</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">))</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="w"> </span><span class="n">eval</span><span class="p">(</span><span class="n">c</span><span class="p">);</span>
<span class="n">trace_file</span> <span class="o">=</span> <span class="s2">&quot;mlx_trace.gputrace&quot;</span>
<span class="w"> </span><span class="n">metal</span><span class="o">::</span><span class="n">stop_capture</span><span class="p">();</span>
<span class="p">}</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">mx</span><span class="o">.</span><span class="n">metal</span><span class="o">.</span><span class="n">start_capture</span><span class="p">(</span><span class="n">trace_file</span><span class="p">):</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Make sure to run with MTL_CAPTURE_ENABLED=1 and &quot;</span>
<span class="sa">f</span><span class="s2">&quot;that the path </span><span class="si">{</span><span class="n">trace_file</span><span class="si">}</span><span class="s2"> does not already exist.&quot;</span><span class="p">)</span>
<span class="n">exit</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
<span class="n">mx</span><span class="o">.</span><span class="n">metal</span><span class="o">.</span><span class="n">stop_capture</span><span class="p">()</span>
</pre></div>
</div>
<p>You can open and replay the GPU trace in Xcode. The <code class="docutils literal notranslate"><span class="pre">Dependencies</span></code> view
@@ -800,8 +820,8 @@ documentation</a> for more information.</p>
<img alt="../_images/capture.png" class="dark-light" src="../_images/capture.png" />
<section id="xcode-workflow">
<h2>Xcode Workflow<a class="headerlink" href="#xcode-workflow" title="Link to this heading">#</a></h2>
<p>You can skip saving to a path by running within Xcode. First, generate an Xcode
project using CMake.</p>
<p>You can skip saving to a path by running within Xcode. First, generate an
Xcode project using CMake.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">mkdir</span> <span class="n">build</span> <span class="o">&amp;&amp;</span> <span class="n">cd</span> <span class="n">build</span>
<span class="n">cmake</span> <span class="o">..</span> <span class="o">-</span><span class="n">DMLX_METAL_DEBUG</span><span class="o">=</span><span class="n">ON</span> <span class="o">-</span><span class="n">G</span> <span class="n">Xcode</span>
<span class="nb">open</span> <span class="n">mlx</span><span class="o">.</span><span class="n">xcodeproj</span>