mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
docs update
This commit is contained in:

committed by
CircleCI Docs

parent
1d2cadbc78
commit
f77d99b285
468
docs/build/html/dev/extensions.html
vendored
468
docs/build/html/dev/extensions.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
|
||||
|
||||
<title>Developer Documentation — MLX 0.9.0 documentation</title>
|
||||
<title>Developer Documentation — MLX 0.10.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=5b4479735964841361fd" />
|
||||
<script src="../_static/vendor/fontawesome/6.1.2/js/all.min.js?digest=5b4479735964841361fd"></script>
|
||||
|
||||
<script src="../_static/documentation_options.js?v=2a76c96f"></script>
|
||||
<script src="../_static/documentation_options.js?v=cb265169"></script>
|
||||
<script src="../_static/doctools.js?v=888ff710"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=efea14e4"></script>
|
||||
@@ -131,8 +131,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.9.0 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.9.0 documentation - Home"/>`);</script>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.10.0 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.10.0 documentation - Home"/>`);</script>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -286,6 +286,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erf.html">mlx.core.erf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erfinv.html">mlx.core.erfinv</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exp.html">mlx.core.exp</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expm1.html">mlx.core.expm1</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expand_dims.html">mlx.core.expand_dims</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eye.html">mlx.core.eye</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.flatten.html">mlx.core.flatten</a></li>
|
||||
@@ -318,6 +319,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.max.html">mlx.core.max</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.maximum.html">mlx.core.maximum</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.mean.html">mlx.core.mean</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.meshgrid.html">mlx.core.meshgrid</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.min.html">mlx.core.min</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.minimum.html">mlx.core.minimum</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.moveaxis.html">mlx.core.moveaxis</a></li>
|
||||
@@ -352,6 +354,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.square.html">mlx.core.square</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.squeeze.html">mlx.core.squeeze</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stack.html">mlx.core.stack</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.std.html">mlx.core.std</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stop_gradient.html">mlx.core.stop_gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.subtract.html">mlx.core.subtract</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sum.html">mlx.core.sum</a></li>
|
||||
@@ -379,6 +382,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.gumbel.html">mlx.core.random.gumbel</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.key.html">mlx.core.random.key</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.normal.html">mlx.core.random.normal</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.multivariate_normal.html">mlx.core.random.multivariate_normal</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.randint.html">mlx.core.random.randint</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.seed.html">mlx.core.random.seed</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.split.html">mlx.core.random.split</a></li>
|
||||
@@ -432,6 +436,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.get_cache_memory.html">mlx.core.metal.get_cache_memory</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.set_memory_limit.html">mlx.core.metal.set_memory_limit</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.set_cache_limit.html">mlx.core.metal.set_cache_limit</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.start_capture.html">mlx.core.metal.start_capture</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.stop_capture.html">mlx.core.metal.stop_capture</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/nn.html">Neural Networks</a><input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-11"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
@@ -763,12 +769,12 @@ document.write(`
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#operations-and-primitives">Operations and Primitives</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#operations">Operations</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#primitives">Primitives</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#using-the-primitives">Using the Primitives</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#using-the-primitive">Using the Primitive</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-primitive">Implementing the Primitive</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-cpu-backend">Implementing the CPU Backend</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-gpu-backend">Implementing the GPU Backend</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-cpu-back-end">Implementing the CPU Back-end</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-gpu-back-end">Implementing the GPU Back-end</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#primitive-transforms">Primitive Transforms</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
@@ -796,61 +802,45 @@ document.write(`
|
||||
|
||||
<section id="developer-documentation">
|
||||
<h1>Developer Documentation<a class="headerlink" href="#developer-documentation" title="Link to this heading">#</a></h1>
|
||||
<p>MLX provides a open and flexible backend to which users may add operations
|
||||
and specialized implementations without much hassle. While the library supplies
|
||||
efficient operations that can be used and composed for any number of
|
||||
applications, there may arise cases where new functionalities or highly
|
||||
optimized implementations are needed. For such cases, you may design and
|
||||
implement your own operations that link to and build on top of <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code>.
|
||||
We will introduce the inner-workings of MLX and go over a simple example to
|
||||
learn the steps involved in adding new operations to MLX with your own CPU
|
||||
and GPU implementations.</p>
|
||||
<p>You can extend MLX with custom operations on the CPU or GPU. This guide
|
||||
explains how to do that with a simple example.</p>
|
||||
<section id="introducing-the-example">
|
||||
<h2>Introducing the Example<a class="headerlink" href="#introducing-the-example" title="Link to this heading">#</a></h2>
|
||||
<p>Let’s say that you would like an operation that takes in two arrays,
|
||||
<code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>, scales them both by some coefficients <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>
|
||||
respectively, and then adds them together to get the result
|
||||
<code class="docutils literal notranslate"><span class="pre">z</span> <span class="pre">=</span> <span class="pre">alpha</span> <span class="pre">*</span> <span class="pre">x</span> <span class="pre">+</span> <span class="pre">beta</span> <span class="pre">*</span> <span class="pre">y</span></code>. Well, you can very easily do that by just
|
||||
writing out a function as follows:</p>
|
||||
<p>Let’s say you would like an operation that takes in two arrays, <code class="docutils literal notranslate"><span class="pre">x</span></code> and
|
||||
<code class="docutils literal notranslate"><span class="pre">y</span></code>, scales them both by coefficients <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code> respectively,
|
||||
and then adds them together to get the result <code class="docutils literal notranslate"><span class="pre">z</span> <span class="pre">=</span> <span class="pre">alpha</span> <span class="pre">*</span> <span class="pre">x</span> <span class="pre">+</span> <span class="pre">beta</span> <span class="pre">*</span> <span class="pre">y</span></code>.
|
||||
You can do that in MLX directly:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">simple_axpby</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">y</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>This function performs that operation while leaving the implementations and
|
||||
differentiation to MLX.</p>
|
||||
<p>However, you work with vector math libraries often and realize that the
|
||||
<code class="docutils literal notranslate"><span class="pre">axpby</span></code> routine defines the same operation <code class="docutils literal notranslate"><span class="pre">Y</span> <span class="pre">=</span> <span class="pre">(alpha</span> <span class="pre">*</span> <span class="pre">X)</span> <span class="pre">+</span> <span class="pre">(beta</span> <span class="pre">*</span> <span class="pre">Y)</span></code>.
|
||||
You would really like the part of your applications that does this operation
|
||||
on the CPU to be very fast - so you decide that you want it to rely on the
|
||||
<code class="docutils literal notranslate"><span class="pre">axpby</span></code> routine provided by the <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a> framework. Continuing to impose
|
||||
our assumptions on to you, let’s also assume that you want to learn how to add
|
||||
your own implementation for the gradients of your new operation while going
|
||||
over the ins-and-outs of the MLX framework.</p>
|
||||
<p>Well, what a coincidence! You are in the right place. Over the course of this
|
||||
example, we will learn:</p>
|
||||
<p>This function performs that operation while leaving the implementation and
|
||||
function transformations to MLX.</p>
|
||||
<p>However you may need to customize the underlying implementation, perhaps to
|
||||
make it faster or for custom differentiation. In this tutorial we will go
|
||||
through adding custom extensions. It will cover:</p>
|
||||
<ul class="simple">
|
||||
<li><p>The structure of the MLX library from the frontend API to the backend implementations.</p></li>
|
||||
<li><p>How to implement your own CPU backend that redirects to <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a> when appropriate (and a fallback if needed).</p></li>
|
||||
<li><p>How to implement your own GPU implementation using metal.</p></li>
|
||||
<li><p>How to add your own <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code>.</p></li>
|
||||
<li><p>How to build your implementations, link them to MLX, and bind them to python.</p></li>
|
||||
<li><p>The structure of the MLX library.</p></li>
|
||||
<li><p>Implementing a CPU operation that redirects to <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a> when appropriate.</p></li>
|
||||
<li><p>Implementing a GPU operation using metal.</p></li>
|
||||
<li><p>Adding the <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code> function transformation.</p></li>
|
||||
<li><p>Building a custom extension and binding it to python.</p></li>
|
||||
</ul>
|
||||
</section>
|
||||
<section id="operations-and-primitives">
|
||||
<h2>Operations and Primitives<a class="headerlink" href="#operations-and-primitives" title="Link to this heading">#</a></h2>
|
||||
<p>In one sentence, operations in MLX build the computation graph, and primitives
|
||||
provide the rules for evaluation and transformations of said graph. Let’s start
|
||||
by discussing operations in more detail.</p>
|
||||
<p>Operations in MLX build the computation graph. Primitives provide the rules for
|
||||
evaluating and transforming the graph. Let’s start by discussing operations in
|
||||
more detail.</p>
|
||||
<section id="operations">
|
||||
<h3>Operations<a class="headerlink" href="#operations" title="Link to this heading">#</a></h3>
|
||||
<p>Operations are the frontend functions that operate on arrays. They are defined
|
||||
in the C++ API (<a class="reference internal" href="../cpp/ops.html#cpp-ops"><span class="std std-ref">Operations</span></a>) and then we provide bindings to these
|
||||
operations in the Python API (<a class="reference internal" href="../python/ops.html#ops"><span class="std std-ref">Operations</span></a>).</p>
|
||||
<p>We would like an operation, <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> that takes in two arrays <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>,
|
||||
and two scalars, <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>. This is how we would define it in the
|
||||
C++ API:</p>
|
||||
<p>Operations are the front-end functions that operate on arrays. They are defined
|
||||
in the C++ API (<a class="reference internal" href="../cpp/ops.html#cpp-ops"><span class="std std-ref">Operations</span></a>), and the Python API (<a class="reference internal" href="../python/ops.html#ops"><span class="std std-ref">Operations</span></a>) binds them.</p>
|
||||
<p>We would like an operation, <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> that takes in two arrays <code class="docutils literal notranslate"><span class="pre">x</span></code> and
|
||||
<code class="docutils literal notranslate"><span class="pre">y</span></code>, and two scalars, <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>. This is how to define it in
|
||||
C++:</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/**</span>
|
||||
<span class="cm">* Scale and sum two vectors element-wise</span>
|
||||
<span class="cm">* z = alpha * x + beta * y</span>
|
||||
@@ -867,9 +857,7 @@ C++ API:</p>
|
||||
<span class="p">);</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>This operation itself can call other operations within it if needed. So, the
|
||||
simplest way to go about implementing this operation would be do so in terms
|
||||
of existing operations.</p>
|
||||
<p>The simplest way to this operation is in terms of existing operations:</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array x</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array y</span>
|
||||
@@ -886,19 +874,17 @@ of existing operations.</p>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>However, as we discussed earlier, this is not our goal. The operations themselves
|
||||
do not contain the implementations that act on the data, nor do they contain the
|
||||
rules of transformations. Rather, they are an easy to use interface that build
|
||||
on top of the building blocks we call <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>.</p>
|
||||
<p>The operations themselves do not contain the implementations that act on the
|
||||
data, nor do they contain the rules of transformations. Rather, they are an
|
||||
easy to use interface that use <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> building blocks.</p>
|
||||
</section>
|
||||
<section id="primitives">
|
||||
<h3>Primitives<a class="headerlink" href="#primitives" title="Link to this heading">#</a></h3>
|
||||
<p>A <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> is part of the computation graph of an <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code>. It
|
||||
defines how to create an output given a set of input <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> . Further,
|
||||
a <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> is a class that contains rules on how it is evaluated
|
||||
on the CPU or GPU, and how it acts under transformations such as <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and
|
||||
<code class="docutils literal notranslate"><span class="pre">jvp</span></code>. These words on their own can be a bit abstract, so lets take a step
|
||||
back and go to our example to give ourselves a more concrete image.</p>
|
||||
defines how to create outputs arrays given a input arrays. Further, a
|
||||
<code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> has methods to run on the CPU or GPU and for function
|
||||
transformations such as <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code>. Lets go back to our example to be
|
||||
more concrete:</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">Axpby</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="k">public</span><span class="w"> </span><span class="n">Primitive</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">public</span><span class="o">:</span>
|
||||
<span class="w"> </span><span class="k">explicit</span><span class="w"> </span><span class="n">Axpby</span><span class="p">(</span><span class="n">Stream</span><span class="w"> </span><span class="n">stream</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta</span><span class="p">)</span>
|
||||
@@ -911,11 +897,15 @@ back and go to our example to give ourselves a more concrete image.</p>
|
||||
<span class="cm"> * To avoid unnecessary allocations, the evaluation function</span>
|
||||
<span class="cm"> * is responsible for allocating space for the array.</span>
|
||||
<span class="cm"> */</span>
|
||||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_cpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_gpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_cpu</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_gpu</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
|
||||
<span class="w"> </span><span class="cm">/** The Jacobian-vector product. */</span>
|
||||
<span class="w"> </span><span class="n">array</span><span class="w"> </span><span class="nf">jvp</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">jvp</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
@@ -924,7 +914,8 @@ back and go to our example to give ourselves a more concrete image.</p>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">vjp</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
|
||||
<span class="w"> </span><span class="cm">/**</span>
|
||||
<span class="cm"> * The primitive must know how to vectorize itself across</span>
|
||||
@@ -932,7 +923,7 @@ back and go to our example to give ourselves a more concrete image.</p>
|
||||
<span class="cm"> * representing the vectorized computation and the axis which</span>
|
||||
<span class="cm"> * corresponds to the output vectorized dimension.</span>
|
||||
<span class="cm"> */</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="n">array</span><span class="p">,</span><span class="w"> </span><span class="kt">int</span><span class="o">></span><span class="w"> </span><span class="n">vmap</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">virtual</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="p">,</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>></span><span class="w"> </span><span class="n">vmap</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">axes</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
|
||||
@@ -953,20 +944,20 @@ back and go to our example to give ourselves a more concrete image.</p>
|
||||
<span class="p">};</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> class derives from the base <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> class and
|
||||
follows the above demonstrated interface. <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> treats <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and
|
||||
<code class="docutils literal notranslate"><span class="pre">beta</span></code> as parameters. It then provides implementations of how the array <code class="docutils literal notranslate"><span class="pre">out</span></code>
|
||||
is produced given <code class="docutils literal notranslate"><span class="pre">inputs</span></code> through <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code> and
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code>. Further, it provides rules of transformations in
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::jvp()</span></code>, <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::vjp()</span></code>, and <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::vmap()</span></code>.</p>
|
||||
<p>The <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> class derives from the base <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> class. The
|
||||
<code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> treats <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code> as parameters. It then provides
|
||||
implementations of how the output array is produced given the inputs through
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code> and <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code>. It also provides rules
|
||||
of transformations in <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::jvp()</span></code>, <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::vjp()</span></code>, and
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::vmap()</span></code>.</p>
|
||||
</section>
|
||||
<section id="using-the-primitives">
|
||||
<h3>Using the Primitives<a class="headerlink" href="#using-the-primitives" title="Link to this heading">#</a></h3>
|
||||
<p>Operations can use this <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> to add a new <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> to
|
||||
the computation graph. An <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> can be constructed by providing its
|
||||
data type, shape, the <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> that computes it, and the
|
||||
<code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> inputs that are passed to the primitive.</p>
|
||||
<p>Let’s re-implement our operation now in terms of our <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> primitive.</p>
|
||||
<section id="using-the-primitive">
|
||||
<h3>Using the Primitive<a class="headerlink" href="#using-the-primitive" title="Link to this heading">#</a></h3>
|
||||
<p>Operations can use this <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> to add a new <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> to the
|
||||
computation graph. An <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> can be constructed by providing its data
|
||||
type, shape, the <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> that computes it, and the <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code>
|
||||
inputs that are passed to the primitive.</p>
|
||||
<p>Let’s reimplement our operation now in terms of our <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> primitive.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array x</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array y</span>
|
||||
@@ -1012,25 +1003,24 @@ data type, shape, the <code class="xref py py-class docutils literal notranslate
|
||||
</section>
|
||||
<section id="implementing-the-primitive">
|
||||
<h2>Implementing the Primitive<a class="headerlink" href="#implementing-the-primitive" title="Link to this heading">#</a></h2>
|
||||
<p>No computation happens when we call the operation alone. In effect, the
|
||||
operation only builds the computation graph. When we evaluate the output
|
||||
array, MLX schedules the execution of the computation graph, and calls
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code> or <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code> depending on the
|
||||
stream/device specified by the user.</p>
|
||||
<p>No computation happens when we call the operation alone. The operation only
|
||||
builds the computation graph. When we evaluate the output array, MLX schedules
|
||||
the execution of the computation graph, and calls <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code> or
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code> depending on the stream/device specified by the user.</p>
|
||||
<div class="admonition warning">
|
||||
<p class="admonition-title">Warning</p>
|
||||
<p>When <code class="xref py py-meth docutils literal notranslate"><span class="pre">Primitive::eval_cpu()</span></code> or <code class="xref py py-meth docutils literal notranslate"><span class="pre">Primitive::eval_gpu()</span></code> are called,
|
||||
no memory has been allocated for the output array. It falls on the implementation
|
||||
of these functions to allocate memory as needed</p>
|
||||
of these functions to allocate memory as needed.</p>
|
||||
</div>
|
||||
<section id="implementing-the-cpu-backend">
|
||||
<h3>Implementing the CPU Backend<a class="headerlink" href="#implementing-the-cpu-backend" title="Link to this heading">#</a></h3>
|
||||
<p>Let’s start by trying to implement a naive and generic version of
|
||||
<section id="implementing-the-cpu-back-end">
|
||||
<h3>Implementing the CPU Back-end<a class="headerlink" href="#implementing-the-cpu-back-end" title="Link to this heading">#</a></h3>
|
||||
<p>Let’s start by implementing a naive and generic version of
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>. We declared this as a private member function of
|
||||
<code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> earlier called <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval()</span></code>.</p>
|
||||
<p>Our naive method will go over each element of the output array, find the
|
||||
corresponding input elements of <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> and perform the operation
|
||||
pointwise. This is captured in the templated function <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby_impl()</span></code>.</p>
|
||||
point-wise. This is captured in the templated function <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby_impl()</span></code>.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o"><</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">></span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="n">axpby_impl</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span>
|
||||
@@ -1066,16 +1056,16 @@ pointwise. This is captured in the templated function <code class="xref py py-me
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Now, we would like our implementation to be able to do this pointwise operation
|
||||
for all incoming floating point arrays. Accordingly, we add dispatches for
|
||||
<code class="docutils literal notranslate"><span class="pre">float32</span></code>, <code class="docutils literal notranslate"><span class="pre">float16</span></code>, <code class="docutils literal notranslate"><span class="pre">bfloat16</span></code> and <code class="docutils literal notranslate"><span class="pre">complex64</span></code>. We throw an error
|
||||
if we encounter an unexpected type.</p>
|
||||
<p>Our implementation should work for all incoming floating point arrays.
|
||||
Accordingly, we add dispatches for <code class="docutils literal notranslate"><span class="pre">float32</span></code>, <code class="docutils literal notranslate"><span class="pre">float16</span></code>, <code class="docutils literal notranslate"><span class="pre">bfloat16</span></code> and
|
||||
<code class="docutils literal notranslate"><span class="pre">complex64</span></code>. We throw an error if we encounter an unexpected type.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Fall back implementation for evaluation on CPU */</span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="c1">// Check the inputs (registered in the op while constructing the out array)</span>
|
||||
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Dispatch to the correct dtype</span>
|
||||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float32</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
@@ -1088,29 +1078,27 @@ if we encounter an unexpected type.</p>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">complex64_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="s">"Axpby is only supported for floating point types."</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="s">"[Axpby] Only supports floating point types."</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We have a fallback implementation! Now, to do what we are really here to do.
|
||||
Remember we wanted to use the <code class="docutils literal notranslate"><span class="pre">axpby</span></code> routine provided by the <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a>
|
||||
framework? Well, there are 3 complications to keep in mind:</p>
|
||||
<p>This is good as a fallback implementation. We can use the <code class="docutils literal notranslate"><span class="pre">axpby</span></code> routine
|
||||
provided by the <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a> framework for a faster implementation in certain
|
||||
cases:</p>
|
||||
<ol class="arabic simple">
|
||||
<li><p>Accelerate does not provide implementations of <code class="docutils literal notranslate"><span class="pre">axpby</span></code> for half precision
|
||||
floats. We can only direct to it for <code class="docutils literal notranslate"><span class="pre">float32</span></code> types</p></li>
|
||||
<li><p>Accelerate assumes the inputs <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> are contiguous and all elements
|
||||
have fixed strides between them. Possibly due to broadcasts and transposes,
|
||||
we aren’t guaranteed that the inputs fit this requirement. We can
|
||||
only direct to Accelerate if both <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> are row contiguous or
|
||||
column contiguous.</p></li>
|
||||
<li><p>Accelerate performs the routine <code class="docutils literal notranslate"><span class="pre">Y</span> <span class="pre">=</span> <span class="pre">(alpha</span> <span class="pre">*</span> <span class="pre">X)</span> <span class="pre">+</span> <span class="pre">(beta</span> <span class="pre">*</span> <span class="pre">Y)</span></code> inplace.
|
||||
MLX expects to write out the answer to a new array. We must copy the elements
|
||||
of <code class="docutils literal notranslate"><span class="pre">y</span></code> into the output array and use that as an input to <code class="docutils literal notranslate"><span class="pre">axpby</span></code></p></li>
|
||||
floats. We can only use it for <code class="docutils literal notranslate"><span class="pre">float32</span></code> types.</p></li>
|
||||
<li><p>Accelerate assumes the inputs <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> are contiguous and all
|
||||
elements have fixed strides between them. We only direct to Accelerate
|
||||
if both <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> are row contiguous or column contiguous.</p></li>
|
||||
<li><p>Accelerate performs the routine <code class="docutils literal notranslate"><span class="pre">Y</span> <span class="pre">=</span> <span class="pre">(alpha</span> <span class="pre">*</span> <span class="pre">X)</span> <span class="pre">+</span> <span class="pre">(beta</span> <span class="pre">*</span> <span class="pre">Y)</span></code> in-place.
|
||||
MLX expects to write the output to a new array. We must copy the elements
|
||||
of <code class="docutils literal notranslate"><span class="pre">y</span></code> into the output and use that as an input to <code class="docutils literal notranslate"><span class="pre">axpby</span></code>.</p></li>
|
||||
</ol>
|
||||
<p>Let’s write out an implementation that uses Accelerate in the right conditions.
|
||||
It must simply allocate data for the output, copy elements of <code class="docutils literal notranslate"><span class="pre">y</span></code> into it,
|
||||
and then call the <code class="xref py py-meth docutils literal notranslate"><span class="pre">catlas_saxpby()</span></code> from accelerate.</p>
|
||||
<p>Let’s write an implementation that uses Accelerate in the right conditions.
|
||||
It allocates data for the output, copies <code class="docutils literal notranslate"><span class="pre">y</span></code> into it, and then calls the
|
||||
<code class="xref py py-func docutils literal notranslate"><span class="pre">catlas_saxpby()</span></code> from accelerate.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o"><</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">></span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="n">axpby_impl_accelerate</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span>
|
||||
@@ -1121,17 +1109,7 @@ and then call the <code class="xref py py-meth docutils literal notranslate"><sp
|
||||
<span class="w"> </span><span class="c1">// Accelerate library provides catlas_saxpby which does</span>
|
||||
<span class="w"> </span><span class="c1">// Y = (alpha * X) + (beta * Y) in place</span>
|
||||
<span class="w"> </span><span class="c1">// To use it, we first copy the data in y over to the output array</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// This specialization requires both x and y be contiguous in the same mode</span>
|
||||
<span class="w"> </span><span class="c1">// i.e: corresponding linear indices in both point to corresponding elements</span>
|
||||
<span class="w"> </span><span class="c1">// The data in the output array is allocated to match the strides in y</span>
|
||||
<span class="w"> </span><span class="c1">// such that x, y, and out are contiguous in the same mode and</span>
|
||||
<span class="w"> </span><span class="c1">// no transposition is needed</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">data_size</span><span class="p">()</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">itemsize</span><span class="p">()),</span>
|
||||
<span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">data_size</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">flags</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// We then copy over the elements using the contiguous vector specialization</span>
|
||||
<span class="w"> </span><span class="n">copy_inplace</span><span class="p">(</span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">CopyType</span><span class="o">::</span><span class="n">Vector</span><span class="p">);</span>
|
||||
@@ -1155,14 +1133,17 @@ and then call the <code class="xref py py-meth docutils literal notranslate"><sp
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Great! But what about the inputs that do not fit the criteria for accelerate?
|
||||
Luckily, we can always just direct back to <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval()</span></code>.</p>
|
||||
<p>With this in mind, lets finally implement our <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>.</p>
|
||||
<p>For inputs that do not fit the criteria for accelerate, we fall back to
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval()</span></code>. With this in mind, let’s finish our
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Evaluate primitive on CPU using accelerate specializations */</span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_cpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_cpu</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Accelerate specialization for contiguous single precision float arrays</span>
|
||||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float32</span><span class="w"> </span><span class="o">&&</span>
|
||||
@@ -1172,33 +1153,32 @@ Luckily, we can always just direct back to <code class="xref py py-meth docutils
|
||||
<span class="w"> </span><span class="k">return</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Fall back to common backend if specializations are not available</span>
|
||||
<span class="w"> </span><span class="n">eval</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="c1">// Fall back to common back-end if specializations are not available</span>
|
||||
<span class="w"> </span><span class="n">eval</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">outputs</span><span class="p">);</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We have now hit a milestone! Just this much is enough to run the operation
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> on a CPU stream!</p>
|
||||
<p>If you do not plan on running the operation on the GPU or using transforms on
|
||||
<p>Just this much is enough to run the operation <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> on a CPU stream! If
|
||||
you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code>, you can stop implementing the
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.</p>
|
||||
</section>
|
||||
<section id="implementing-the-gpu-backend">
|
||||
<h3>Implementing the GPU Backend<a class="headerlink" href="#implementing-the-gpu-backend" title="Link to this heading">#</a></h3>
|
||||
<section id="implementing-the-gpu-back-end">
|
||||
<h3>Implementing the GPU Back-end<a class="headerlink" href="#implementing-the-gpu-back-end" title="Link to this heading">#</a></h3>
|
||||
<p>Apple silicon devices address their GPUs using the <a class="reference external" href="https://developer.apple.com/documentation/metal?language=objc">Metal</a> shading language, and
|
||||
all GPU kernels in MLX are written using metal.</p>
|
||||
GPU kernels in MLX are written using Metal.</p>
|
||||
<div class="admonition note">
|
||||
<p class="admonition-title">Note</p>
|
||||
<p>Here are some helpful resources if you are new to metal!</p>
|
||||
<p>Here are some helpful resources if you are new to Metal:</p>
|
||||
<ul class="simple">
|
||||
<li><p>A walkthrough of the metal compute pipeline: <a class="reference external" href="https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc">Metal Example</a></p></li>
|
||||
<li><p>Documentation for metal shading language: <a class="reference external" href="https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf">Metal Specification</a></p></li>
|
||||
<li><p>Using metal from C++: <a class="reference external" href="https://developer.apple.com/metal/cpp/">Metal-cpp</a></p></li>
|
||||
</ul>
|
||||
</div>
|
||||
<p>Let’s keep the GPU algorithm simple. We will launch exactly as many threads
|
||||
as there are elements in the output. Each thread will pick the element it needs
|
||||
from <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>, do the pointwise operation, and then update its assigned
|
||||
<p>Let’s keep the GPU kernel simple. We will launch exactly as many threads as
|
||||
there are elements in the output. Each thread will pick the element it needs
|
||||
from <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>, do the point-wise operation, and update its assigned
|
||||
element in the output.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o"><</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">></span>
|
||||
<span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="n">axpby_general</span><span class="p">(</span>
|
||||
@@ -1223,8 +1203,7 @@ element in the output.</p>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We then need to instantiate this template for all floating point types and give
|
||||
each instantiation a unique host name so we can identify the right kernel for
|
||||
each data type.</p>
|
||||
each instantiation a unique host name so we can identify it.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cp">#define instantiate_axpby(type_name, type) \</span>
|
||||
<span class="cp"> template [[host_name("axpby_general_" #type_name)]] \</span>
|
||||
<span class="cp"> [[kernel]] void axpby_general<type>( \</span>
|
||||
@@ -1245,25 +1224,18 @@ each data type.</p>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">complex64</span><span class="p">,</span><span class="w"> </span><span class="n">complex64_t</span><span class="p">);</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>This kernel will be compiled into a metal library <code class="docutils literal notranslate"><span class="pre">mlx_ext.metallib</span></code> as we
|
||||
will see later in <a class="reference internal" href="#building-with-cmake"><span class="std std-ref">Building with CMake</span></a>. In the following example, we
|
||||
assume that the library <code class="docutils literal notranslate"><span class="pre">mlx_ext.metallib</span></code> will always be co-located with
|
||||
the executable/ shared-library calling the <code class="xref py py-meth docutils literal notranslate"><span class="pre">register_library()</span></code> function.
|
||||
The <code class="xref py py-meth docutils literal notranslate"><span class="pre">register_library()</span></code> function takes the library’s name and potential
|
||||
path (or in this case, a function that can produce the path of the metal
|
||||
library) and tries to load that library if it hasn’t already been registered
|
||||
by the relevant static <code class="xref py py-class docutils literal notranslate"><span class="pre">mlx::core::metal::Device</span></code> object. This is why,
|
||||
it is important to package your C++ library with the metal library. We will
|
||||
go over this process in more detail later.</p>
|
||||
<p>The logic to determine the kernel, set the inputs, resolve the grid dimensions
|
||||
and dispatch it to the GPU are contained in <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code> as shown
|
||||
<p>The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||
and dispatch to the GPU are contained in <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code> as shown
|
||||
below.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Evaluate primitive on GPU */</span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_gpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_gpu</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="c1">// Prepare inputs</span>
|
||||
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Each primitive carries the stream it should execute on</span>
|
||||
<span class="w"> </span><span class="c1">// and each stream carries its device identifiers</span>
|
||||
@@ -1274,7 +1246,7 @@ below.</p>
|
||||
<span class="w"> </span><span class="c1">// Allocate output memory</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Resolve name of kernel (corresponds to axpby.metal)</span>
|
||||
<span class="w"> </span><span class="c1">// Resolve name of kernel</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">ostringstream</span><span class="w"> </span><span class="n">kname</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="n">kname</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"axpby_"</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"general_"</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">type_to_name</span><span class="p">(</span><span class="n">out</span><span class="p">);</span>
|
||||
|
||||
@@ -1328,24 +1300,21 @@ below.</p>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can now call the <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> operation on both the CPU and the GPU!</p>
|
||||
<p>A few things to note about MLX and metal before moving on. MLX keeps track
|
||||
of the active <code class="docutils literal notranslate"><span class="pre">compute_encoder</span></code>. We rely on <code class="xref py py-meth docutils literal notranslate"><span class="pre">d.get_command_encoder()</span></code>
|
||||
to give us the active metal compute command encoder instead of building a
|
||||
new one and calling <code class="xref py py-meth docutils literal notranslate"><span class="pre">compute_encoder->end_encoding()</span></code> at the end.
|
||||
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
||||
until some specified limit is hit or the compute encoder needs to be flushed
|
||||
for synchronization. MLX also handles enqueuing and committing the associated
|
||||
command buffers as needed. We suggest taking a deeper dive into
|
||||
<code class="xref py py-class docutils literal notranslate"><span class="pre">metal::Device</span></code> if you would like to study this routine further.</p>
|
||||
<p>A few things to note about MLX and Metal before moving on. MLX keeps track of
|
||||
the active <code class="docutils literal notranslate"><span class="pre">command_buffer</span></code> and the <code class="docutils literal notranslate"><span class="pre">MTLCommandBuffer</span></code> to which it is
|
||||
associated. We rely on <code class="xref py py-meth docutils literal notranslate"><span class="pre">d.get_command_encoder()</span></code> to give us the active
|
||||
metal compute command encoder instead of building a new one and calling
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">compute_encoder->end_encoding()</span></code> at the end. MLX adds kernels (compute
|
||||
pipelines) to the active command buffer until some specified limit is hit or
|
||||
the command buffer needs to be flushed for synchronization.</p>
|
||||
</section>
|
||||
<section id="primitive-transforms">
|
||||
<h3>Primitive Transforms<a class="headerlink" href="#primitive-transforms" title="Link to this heading">#</a></h3>
|
||||
<p>Now that we have come this far, let’s also learn how to add implementations to
|
||||
transformations in a <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>. These transformations can be built on
|
||||
top of our operations, including the one we just defined now. Which then gives
|
||||
us the following <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::jvp()</span></code> and <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::vjp()</span></code> implementations.</p>
|
||||
<p>Next, let’s add implementations for transformations in a <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>.
|
||||
These transformations can be built on top of other operations, including the
|
||||
one we just defined:</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** The Jacobian-vector product. */</span>
|
||||
<span class="n">array</span><span class="w"> </span><span class="nf">Axpby::jvp</span><span class="p">(</span>
|
||||
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">jvp</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
@@ -1360,12 +1329,12 @@ us the following <code class="xref py py-meth docutils literal notranslate"><spa
|
||||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">argnums</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">></span><span class="w"> </span><span class="mi">1</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">argnums</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">0</span><span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">alpha_</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dtype</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="p">{</span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">stream</span><span class="p">())};</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="w"> </span><span class="c1">// If, argnums = {0, 1}, we take contributions from both</span>
|
||||
<span class="w"> </span><span class="c1">// which gives us jvp = tangent_x * alpha + tangent_y * beta</span>
|
||||
<span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby</span><span class="p">(</span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="p">{</span><span class="n">axpby</span><span class="p">(</span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">())};</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
@@ -1373,26 +1342,27 @@ us the following <code class="xref py py-meth docutils literal notranslate"><spa
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** The vector-Jacobian product. */</span>
|
||||
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vjp</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">cotangents</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="cm">/* unused */</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="c1">// Reverse mode diff</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">vjps</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="k">auto</span><span class="w"> </span><span class="n">arg</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">arg</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">0</span><span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">alpha_</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">cotan</span><span class="p">.</span><span class="n">dtype</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="n">vjps</span><span class="p">.</span><span class="n">push_back</span><span class="p">(</span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">()));</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">cotangents</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dtype</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="n">vjps</span><span class="p">.</span><span class="n">push_back</span><span class="p">(</span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">cotangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">stream</span><span class="p">()));</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">vjps</span><span class="p">;</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Finally, you need not have a transformation fully defined to start using your
|
||||
own <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>.</p>
|
||||
<p>Note, a transformation does not need to be fully defined to start using
|
||||
the <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Vectorize primitive along given axis */</span>
|
||||
<span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="n">array</span><span class="p">,</span><span class="w"> </span><span class="kt">int</span><span class="o">></span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vmap</span><span class="p">(</span>
|
||||
<span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="p">,</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>></span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vmap</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">axes</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span><span class="s">"Axpby has no vmap implementation."</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span><span class="s">"[Axpby] vmap not implemented."</span><span class="p">);</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
@@ -1416,64 +1386,63 @@ own <code class="xref py py-class docutils literal notranslate"><span class="pre
|
||||
<ul class="simple">
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/axpby/</span></code> defines the C++ extension library</p></li>
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> sets out the structure for the
|
||||
associated python package</p></li>
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/bindings.cpp</span></code> provides python bindings for our operation</p></li>
|
||||
associated Python package</p></li>
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/bindings.cpp</span></code> provides Python bindings for our operation</p></li>
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/CMakeLists.txt</span></code> holds CMake rules to build the library and
|
||||
python bindings</p></li>
|
||||
Python bindings</p></li>
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/setup.py</span></code> holds the <code class="docutils literal notranslate"><span class="pre">setuptools</span></code> rules to build and install
|
||||
the python package</p></li>
|
||||
the Python package</p></li>
|
||||
</ul>
|
||||
<section id="binding-to-python">
|
||||
<h3>Binding to Python<a class="headerlink" href="#binding-to-python" title="Link to this heading">#</a></h3>
|
||||
<p>We use <a class="reference external" href="https://pybind11.readthedocs.io/en/stable/">PyBind11</a> to build a Python API for the C++ library. Since bindings for
|
||||
<p>We use <a class="reference external" href="https://nanobind.readthedocs.io/en/latest/">nanobind</a> to build a Python API for the C++ library. Since bindings for
|
||||
components such as <a class="reference internal" href="../python/_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a>, <a class="reference internal" href="../python/_autosummary/mlx.core.stream.html#mlx.core.stream" title="mlx.core.stream"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.stream</span></code></a>, etc. are
|
||||
already provided, adding our <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> is simple!</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">PYBIND11_MODULE</span><span class="p">(</span><span class="n">mlx_sample_extensions</span><span class="p">,</span><span class="w"> </span><span class="n">m</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">doc</span><span class="p">()</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s">"Sample C++ and metal extensions for MLX"</span><span class="p">;</span>
|
||||
already provided, adding our <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> is simple.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">NB_MODULE</span><span class="p">(</span><span class="n">_ext</span><span class="p">,</span><span class="w"> </span><span class="n">m</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">doc</span><span class="p">()</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s">"Sample extension for MLX"</span><span class="p">;</span>
|
||||
|
||||
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">def</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="s">"axpby"</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="o">&</span><span class="n">axpby</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="s">"x"</span><span class="n">_a</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="s">"y"</span><span class="n">_a</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">pos_only</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="s">"alpha"</span><span class="n">_a</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="s">"beta"</span><span class="n">_a</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">kw_only</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="s">"stream"</span><span class="n">_a</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">none</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="sa">R</span><span class="s">"</span><span class="dl">pbdoc(</span>
|
||||
<span class="s"> Scale and sum two vectors element-wise</span>
|
||||
<span class="s"> ``z = alpha * x + beta * y``</span>
|
||||
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">def</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="s">"axpby"</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="o">&</span><span class="n">axpby</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="s">"x"</span><span class="n">_a</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="s">"y"</span><span class="n">_a</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="s">"alpha"</span><span class="n">_a</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="s">"beta"</span><span class="n">_a</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">nb</span><span class="o">::</span><span class="n">kw_only</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="s">"stream"</span><span class="n">_a</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">nb</span><span class="o">::</span><span class="n">none</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="sa">R</span><span class="s">"</span><span class="dl">(</span>
|
||||
<span class="s"> Scale and sum two vectors element-wise</span>
|
||||
<span class="s"> ``z = alpha * x + beta * y``</span>
|
||||
|
||||
<span class="s"> Follows numpy style broadcasting between ``x`` and ``y``</span>
|
||||
<span class="s"> Inputs are upcasted to floats if needed</span>
|
||||
<span class="s"> Follows numpy style broadcasting between ``x`` and ``y``</span>
|
||||
<span class="s"> Inputs are upcasted to floats if needed</span>
|
||||
|
||||
<span class="s"> Args:</span>
|
||||
<span class="s"> x (array): Input array.</span>
|
||||
<span class="s"> y (array): Input array.</span>
|
||||
<span class="s"> alpha (float): Scaling factor for ``x``.</span>
|
||||
<span class="s"> beta (float): Scaling factor for ``y``.</span>
|
||||
<span class="s"> Args:</span>
|
||||
<span class="s"> x (array): Input array.</span>
|
||||
<span class="s"> y (array): Input array.</span>
|
||||
<span class="s"> alpha (float): Scaling factor for ``x``.</span>
|
||||
<span class="s"> beta (float): Scaling factor for ``y``.</span>
|
||||
|
||||
<span class="s"> Returns:</span>
|
||||
<span class="s"> array: ``alpha * x + beta * y``</span>
|
||||
<span class="s"> </span><span class="dl">)pbdoc</span><span class="s">"</span><span class="p">);</span>
|
||||
<span class="p">}</span>
|
||||
<span class="s"> Returns:</span>
|
||||
<span class="s"> array: ``alpha * x + beta * y``</span>
|
||||
<span class="s"> </span><span class="dl">)</span><span class="s">"</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Most of the complexity in the above example comes from additional bells and
|
||||
whistles such as the literal names and doc-strings.</p>
|
||||
<div class="admonition warning">
|
||||
<p class="admonition-title">Warning</p>
|
||||
<p><code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code> needs to be imported before importing
|
||||
<code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx_sample_extensions</span></code> as defined by the pybind11 module above to
|
||||
<p><code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code> must be imported before importing
|
||||
<code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx_sample_extensions</span></code> as defined by the nanobind module above to
|
||||
ensure that the casters for <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code> components like
|
||||
<a class="reference internal" href="../python/_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> are available.</p>
|
||||
</div>
|
||||
</section>
|
||||
<section id="building-with-cmake">
|
||||
<span id="id1"></span><h3>Building with CMake<a class="headerlink" href="#building-with-cmake" title="Link to this heading">#</a></h3>
|
||||
<p>Building the C++ extension library itself is simple, it only requires that you
|
||||
<code class="docutils literal notranslate"><span class="pre">find_package(MLX</span> <span class="pre">CONFIG)</span></code> and then link it to your library.</p>
|
||||
<p>Building the C++ extension library only requires that you <code class="docutils literal notranslate"><span class="pre">find_package(MLX</span>
|
||||
<span class="pre">CONFIG)</span></code> and then link it to your library.</p>
|
||||
<div class="highlight-cmake notranslate"><div class="highlight"><pre><span></span><span class="c"># Add library</span>
|
||||
<span class="nb">add_library</span><span class="p">(</span><span class="s">mlx_ext</span><span class="p">)</span>
|
||||
|
||||
@@ -1493,11 +1462,11 @@ ensure that the casters for <code class="xref py py-mod docutils literal notrans
|
||||
<span class="nb">target_link_libraries</span><span class="p">(</span><span class="s">mlx_ext</span><span class="w"> </span><span class="s">PUBLIC</span><span class="w"> </span><span class="s">mlx</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We also need to build the attached metal library. For convenience, we provide a
|
||||
<p>We also need to build the attached Metal library. For convenience, we provide a
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx_build_metallib()</span></code> function that builds a <code class="docutils literal notranslate"><span class="pre">.metallib</span></code> target given
|
||||
sources, headers, destinations, etc. (defined in <code class="docutils literal notranslate"><span class="pre">cmake/extension.cmake</span></code> and
|
||||
automatically imported with MLX package).</p>
|
||||
<p>Here is what that looks like in practice!</p>
|
||||
<p>Here is what that looks like in practice:</p>
|
||||
<div class="highlight-cmake notranslate"><div class="highlight"><pre><span></span><span class="c"># Build metallib</span>
|
||||
<span class="nb">if</span><span class="p">(</span><span class="s">MLX_BUILD_METAL</span><span class="p">)</span>
|
||||
|
||||
@@ -1517,15 +1486,17 @@ automatically imported with MLX package).</p>
|
||||
<span class="nb">endif</span><span class="p">()</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Finally, we build the <a class="reference external" href="https://pybind11.readthedocs.io/en/stable/">Pybind11</a> bindings</p>
|
||||
<div class="highlight-cmake notranslate"><div class="highlight"><pre><span></span><span class="nb">pybind11_add_module</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="s">mlx_sample_extensions</span>
|
||||
<span class="w"> </span><span class="o">${</span><span class="nv">CMAKE_CURRENT_LIST_DIR</span><span class="o">}</span><span class="s">/bindings.cpp</span>
|
||||
<p>Finally, we build the <a class="reference external" href="https://nanobind.readthedocs.io/en/latest/">nanobind</a> bindings</p>
|
||||
<div class="highlight-cmake notranslate"><div class="highlight"><pre><span></span><span class="nb">nanobind_add_module</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="s">_ext</span>
|
||||
<span class="w"> </span><span class="s">NB_STATIC</span><span class="w"> </span><span class="s">STABLE_ABI</span><span class="w"> </span><span class="s">LTO</span><span class="w"> </span><span class="s">NOMINSIZE</span>
|
||||
<span class="w"> </span><span class="s">NB_DOMAIN</span><span class="w"> </span><span class="s">mlx</span>
|
||||
<span class="w"> </span><span class="o">${</span><span class="nv">CMAKE_CURRENT_LIST_DIR</span><span class="o">}</span><span class="s">/bindings.cpp</span>
|
||||
<span class="p">)</span>
|
||||
<span class="nb">target_link_libraries</span><span class="p">(</span><span class="s">mlx_sample_extensions</span><span class="w"> </span><span class="s">PRIVATE</span><span class="w"> </span><span class="s">mlx_ext</span><span class="p">)</span>
|
||||
<span class="nb">target_link_libraries</span><span class="p">(</span><span class="s">_ext</span><span class="w"> </span><span class="s">PRIVATE</span><span class="w"> </span><span class="s">mlx_ext</span><span class="p">)</span>
|
||||
|
||||
<span class="nb">if</span><span class="p">(</span><span class="s">BUILD_SHARED_LIBS</span><span class="p">)</span>
|
||||
<span class="w"> </span><span class="nb">target_link_options</span><span class="p">(</span><span class="s">mlx_sample_extensions</span><span class="w"> </span><span class="s">PRIVATE</span><span class="w"> </span><span class="s">-Wl,-rpath,@loader_path</span><span class="p">)</span>
|
||||
<span class="w"> </span><span class="nb">target_link_options</span><span class="p">(</span><span class="s">_ext</span><span class="w"> </span><span class="s">PRIVATE</span><span class="w"> </span><span class="s">-Wl,-rpath,@loader_path</span><span class="p">)</span>
|
||||
<span class="nb">endif</span><span class="p">()</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
@@ -1533,7 +1504,7 @@ automatically imported with MLX package).</p>
|
||||
<section id="building-with-setuptools">
|
||||
<h3>Building with <code class="docutils literal notranslate"><span class="pre">setuptools</span></code><a class="headerlink" href="#building-with-setuptools" title="Link to this heading">#</a></h3>
|
||||
<p>Once we have set out the CMake build rules as described above, we can use the
|
||||
build utilities defined in <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.extension</span></code> for a simple build process.</p>
|
||||
build utilities defined in <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.extension</span></code>:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">mlx</span> <span class="kn">import</span> <span class="n">extension</span>
|
||||
<span class="kn">from</span> <span class="nn">setuptools</span> <span class="kn">import</span> <span class="n">setup</span>
|
||||
|
||||
@@ -1542,13 +1513,13 @@ build utilities defined in <code class="xref py py-mod docutils literal notransl
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"mlx_sample_extensions"</span><span class="p">,</span>
|
||||
<span class="n">version</span><span class="o">=</span><span class="s2">"0.0.0"</span><span class="p">,</span>
|
||||
<span class="n">description</span><span class="o">=</span><span class="s2">"Sample C++ and Metal extensions for MLX primitives."</span><span class="p">,</span>
|
||||
<span class="n">ext_modules</span><span class="o">=</span><span class="p">[</span><span class="n">extension</span><span class="o">.</span><span class="n">CMakeExtension</span><span class="p">(</span><span class="s2">"mlx_sample_extensions"</span><span class="p">)],</span>
|
||||
<span class="n">ext_modules</span><span class="o">=</span><span class="p">[</span><span class="n">extension</span><span class="o">.</span><span class="n">CMakeExtension</span><span class="p">(</span><span class="s2">"mlx_sample_extensions._ext"</span><span class="p">)],</span>
|
||||
<span class="n">cmdclass</span><span class="o">=</span><span class="p">{</span><span class="s2">"build_ext"</span><span class="p">:</span> <span class="n">extension</span><span class="o">.</span><span class="n">CMakeBuild</span><span class="p">},</span>
|
||||
<span class="n">packages</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"mlx_sample_extensions"</span><span class="p">],</span>
|
||||
<span class="n">package_dir</span> <span class="o">=</span> <span class="p">{</span><span class="s2">""</span><span class="p">:</span> <span class="s2">"mlx_sample_extensions"</span><span class="p">},</span>
|
||||
<span class="n">package_data</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"mlx_sample_extensions"</span> <span class="p">:</span> <span class="p">[</span><span class="s2">"*.so"</span><span class="p">,</span> <span class="s2">"*.dylib"</span><span class="p">,</span> <span class="s2">"*.metallib"</span><span class="p">]},</span>
|
||||
<span class="n">packages</span><span class="o">=</span><span class="p">[</span><span class="s2">"mlx_sample_extensions"</span><span class="p">],</span>
|
||||
<span class="n">package_data</span><span class="o">=</span><span class="p">{</span><span class="s2">"mlx_sample_extensions"</span><span class="p">:</span> <span class="p">[</span><span class="s2">"*.so"</span><span class="p">,</span> <span class="s2">"*.dylib"</span><span class="p">,</span> <span class="s2">"*.metallib"</span><span class="p">]},</span>
|
||||
<span class="n">extras_require</span><span class="o">=</span><span class="p">{</span><span class="s2">"dev"</span><span class="p">:[]},</span>
|
||||
<span class="n">zip_safe</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">python_requires</span><span class="o">=</span><span class="s2">">=3.7"</span><span class="p">,</span>
|
||||
<span class="n">python_requires</span><span class="o">=</span><span class="s2">">=3.8"</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
@@ -1557,34 +1528,36 @@ build utilities defined in <code class="xref py py-mod docutils literal notransl
|
||||
<p>We treat <code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> as the package directory
|
||||
even though it only contains a <code class="docutils literal notranslate"><span class="pre">__init__.py</span></code> to ensure the following:</p>
|
||||
<ul class="simple">
|
||||
<li><p><code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code> is always imported before importing <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx_sample_extensions</span></code></p></li>
|
||||
<li><p><code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code> must be imported before importing <code class="xref py py-mod docutils literal notranslate"><span class="pre">_ext</span></code></p></li>
|
||||
<li><p>The C++ extension library and the metal library are co-located with the python
|
||||
bindings and copied together if the package is installed</p></li>
|
||||
</ul>
|
||||
</div>
|
||||
<p>You can build inplace for development using
|
||||
<p>To build the package, first install the build dependencies with <code class="docutils literal notranslate"><span class="pre">pip</span> <span class="pre">install</span>
|
||||
<span class="pre">-r</span> <span class="pre">requirements.txt</span></code>. You can then build inplace for development using
|
||||
<code class="docutils literal notranslate"><span class="pre">python</span> <span class="pre">setup.py</span> <span class="pre">build_ext</span> <span class="pre">-j8</span> <span class="pre">--inplace</span></code> (in <code class="docutils literal notranslate"><span class="pre">extensions/</span></code>)</p>
|
||||
<p>This will result in a directory structure as follows:</p>
|
||||
<p>This results in the directory structure:</p>
|
||||
<div class="line-block">
|
||||
<div class="line">extensions</div>
|
||||
<div class="line">├── mlx_sample_extensions</div>
|
||||
<div class="line">│ ├── __init__.py</div>
|
||||
<div class="line">│ ├── libmlx_ext.dylib # C++ extension library</div>
|
||||
<div class="line">│ ├── mlx_ext.metallib # Metal library</div>
|
||||
<div class="line">│ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding</div>
|
||||
<div class="line">│ └── _ext.cpython-3x-darwin.so # Python Binding</div>
|
||||
<div class="line">…</div>
|
||||
</div>
|
||||
<p>When you try to install using the command <code class="docutils literal notranslate"><span class="pre">python</span> <span class="pre">-m</span> <span class="pre">pip</span> <span class="pre">install</span> <span class="pre">.</span></code>
|
||||
(in <code class="docutils literal notranslate"><span class="pre">extensions/</span></code>), the package will be installed with the same structure as
|
||||
<code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> and the C++ and metal library will be
|
||||
copied along with the python binding since they are specified as <code class="docutils literal notranslate"><span class="pre">package_data</span></code>.</p>
|
||||
<p>When you try to install using the command <code class="docutils literal notranslate"><span class="pre">python</span> <span class="pre">-m</span> <span class="pre">pip</span> <span class="pre">install</span> <span class="pre">.</span></code> (in
|
||||
<code class="docutils literal notranslate"><span class="pre">extensions/</span></code>), the package will be installed with the same structure as
|
||||
<code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> and the C++ and Metal library will be
|
||||
copied along with the Python binding since they are specified as
|
||||
<code class="docutils literal notranslate"><span class="pre">package_data</span></code>.</p>
|
||||
</section>
|
||||
</section>
|
||||
<section id="usage">
|
||||
<h2>Usage<a class="headerlink" href="#usage" title="Link to this heading">#</a></h2>
|
||||
<p>After installing the extension as described above, you should be able to simply
|
||||
import the python package and play with it as you would any other MLX operation!</p>
|
||||
<p>Let’s looks at a simple script and it’s results!</p>
|
||||
import the Python package and play with it as you would any other MLX operation.</p>
|
||||
<p>Let’s look at a simple script and its results:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||||
<span class="kn">from</span> <span class="nn">mlx_sample_extensions</span> <span class="kn">import</span> <span class="n">axpby</span>
|
||||
|
||||
@@ -1606,7 +1579,7 @@ import the python package and play with it as you would any other MLX operation!
|
||||
<section id="results">
|
||||
<h3>Results<a class="headerlink" href="#results" title="Link to this heading">#</a></h3>
|
||||
<p>Let’s run a quick benchmark and see how our new <code class="docutils literal notranslate"><span class="pre">axpby</span></code> operation compares
|
||||
with the naive <code class="xref py py-meth docutils literal notranslate"><span class="pre">simple_axpby()</span></code> we defined at first on the CPU.</p>
|
||||
with the naive <code class="xref py py-meth docutils literal notranslate"><span class="pre">simple_axpby()</span></code> we first defined on the CPU.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||||
<span class="kn">from</span> <span class="nn">mlx_sample_extensions</span> <span class="kn">import</span> <span class="n">axpby</span>
|
||||
<span class="kn">import</span> <span class="nn">time</span>
|
||||
@@ -1624,7 +1597,7 @@ with the naive <code class="xref py py-meth docutils literal notranslate"><span
|
||||
<span class="n">alpha</span> <span class="o">=</span> <span class="mf">4.0</span>
|
||||
<span class="n">beta</span> <span class="o">=</span> <span class="mf">2.0</span>
|
||||
|
||||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
|
||||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">bench</span><span class="p">(</span><span class="n">f</span><span class="p">):</span>
|
||||
<span class="c1"># Warm up</span>
|
||||
@@ -1646,21 +1619,18 @@ with the naive <code class="xref py py-meth docutils literal notranslate"><span
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Simple axpby: </span><span class="si">{</span><span class="n">simple_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> s | Custom axpby: </span><span class="si">{</span><span class="n">custom_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> s"</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Results:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">Simple</span> <span class="n">axpby</span><span class="p">:</span> <span class="mf">0.114</span> <span class="n">s</span> <span class="o">|</span> <span class="n">Custom</span> <span class="n">axpby</span><span class="p">:</span> <span class="mf">0.109</span> <span class="n">s</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We see some modest improvements right away!</p>
|
||||
<p>The results are <code class="docutils literal notranslate"><span class="pre">Simple</span> <span class="pre">axpby:</span> <span class="pre">0.114</span> <span class="pre">s</span> <span class="pre">|</span> <span class="pre">Custom</span> <span class="pre">axpby:</span> <span class="pre">0.109</span> <span class="pre">s</span></code>. We see
|
||||
modest improvements right away!</p>
|
||||
<p>This operation is now good to be used to build other operations, in
|
||||
<a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> calls, and also as a part of graph transformations like
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">grad()</span></code>!</p>
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">grad()</span></code>.</p>
|
||||
</section>
|
||||
</section>
|
||||
<section id="scripts">
|
||||
<h2>Scripts<a class="headerlink" href="#scripts" title="Link to this heading">#</a></h2>
|
||||
<div class="admonition-download-the-code admonition">
|
||||
<p class="admonition-title">Download the code</p>
|
||||
<p>The full example code is available in <a class="reference external" href="code">mlx</a>.</p>
|
||||
<p>The full example code is available in <a class="reference external" href="https://github.com/ml-explore/mlx/tree/main/examples/extensions/">mlx</a>.</p>
|
||||
</div>
|
||||
</section>
|
||||
</section>
|
||||
@@ -1713,12 +1683,12 @@ with the naive <code class="xref py py-meth docutils literal notranslate"><span
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#operations-and-primitives">Operations and Primitives</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#operations">Operations</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#primitives">Primitives</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#using-the-primitives">Using the Primitives</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#using-the-primitive">Using the Primitive</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-primitive">Implementing the Primitive</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-cpu-backend">Implementing the CPU Backend</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-gpu-backend">Implementing the GPU Backend</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-cpu-back-end">Implementing the CPU Back-end</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-gpu-back-end">Implementing the GPU Back-end</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#primitive-transforms">Primitive Transforms</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
|
54
docs/build/html/dev/metal_debugger.html
vendored
54
docs/build/html/dev/metal_debugger.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
|
||||
|
||||
<title>Metal Debugger — MLX 0.9.0 documentation</title>
|
||||
<title>Metal Debugger — MLX 0.10.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=5b4479735964841361fd" />
|
||||
<script src="../_static/vendor/fontawesome/6.1.2/js/all.min.js?digest=5b4479735964841361fd"></script>
|
||||
|
||||
<script src="../_static/documentation_options.js?v=2a76c96f"></script>
|
||||
<script src="../_static/documentation_options.js?v=cb265169"></script>
|
||||
<script src="../_static/doctools.js?v=888ff710"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=efea14e4"></script>
|
||||
@@ -130,8 +130,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.9.0 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.9.0 documentation - Home"/>`);</script>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.10.0 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.10.0 documentation - Home"/>`);</script>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -285,6 +285,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erf.html">mlx.core.erf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erfinv.html">mlx.core.erfinv</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exp.html">mlx.core.exp</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expm1.html">mlx.core.expm1</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expand_dims.html">mlx.core.expand_dims</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eye.html">mlx.core.eye</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.flatten.html">mlx.core.flatten</a></li>
|
||||
@@ -317,6 +318,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.max.html">mlx.core.max</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.maximum.html">mlx.core.maximum</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.mean.html">mlx.core.mean</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.meshgrid.html">mlx.core.meshgrid</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.min.html">mlx.core.min</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.minimum.html">mlx.core.minimum</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.moveaxis.html">mlx.core.moveaxis</a></li>
|
||||
@@ -351,6 +353,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.square.html">mlx.core.square</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.squeeze.html">mlx.core.squeeze</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stack.html">mlx.core.stack</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.std.html">mlx.core.std</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stop_gradient.html">mlx.core.stop_gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.subtract.html">mlx.core.subtract</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sum.html">mlx.core.sum</a></li>
|
||||
@@ -378,6 +381,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.gumbel.html">mlx.core.random.gumbel</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.key.html">mlx.core.random.key</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.normal.html">mlx.core.random.normal</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.multivariate_normal.html">mlx.core.random.multivariate_normal</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.randint.html">mlx.core.random.randint</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.seed.html">mlx.core.random.seed</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.split.html">mlx.core.random.split</a></li>
|
||||
@@ -431,6 +435,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.get_cache_memory.html">mlx.core.metal.get_cache_memory</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.set_memory_limit.html">mlx.core.metal.set_memory_limit</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.set_cache_limit.html">mlx.core.metal.set_cache_limit</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.start_capture.html">mlx.core.metal.start_capture</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.stop_capture.html">mlx.core.metal.stop_capture</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/nn.html">Neural Networks</a><input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-11"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
@@ -773,25 +779,39 @@ document.write(`
|
||||
<section id="metal-debugger">
|
||||
<h1>Metal Debugger<a class="headerlink" href="#metal-debugger" title="Link to this heading">#</a></h1>
|
||||
<p>Profiling is a key step for performance optimization. You can build MLX with
|
||||
the <code class="docutils literal notranslate"><span class="pre">MLX_METAL_DEBUG</span></code> option to improve the Metal debugging and optimization
|
||||
workflow. The <code class="docutils literal notranslate"><span class="pre">MLX_METAL_DEBUG</span></code> debug option:</p>
|
||||
the <code class="docutils literal notranslate"><span class="pre">MLX_METAL_DEBUG</span></code> option to improve the Metal debugging and
|
||||
optimization workflow. The <code class="docutils literal notranslate"><span class="pre">MLX_METAL_DEBUG</span></code> debug option:</p>
|
||||
<ul class="simple">
|
||||
<li><p>Records source during Metal compilation, for later inspection while
|
||||
debugging.</p></li>
|
||||
<li><p>Labels Metal objects such as command queues, improving capture readability.</p></li>
|
||||
</ul>
|
||||
<p>The <code class="docutils literal notranslate"><span class="pre">metal::start_capture</span></code> function initiates a capture of all MLX GPU work.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="kt">int</span><span class="w"> </span><span class="nf">main</span><span class="p">()</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="n">metal</span><span class="o">::</span><span class="n">start_capture</span><span class="p">(</span><span class="s">"/Users/Jane/Developer/MLX.gputrace"</span><span class="p">);</span>
|
||||
<p>To build with debugging enabled in Python prepend
|
||||
<code class="docutils literal notranslate"><span class="pre">CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"</span></code> to the build call.</p>
|
||||
<p>The <a class="reference internal" href="../python/_autosummary/mlx.core.metal.start_capture.html#mlx.core.metal.start_capture" title="mlx.core.metal.start_capture"><code class="xref py py-func docutils literal notranslate"><span class="pre">metal.start_capture()</span></code></a> function initiates a capture of all MLX GPU
|
||||
work.</p>
|
||||
<div class="admonition note">
|
||||
<p class="admonition-title">Note</p>
|
||||
<p>To capture a GPU trace you must run the application with
|
||||
<code class="docutils literal notranslate"><span class="pre">MTL_CAPTURE_ENABLED=1</span></code>.</p>
|
||||
</div>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||||
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">a</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">arange</span><span class="p">(</span><span class="mf">10.f</span><span class="p">,</span><span class="w"> </span><span class="mf">20.f</span><span class="p">,</span><span class="w"> </span><span class="mf">1.f</span><span class="p">,</span><span class="w"> </span><span class="n">float32</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">b</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">arange</span><span class="p">(</span><span class="mf">30.f</span><span class="p">,</span><span class="w"> </span><span class="mf">40.f</span><span class="p">,</span><span class="w"> </span><span class="mf">1.f</span><span class="p">,</span><span class="w"> </span><span class="n">float32</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">c</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">add</span><span class="p">(</span><span class="n">a</span><span class="p">,</span><span class="w"> </span><span class="n">b</span><span class="p">);</span>
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">))</span>
|
||||
<span class="n">b</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">))</span>
|
||||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||||
|
||||
<span class="w"> </span><span class="n">eval</span><span class="p">(</span><span class="n">c</span><span class="p">);</span>
|
||||
<span class="n">trace_file</span> <span class="o">=</span> <span class="s2">"mlx_trace.gputrace"</span>
|
||||
|
||||
<span class="w"> </span><span class="n">metal</span><span class="o">::</span><span class="n">stop_capture</span><span class="p">();</span>
|
||||
<span class="p">}</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">mx</span><span class="o">.</span><span class="n">metal</span><span class="o">.</span><span class="n">start_capture</span><span class="p">(</span><span class="n">trace_file</span><span class="p">):</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="s2">"Make sure to run with MTL_CAPTURE_ENABLED=1 and "</span>
|
||||
<span class="sa">f</span><span class="s2">"that the path </span><span class="si">{</span><span class="n">trace_file</span><span class="si">}</span><span class="s2"> does not already exist."</span><span class="p">)</span>
|
||||
<span class="n">exit</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span>
|
||||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||||
|
||||
<span class="n">mx</span><span class="o">.</span><span class="n">metal</span><span class="o">.</span><span class="n">stop_capture</span><span class="p">()</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>You can open and replay the GPU trace in Xcode. The <code class="docutils literal notranslate"><span class="pre">Dependencies</span></code> view
|
||||
@@ -800,8 +820,8 @@ documentation</a> for more information.</p>
|
||||
<img alt="../_images/capture.png" class="dark-light" src="../_images/capture.png" />
|
||||
<section id="xcode-workflow">
|
||||
<h2>Xcode Workflow<a class="headerlink" href="#xcode-workflow" title="Link to this heading">#</a></h2>
|
||||
<p>You can skip saving to a path by running within Xcode. First, generate an Xcode
|
||||
project using CMake.</p>
|
||||
<p>You can skip saving to a path by running within Xcode. First, generate an
|
||||
Xcode project using CMake.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">mkdir</span> <span class="n">build</span> <span class="o">&&</span> <span class="n">cd</span> <span class="n">build</span>
|
||||
<span class="n">cmake</span> <span class="o">..</span> <span class="o">-</span><span class="n">DMLX_METAL_DEBUG</span><span class="o">=</span><span class="n">ON</span> <span class="o">-</span><span class="n">G</span> <span class="n">Xcode</span>
|
||||
<span class="nb">open</span> <span class="n">mlx</span><span class="o">.</span><span class="n">xcodeproj</span>
|
||||
|
Reference in New Issue
Block a user