This commit is contained in:
CircleCI Docs
2025-03-20 22:37:22 +00:00
parent ca8beb747c
commit 0438cc01cd
858 changed files with 18494 additions and 17475 deletions

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Custom Metal Kernels &#8212; MLX 0.23.2 documentation</title>
<title>Custom Metal Kernels &#8212; MLX 0.24.0 documentation</title>
@@ -36,7 +36,7 @@
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
<script src="../_static/documentation_options.js?v=9900918c"></script>
<script src="../_static/documentation_options.js?v=ae1d10b0"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -137,8 +137,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.23.2 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.23.2 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.24.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.24.0 documentation - Home"/>`);</script>
</a></div>

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Custom Extensions in MLX &#8212; MLX 0.23.2 documentation</title>
<title>Custom Extensions in MLX &#8212; MLX 0.24.0 documentation</title>
@@ -36,7 +36,7 @@
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
<script src="../_static/documentation_options.js?v=9900918c"></script>
<script src="../_static/documentation_options.js?v=ae1d10b0"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -137,8 +137,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.23.2 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.23.2 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.24.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.24.0 documentation - Home"/>`);</script>
</a></div>
@@ -942,12 +942,12 @@ You can do that in MLX directly:</p>
</div>
<p>This function performs that operation while leaving the implementation and
function transformations to MLX.</p>
<p>However you may need to customize the underlying implementation, perhaps to
make it faster or for custom differentiation. In this tutorial we will go
through adding custom extensions. It will cover:</p>
<p>However, you may want to customize the underlying implementation, perhaps to
make it faster. In this tutorial we will go through adding custom extensions.
It will cover:</p>
<ul class="simple">
<li><p>The structure of the MLX library.</p></li>
<li><p>Implementing a CPU operation that redirects to <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a> when appropriate.</p></li>
<li><p>Implementing a CPU operation.</p></li>
<li><p>Implementing a GPU operation using metal.</p></li>
<li><p>Adding the <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code> function transformation.</p></li>
<li><p>Building a custom extension and binding it to python.</p></li>
@@ -962,14 +962,14 @@ more detail.</p>
<h3>Operations<a class="headerlink" href="#operations" title="Link to this heading">#</a></h3>
<p>Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (<a class="reference internal" href="../cpp/ops.html#cpp-ops"><span class="std std-ref">Operations</span></a>), and the Python API (<a class="reference internal" href="../python/ops.html#ops"><span class="std std-ref">Operations</span></a>) binds them.</p>
<p>We would like an operation, <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> that takes in two arrays <code class="docutils literal notranslate"><span class="pre">x</span></code> and
<p>We would like an operation <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> that takes in two arrays, <code class="docutils literal notranslate"><span class="pre">x</span></code> and
<code class="docutils literal notranslate"><span class="pre">y</span></code>, and two scalars, <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>. This is how to define it in
C++:</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/**</span>
<span class="cm">* Scale and sum two vectors element-wise</span>
<span class="cm">* z = alpha * x + beta * y</span>
<span class="cm">*</span>
<span class="cm">* Follow numpy style broadcasting between x and y</span>
<span class="cm">* Use NumPy-style broadcasting between x and y</span>
<span class="cm">* Inputs are upcasted to floats if needed</span>
<span class="cm">**/</span>
<span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
@@ -981,7 +981,7 @@ C++:</p>
<span class="p">);</span>
</pre></div>
</div>
<p>The simplest way to this operation is in terms of existing operations:</p>
<p>The simplest way to implement this is with existing operations:</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array x</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array y</span>
@@ -1062,9 +1062,6 @@ more concrete:</p>
<span class="w"> </span><span class="k">private</span><span class="o">:</span>
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">;</span>
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span>
<span class="w"> </span><span class="cm">/** Fall back implementation for evaluation on CPU */</span>
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="p">);</span>
<span class="p">};</span>
</pre></div>
</div>
@@ -1093,7 +1090,7 @@ inputs that are passed to the primitive.</p>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">promoted_dtype</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">promote_types</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">dtype</span><span class="p">(),</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">dtype</span><span class="p">());</span>
<span class="w"> </span><span class="c1">// Upcast to float32 for non-floating point inputs x and y</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">out_dtype</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">is_floating_point</span><span class="p">(</span><span class="n">promoted_dtype</span><span class="p">)</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">out_dtype</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">issubdtype</span><span class="p">(</span><span class="n">promoted_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">float32</span><span class="p">)</span>
<span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">promoted_dtype</span>
<span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">promote_types</span><span class="p">(</span><span class="n">promoted_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">float32</span><span class="p">);</span>
@@ -1139,153 +1136,87 @@ of these functions to allocate memory as needed.</p>
</div>
<section id="implementing-the-cpu-back-end">
<h3>Implementing the CPU Back-end<a class="headerlink" href="#implementing-the-cpu-back-end" title="Link to this heading">#</a></h3>
<p>Lets start by implementing a naive and generic version of
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>. We declared this as a private member function of
<code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> earlier called <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval()</span></code>.</p>
<p>Our naive method will go over each element of the output array, find the
<p>Lets start by implementing <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>.</p>
<p>The method will go over each element of the output array, find the
corresponding input elements of <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> and perform the operation
point-wise. This is captured in the templated function <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby_impl()</span></code>.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o">&lt;</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">&gt;</span>
<span class="kt">void</span><span class="w"> </span><span class="n">axpby_impl</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">y</span><span class="p">,</span>
<span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="p">,</span>
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span>
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// We only allocate memory when we are ready to fill the output</span>
<span class="w"> </span><span class="c1">// malloc_or_wait synchronously allocates available memory</span>
<span class="w"> </span><span class="c1">// There may be a wait executed here if the allocation is requested</span>
<span class="w"> </span><span class="c1">// under memory-pressured conditions</span>
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">y</span><span class="p">,</span>
<span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="p">,</span>
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span>
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span>
<span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">Stream</span><span class="w"> </span><span class="n">stream</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Allocate the output with `malloc_or_wait` which synchronously allocates</span>
<span class="w"> </span><span class="c1">// memory, potentially waiting if the system is under memory pressure</span>
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">mx</span><span class="o">::</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
<span class="w"> </span><span class="c1">// Collect input and output data pointers</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">();</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">();</span>
<span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">out_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">();</span>
<span class="w"> </span><span class="c1">// Get the CPU command encoder and register input and output arrays</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">encoder</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">cpu</span><span class="o">::</span><span class="n">get_command_encoder</span><span class="p">(</span><span class="n">stream</span><span class="p">);</span>
<span class="w"> </span><span class="n">encoder</span><span class="p">.</span><span class="n">set_input_array</span><span class="p">(</span><span class="n">x</span><span class="p">);</span>
<span class="w"> </span><span class="n">encoder</span><span class="p">.</span><span class="n">set_input_array</span><span class="p">(</span><span class="n">y</span><span class="p">);</span>
<span class="w"> </span><span class="n">encoder</span><span class="p">.</span><span class="n">set_output_array</span><span class="p">(</span><span class="n">out</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Launch the CPU kernel</span>
<span class="w"> </span><span class="n">encoder</span><span class="p">.</span><span class="n">dispatch</span><span class="p">([</span><span class="n">x_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">(),</span>
<span class="w"> </span><span class="n">y_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">(),</span>
<span class="w"> </span><span class="n">out_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">(),</span>
<span class="w"> </span><span class="n">size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span>
<span class="w"> </span><span class="n">shape</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span>
<span class="w"> </span><span class="n">x_strides</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span>
<span class="w"> </span><span class="n">y_strides</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span>
<span class="w"> </span><span class="n">alpha_</span><span class="p">,</span>
<span class="w"> </span><span class="n">beta_</span><span class="p">]()</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Cast alpha and beta to the relevant types</span>
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">(</span><span class="n">alpha_</span><span class="p">);</span>
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">(</span><span class="n">beta_</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Do the element-wise operation for each output</span>
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="kt">size_t</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="mi">0</span><span class="p">;</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o">&lt;</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">();</span><span class="w"> </span><span class="n">out_idx</span><span class="o">++</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Map linear indices to offsets in x and y</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">());</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">());</span>
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="kt">size_t</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="mi">0</span><span class="p">;</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o">&lt;</span><span class="w"> </span><span class="n">size</span><span class="p">;</span><span class="w"> </span><span class="n">out_idx</span><span class="o">++</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Map linear indices to offsets in x and y</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="p">,</span><span class="w"> </span><span class="n">x_strides</span><span class="p">);</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="p">,</span><span class="w"> </span><span class="n">y_strides</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// We allocate the output to be contiguous and regularly strided</span>
<span class="w"> </span><span class="c1">// (defaults to row major) and hence it doesn&#39;t need additional mapping</span>
<span class="w"> </span><span class="n">out_ptr</span><span class="p">[</span><span class="n">out_idx</span><span class="p">]</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="p">[</span><span class="n">x_offset</span><span class="p">]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="p">[</span><span class="n">y_offset</span><span class="p">];</span>
<span class="w"> </span><span class="c1">// We allocate the output to be contiguous and regularly strided</span>
<span class="w"> </span><span class="c1">// (defaults to row major) and hence it doesn&#39;t need additional mapping</span>
<span class="w"> </span><span class="n">out_ptr</span><span class="p">[</span><span class="n">out_idx</span><span class="p">]</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="p">[</span><span class="n">x_offset</span><span class="p">]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="p">[</span><span class="n">y_offset</span><span class="p">];</span>
<span class="w"> </span><span class="p">}</span>
<span class="w"> </span><span class="p">});</span>
<span class="p">}</span>
</pre></div>
</div>
<p>Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for <code class="docutils literal notranslate"><span class="pre">float32</span></code>, <code class="docutils literal notranslate"><span class="pre">float16</span></code>, <code class="docutils literal notranslate"><span class="pre">bfloat16</span></code> and
<code class="docutils literal notranslate"><span class="pre">complex64</span></code>. We throw an error if we encounter an unexpected type.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Fall back implementation for evaluation on CPU */</span>
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_cpu</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
<span class="w"> </span><span class="c1">// Dispatch to the correct dtype</span>
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float32</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o">&lt;</span><span class="n">float16_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">bfloat16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o">&lt;</span><span class="n">bfloat16_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">complex64</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o">&lt;</span><span class="n">complex64_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span>
<span class="w"> </span><span class="s">&quot;[Axpby] Only supports floating point types.&quot;</span><span class="p">);</span>
<span class="w"> </span><span class="p">}</span>
<span class="p">}</span>
</pre></div>
</div>
<p>This is good as a fallback implementation. We can use the <code class="docutils literal notranslate"><span class="pre">axpby</span></code> routine
provided by the <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a> framework for a faster implementation in certain
cases:</p>
<ol class="arabic simple">
<li><p>Accelerate does not provide implementations of <code class="docutils literal notranslate"><span class="pre">axpby</span></code> for half precision
floats. We can only use it for <code class="docutils literal notranslate"><span class="pre">float32</span></code> types.</p></li>
<li><p>Accelerate assumes the inputs <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> are contiguous and all
elements have fixed strides between them. We only direct to Accelerate
if both <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> are row contiguous or column contiguous.</p></li>
<li><p>Accelerate performs the routine <code class="docutils literal notranslate"><span class="pre">Y</span> <span class="pre">=</span> <span class="pre">(alpha</span> <span class="pre">*</span> <span class="pre">X)</span> <span class="pre">+</span> <span class="pre">(beta</span> <span class="pre">*</span> <span class="pre">Y)</span></code> in-place.
MLX expects to write the output to a new array. We must copy the elements
of <code class="docutils literal notranslate"><span class="pre">y</span></code> into the output and use that as an input to <code class="docutils literal notranslate"><span class="pre">axpby</span></code>.</p></li>
</ol>
<p>Lets write an implementation that uses Accelerate in the right conditions.
It allocates data for the output, copies <code class="docutils literal notranslate"><span class="pre">y</span></code> into it, and then calls the
<code class="xref py py-func docutils literal notranslate"><span class="pre">catlas_saxpby()</span></code> from accelerate.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o">&lt;</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">&gt;</span>
<span class="kt">void</span><span class="w"> </span><span class="n">axpby_impl_accelerate</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">y</span><span class="p">,</span>
<span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="p">,</span>
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span>
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Accelerate library provides catlas_saxpby which does</span>
<span class="w"> </span><span class="c1">// Y = (alpha * X) + (beta * Y) in place</span>
<span class="w"> </span><span class="c1">// To use it, we first copy the data in y over to the output array</span>
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
<span class="w"> </span><span class="c1">// We then copy over the elements using the contiguous vector specialization</span>
<span class="w"> </span><span class="n">copy_inplace</span><span class="p">(</span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">CopyType</span><span class="o">::</span><span class="n">Vector</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Get x and y pointers for catlas_saxpby</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">();</span>
<span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">();</span>
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">(</span><span class="n">alpha_</span><span class="p">);</span>
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">(</span><span class="n">beta_</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Call the inplace accelerate operator</span>
<span class="w"> </span><span class="n">catlas_saxpby</span><span class="p">(</span>
<span class="w"> </span><span class="cm">/* N = */</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span>
<span class="w"> </span><span class="cm">/* ALPHA = */</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span>
<span class="w"> </span><span class="cm">/* X = */</span><span class="w"> </span><span class="n">x_ptr</span><span class="p">,</span>
<span class="w"> </span><span class="cm">/* INCX = */</span><span class="w"> </span><span class="mi">1</span><span class="p">,</span>
<span class="w"> </span><span class="cm">/* BETA = */</span><span class="w"> </span><span class="n">beta</span><span class="p">,</span>
<span class="w"> </span><span class="cm">/* Y = */</span><span class="w"> </span><span class="n">y_ptr</span><span class="p">,</span>
<span class="w"> </span><span class="cm">/* INCY = */</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span>
<span class="p">}</span>
</pre></div>
</div>
<p>For inputs that do not fit the criteria for accelerate, we fall back to
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval()</span></code>. With this in mind, lets finish our
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Evaluate primitive on CPU using accelerate specializations */</span>
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_cpu</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
<span class="w"> </span><span class="c1">// Accelerate specialization for contiguous single precision float arrays</span>
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float32</span><span class="w"> </span><span class="o">&amp;&amp;</span>
<span class="w"> </span><span class="p">((</span><span class="n">x</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">row_contiguous</span><span class="w"> </span><span class="o">&amp;&amp;</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">row_contiguous</span><span class="p">)</span><span class="w"> </span><span class="o">||</span>
<span class="w"> </span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">col_contiguous</span><span class="w"> </span><span class="o">&amp;&amp;</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">col_contiguous</span><span class="p">)))</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="n">axpby_impl_accelerate</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
<span class="w"> </span><span class="k">return</span><span class="p">;</span>
<span class="w"> </span><span class="p">}</span>
<span class="w"> </span><span class="c1">// Fall back to common back-end if specializations are not available</span>
<span class="w"> </span><span class="n">eval</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">outputs</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Dispatch to the correct dtype</span>
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">float32</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">float16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o">&lt;</span><span class="n">mx</span><span class="o">::</span><span class="n">float16_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">bfloat16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o">&lt;</span><span class="n">mx</span><span class="o">::</span><span class="n">bfloat16_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">complex64</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o">&lt;</span><span class="n">mx</span><span class="o">::</span><span class="n">complex64_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span>
<span class="w"> </span><span class="s">&quot;Axpby is only supported for floating point types.&quot;</span><span class="p">);</span>
<span class="w"> </span><span class="p">}</span>
<span class="p">}</span>
</pre></div>
</div>
<p>Just this much is enough to run the operation <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code>, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.</p>
primitive here.</p>
</section>
<section id="implementing-the-gpu-back-end">
<h3>Implementing the GPU Back-end<a class="headerlink" href="#implementing-the-gpu-back-end" title="Link to this heading">#</a></h3>
@@ -1688,18 +1619,16 @@ import the Python package and play with it as you would any other MLX operation.
<section id="results">
<h3>Results<a class="headerlink" href="#results" title="Link to this heading">#</a></h3>
<p>Lets run a quick benchmark and see how our new <code class="docutils literal notranslate"><span class="pre">axpby</span></code> operation compares
with the naive <code class="xref py py-meth docutils literal notranslate"><span class="pre">simple_axpby()</span></code> we first defined on the CPU.</p>
with the naive <code class="xref py py-meth docutils literal notranslate"><span class="pre">simple_axpby()</span></code> we first defined.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">mlx_sample_extensions</span><span class="w"> </span><span class="kn">import</span> <span class="n">axpby</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">time</span>
<span class="n">mx</span><span class="o">.</span><span class="n">set_default_device</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">simple_axpby</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">:</span>
<span class="k">return</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">y</span>
<span class="n">M</span> <span class="o">=</span> <span class="mi">256</span>
<span class="n">N</span> <span class="o">=</span> <span class="mi">512</span>
<span class="n">M</span> <span class="o">=</span> <span class="mi">4096</span>
<span class="n">N</span> <span class="o">=</span> <span class="mi">4096</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">))</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">))</span>
@@ -1710,25 +1639,25 @@ with the naive <code class="xref py py-meth docutils literal notranslate"><span
<span class="k">def</span><span class="w"> </span><span class="nf">bench</span><span class="p">(</span><span class="n">f</span><span class="p">):</span>
<span class="c1"># Warm up</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
<span class="c1"># Timed run</span>
<span class="n">s</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5000</span><span class="p">):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
<span class="n">e</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="k">return</span> <span class="n">e</span> <span class="o">-</span> <span class="n">s</span>
<span class="k">return</span> <span class="mi">1000</span> <span class="o">*</span> <span class="p">(</span><span class="n">e</span> <span class="o">-</span> <span class="n">s</span><span class="p">)</span> <span class="o">/</span> <span class="mi">100</span>
<span class="n">simple_time</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">simple_axpby</span><span class="p">)</span>
<span class="n">custom_time</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">axpby</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Simple axpby: </span><span class="si">{</span><span class="n">simple_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> s | Custom axpby: </span><span class="si">{</span><span class="n">custom_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> s&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Simple axpby: </span><span class="si">{</span><span class="n">simple_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms | Custom axpby: </span><span class="si">{</span><span class="n">custom_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms&quot;</span><span class="p">)</span>
</pre></div>
</div>
<p>The results are <code class="docutils literal notranslate"><span class="pre">Simple</span> <span class="pre">axpby:</span> <span class="pre">0.114</span> <span class="pre">s</span> <span class="pre">|</span> <span class="pre">Custom</span> <span class="pre">axpby:</span> <span class="pre">0.109</span> <span class="pre">s</span></code>. We see
<p>The results are <code class="docutils literal notranslate"><span class="pre">Simple</span> <span class="pre">axpby:</span> <span class="pre">1.559</span> <span class="pre">ms</span> <span class="pre">|</span> <span class="pre">Custom</span> <span class="pre">axpby:</span> <span class="pre">0.774</span> <span class="pre">ms</span></code>. We see
modest improvements right away!</p>
<p>This operation is now good to be used to build other operations, in
<a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> calls, and also as a part of graph transformations like

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Metal Debugger &#8212; MLX 0.23.2 documentation</title>
<title>Metal Debugger &#8212; MLX 0.24.0 documentation</title>
@@ -36,7 +36,7 @@
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
<script src="../_static/documentation_options.js?v=9900918c"></script>
<script src="../_static/documentation_options.js?v=ae1d10b0"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -137,8 +137,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.23.2 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.23.2 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.24.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.24.0 documentation - Home"/>`);</script>
</a></div>

View File

@@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Using MLX in C++ &#8212; MLX 0.23.2 documentation</title>
<title>Using MLX in C++ &#8212; MLX 0.24.0 documentation</title>
@@ -36,7 +36,7 @@
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
<script src="../_static/documentation_options.js?v=9900918c"></script>
<script src="../_static/documentation_options.js?v=ae1d10b0"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
@@ -136,8 +136,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.23.2 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.23.2 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.24.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.24.0 documentation - Home"/>`);</script>
</a></div>