mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 09:18:12 +08:00
rebase
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>Custom Metal Kernels — MLX 0.23.2 documentation</title>
|
||||
<title>Custom Metal Kernels — MLX 0.24.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
|
||||
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||||
|
||||
<script src="../_static/documentation_options.js?v=9900918c"></script>
|
||||
<script src="../_static/documentation_options.js?v=ae1d10b0"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -137,8 +137,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.23.2 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.23.2 documentation - Home"/>`);</script>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.24.0 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.24.0 documentation - Home"/>`);</script>
|
||||
|
||||
|
||||
</a></div>
|
||||
|
221
docs/build/html/dev/extensions.html
vendored
221
docs/build/html/dev/extensions.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>Custom Extensions in MLX — MLX 0.23.2 documentation</title>
|
||||
<title>Custom Extensions in MLX — MLX 0.24.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
|
||||
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||||
|
||||
<script src="../_static/documentation_options.js?v=9900918c"></script>
|
||||
<script src="../_static/documentation_options.js?v=ae1d10b0"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -137,8 +137,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.23.2 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.23.2 documentation - Home"/>`);</script>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.24.0 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.24.0 documentation - Home"/>`);</script>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -942,12 +942,12 @@ You can do that in MLX directly:</p>
|
||||
</div>
|
||||
<p>This function performs that operation while leaving the implementation and
|
||||
function transformations to MLX.</p>
|
||||
<p>However you may need to customize the underlying implementation, perhaps to
|
||||
make it faster or for custom differentiation. In this tutorial we will go
|
||||
through adding custom extensions. It will cover:</p>
|
||||
<p>However, you may want to customize the underlying implementation, perhaps to
|
||||
make it faster. In this tutorial we will go through adding custom extensions.
|
||||
It will cover:</p>
|
||||
<ul class="simple">
|
||||
<li><p>The structure of the MLX library.</p></li>
|
||||
<li><p>Implementing a CPU operation that redirects to <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a> when appropriate.</p></li>
|
||||
<li><p>Implementing a CPU operation.</p></li>
|
||||
<li><p>Implementing a GPU operation using metal.</p></li>
|
||||
<li><p>Adding the <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code> function transformation.</p></li>
|
||||
<li><p>Building a custom extension and binding it to python.</p></li>
|
||||
@@ -962,14 +962,14 @@ more detail.</p>
|
||||
<h3>Operations<a class="headerlink" href="#operations" title="Link to this heading">#</a></h3>
|
||||
<p>Operations are the front-end functions that operate on arrays. They are defined
|
||||
in the C++ API (<a class="reference internal" href="../cpp/ops.html#cpp-ops"><span class="std std-ref">Operations</span></a>), and the Python API (<a class="reference internal" href="../python/ops.html#ops"><span class="std std-ref">Operations</span></a>) binds them.</p>
|
||||
<p>We would like an operation, <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> that takes in two arrays <code class="docutils literal notranslate"><span class="pre">x</span></code> and
|
||||
<p>We would like an operation <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> that takes in two arrays, <code class="docutils literal notranslate"><span class="pre">x</span></code> and
|
||||
<code class="docutils literal notranslate"><span class="pre">y</span></code>, and two scalars, <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>. This is how to define it in
|
||||
C++:</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/**</span>
|
||||
<span class="cm">* Scale and sum two vectors element-wise</span>
|
||||
<span class="cm">* z = alpha * x + beta * y</span>
|
||||
<span class="cm">*</span>
|
||||
<span class="cm">* Follow numpy style broadcasting between x and y</span>
|
||||
<span class="cm">* Use NumPy-style broadcasting between x and y</span>
|
||||
<span class="cm">* Inputs are upcasted to floats if needed</span>
|
||||
<span class="cm">**/</span>
|
||||
<span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
|
||||
@@ -981,7 +981,7 @@ C++:</p>
|
||||
<span class="p">);</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The simplest way to this operation is in terms of existing operations:</p>
|
||||
<p>The simplest way to implement this is with existing operations:</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array x</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array y</span>
|
||||
@@ -1062,9 +1062,6 @@ more concrete:</p>
|
||||
<span class="w"> </span><span class="k">private</span><span class="o">:</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span>
|
||||
|
||||
<span class="w"> </span><span class="cm">/** Fall back implementation for evaluation on CPU */</span>
|
||||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">);</span>
|
||||
<span class="p">};</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
@@ -1093,7 +1090,7 @@ inputs that are passed to the primitive.</p>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">promoted_dtype</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">promote_types</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">dtype</span><span class="p">(),</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">dtype</span><span class="p">());</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Upcast to float32 for non-floating point inputs x and y</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">out_dtype</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">is_floating_point</span><span class="p">(</span><span class="n">promoted_dtype</span><span class="p">)</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">out_dtype</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">issubdtype</span><span class="p">(</span><span class="n">promoted_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">promoted_dtype</span>
|
||||
<span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">promote_types</span><span class="p">(</span><span class="n">promoted_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">float32</span><span class="p">);</span>
|
||||
|
||||
@@ -1139,153 +1136,87 @@ of these functions to allocate memory as needed.</p>
|
||||
</div>
|
||||
<section id="implementing-the-cpu-back-end">
|
||||
<h3>Implementing the CPU Back-end<a class="headerlink" href="#implementing-the-cpu-back-end" title="Link to this heading">#</a></h3>
|
||||
<p>Let’s start by implementing a naive and generic version of
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>. We declared this as a private member function of
|
||||
<code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> earlier called <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval()</span></code>.</p>
|
||||
<p>Our naive method will go over each element of the output array, find the
|
||||
<p>Let’s start by implementing <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>.</p>
|
||||
<p>The method will go over each element of the output array, find the
|
||||
corresponding input elements of <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> and perform the operation
|
||||
point-wise. This is captured in the templated function <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby_impl()</span></code>.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o"><</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">></span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="n">axpby_impl</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="c1">// We only allocate memory when we are ready to fill the output</span>
|
||||
<span class="w"> </span><span class="c1">// malloc_or_wait synchronously allocates available memory</span>
|
||||
<span class="w"> </span><span class="c1">// There may be a wait executed here if the allocation is requested</span>
|
||||
<span class="w"> </span><span class="c1">// under memory-pressured conditions</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">Stream</span><span class="w"> </span><span class="n">stream</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="c1">// Allocate the output with `malloc_or_wait` which synchronously allocates</span>
|
||||
<span class="w"> </span><span class="c1">// memory, potentially waiting if the system is under memory pressure</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">mx</span><span class="o">::</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Collect input and output data pointers</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">out_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="c1">// Get the CPU command encoder and register input and output arrays</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">encoder</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">cpu</span><span class="o">::</span><span class="n">get_command_encoder</span><span class="p">(</span><span class="n">stream</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">encoder</span><span class="p">.</span><span class="n">set_input_array</span><span class="p">(</span><span class="n">x</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">encoder</span><span class="p">.</span><span class="n">set_input_array</span><span class="p">(</span><span class="n">y</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">encoder</span><span class="p">.</span><span class="n">set_output_array</span><span class="p">(</span><span class="n">out</span><span class="p">);</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Launch the CPU kernel</span>
|
||||
<span class="w"> </span><span class="n">encoder</span><span class="p">.</span><span class="n">dispatch</span><span class="p">([</span><span class="n">x_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="n">y_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="n">out_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="n">size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="n">shape</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="n">x_strides</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="n">y_strides</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="n">alpha_</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">beta_</span><span class="p">]()</span><span class="w"> </span><span class="p">{</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Cast alpha and beta to the relevant types</span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">alpha_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">beta_</span><span class="p">);</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Do the element-wise operation for each output</span>
|
||||
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="kt">size_t</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="mi">0</span><span class="p">;</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o"><</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">();</span><span class="w"> </span><span class="n">out_idx</span><span class="o">++</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="c1">// Map linear indices to offsets in x and y</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="kt">size_t</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="mi">0</span><span class="p">;</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o"><</span><span class="w"> </span><span class="n">size</span><span class="p">;</span><span class="w"> </span><span class="n">out_idx</span><span class="o">++</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="c1">// Map linear indices to offsets in x and y</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="p">,</span><span class="w"> </span><span class="n">x_strides</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="p">,</span><span class="w"> </span><span class="n">y_strides</span><span class="p">);</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// We allocate the output to be contiguous and regularly strided</span>
|
||||
<span class="w"> </span><span class="c1">// (defaults to row major) and hence it doesn't need additional mapping</span>
|
||||
<span class="w"> </span><span class="n">out_ptr</span><span class="p">[</span><span class="n">out_idx</span><span class="p">]</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="p">[</span><span class="n">x_offset</span><span class="p">]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="p">[</span><span class="n">y_offset</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="c1">// We allocate the output to be contiguous and regularly strided</span>
|
||||
<span class="w"> </span><span class="c1">// (defaults to row major) and hence it doesn't need additional mapping</span>
|
||||
<span class="w"> </span><span class="n">out_ptr</span><span class="p">[</span><span class="n">out_idx</span><span class="p">]</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="p">[</span><span class="n">x_offset</span><span class="p">]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="p">[</span><span class="n">y_offset</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="w"> </span><span class="p">});</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Our implementation should work for all incoming floating point arrays.
|
||||
Accordingly, we add dispatches for <code class="docutils literal notranslate"><span class="pre">float32</span></code>, <code class="docutils literal notranslate"><span class="pre">float16</span></code>, <code class="docutils literal notranslate"><span class="pre">bfloat16</span></code> and
|
||||
<code class="docutils literal notranslate"><span class="pre">complex64</span></code>. We throw an error if we encounter an unexpected type.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Fall back implementation for evaluation on CPU */</span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_cpu</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Dispatch to the correct dtype</span>
|
||||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float32</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">float16_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">bfloat16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">bfloat16_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">complex64</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">complex64_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="s">"[Axpby] Only supports floating point types."</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>This is good as a fallback implementation. We can use the <code class="docutils literal notranslate"><span class="pre">axpby</span></code> routine
|
||||
provided by the <a class="reference external" href="https://developer.apple.com/documentation/accelerate/blas?language=objc">Accelerate</a> framework for a faster implementation in certain
|
||||
cases:</p>
|
||||
<ol class="arabic simple">
|
||||
<li><p>Accelerate does not provide implementations of <code class="docutils literal notranslate"><span class="pre">axpby</span></code> for half precision
|
||||
floats. We can only use it for <code class="docutils literal notranslate"><span class="pre">float32</span></code> types.</p></li>
|
||||
<li><p>Accelerate assumes the inputs <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> are contiguous and all
|
||||
elements have fixed strides between them. We only direct to Accelerate
|
||||
if both <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> are row contiguous or column contiguous.</p></li>
|
||||
<li><p>Accelerate performs the routine <code class="docutils literal notranslate"><span class="pre">Y</span> <span class="pre">=</span> <span class="pre">(alpha</span> <span class="pre">*</span> <span class="pre">X)</span> <span class="pre">+</span> <span class="pre">(beta</span> <span class="pre">*</span> <span class="pre">Y)</span></code> in-place.
|
||||
MLX expects to write the output to a new array. We must copy the elements
|
||||
of <code class="docutils literal notranslate"><span class="pre">y</span></code> into the output and use that as an input to <code class="docutils literal notranslate"><span class="pre">axpby</span></code>.</p></li>
|
||||
</ol>
|
||||
<p>Let’s write an implementation that uses Accelerate in the right conditions.
|
||||
It allocates data for the output, copies <code class="docutils literal notranslate"><span class="pre">y</span></code> into it, and then calls the
|
||||
<code class="xref py py-func docutils literal notranslate"><span class="pre">catlas_saxpby()</span></code> from accelerate.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o"><</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">></span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="n">axpby_impl_accelerate</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="c1">// Accelerate library provides catlas_saxpby which does</span>
|
||||
<span class="w"> </span><span class="c1">// Y = (alpha * X) + (beta * Y) in place</span>
|
||||
<span class="w"> </span><span class="c1">// To use it, we first copy the data in y over to the output array</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// We then copy over the elements using the contiguous vector specialization</span>
|
||||
<span class="w"> </span><span class="n">copy_inplace</span><span class="p">(</span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">CopyType</span><span class="o">::</span><span class="n">Vector</span><span class="p">);</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Get x and y pointers for catlas_saxpby</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span>
|
||||
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">alpha_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">beta_</span><span class="p">);</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Call the inplace accelerate operator</span>
|
||||
<span class="w"> </span><span class="n">catlas_saxpby</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="cm">/* N = */</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="cm">/* ALPHA = */</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="cm">/* X = */</span><span class="w"> </span><span class="n">x_ptr</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="cm">/* INCX = */</span><span class="w"> </span><span class="mi">1</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="cm">/* BETA = */</span><span class="w"> </span><span class="n">beta</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="cm">/* Y = */</span><span class="w"> </span><span class="n">y_ptr</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="cm">/* INCY = */</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>For inputs that do not fit the criteria for accelerate, we fall back to
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval()</span></code>. With this in mind, let’s finish our
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Evaluate primitive on CPU using accelerate specializations */</span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_cpu</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Accelerate specialization for contiguous single precision float arrays</span>
|
||||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float32</span><span class="w"> </span><span class="o">&&</span>
|
||||
<span class="w"> </span><span class="p">((</span><span class="n">x</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">row_contiguous</span><span class="w"> </span><span class="o">&&</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">row_contiguous</span><span class="p">)</span><span class="w"> </span><span class="o">||</span>
|
||||
<span class="w"> </span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">col_contiguous</span><span class="w"> </span><span class="o">&&</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">col_contiguous</span><span class="p">)))</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="n">axpby_impl_accelerate</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Fall back to common back-end if specializations are not available</span>
|
||||
<span class="w"> </span><span class="n">eval</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">outputs</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="c1">// Dispatch to the correct dtype</span>
|
||||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">float32</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">float16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">mx</span><span class="o">::</span><span class="n">float16_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">bfloat16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">mx</span><span class="o">::</span><span class="n">bfloat16_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">complex64</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">mx</span><span class="o">::</span><span class="n">complex64_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="s">"Axpby is only supported for floating point types."</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Just this much is enough to run the operation <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> on a CPU stream! If
|
||||
you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code>, you can stop implementing the
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.</p>
|
||||
primitive here.</p>
|
||||
</section>
|
||||
<section id="implementing-the-gpu-back-end">
|
||||
<h3>Implementing the GPU Back-end<a class="headerlink" href="#implementing-the-gpu-back-end" title="Link to this heading">#</a></h3>
|
||||
@@ -1688,18 +1619,16 @@ import the Python package and play with it as you would any other MLX operation.
|
||||
<section id="results">
|
||||
<h3>Results<a class="headerlink" href="#results" title="Link to this heading">#</a></h3>
|
||||
<p>Let’s run a quick benchmark and see how our new <code class="docutils literal notranslate"><span class="pre">axpby</span></code> operation compares
|
||||
with the naive <code class="xref py py-meth docutils literal notranslate"><span class="pre">simple_axpby()</span></code> we first defined on the CPU.</p>
|
||||
with the naive <code class="xref py py-meth docutils literal notranslate"><span class="pre">simple_axpby()</span></code> we first defined.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||||
<span class="kn">from</span><span class="w"> </span><span class="nn">mlx_sample_extensions</span><span class="w"> </span><span class="kn">import</span> <span class="n">axpby</span>
|
||||
<span class="kn">import</span><span class="w"> </span><span class="nn">time</span>
|
||||
|
||||
<span class="n">mx</span><span class="o">.</span><span class="n">set_default_device</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">simple_axpby</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">y</span>
|
||||
|
||||
<span class="n">M</span> <span class="o">=</span> <span class="mi">256</span>
|
||||
<span class="n">N</span> <span class="o">=</span> <span class="mi">512</span>
|
||||
<span class="n">M</span> <span class="o">=</span> <span class="mi">4096</span>
|
||||
<span class="n">N</span> <span class="o">=</span> <span class="mi">4096</span>
|
||||
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">))</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">))</span>
|
||||
@@ -1710,25 +1639,25 @@ with the naive <code class="xref py py-meth docutils literal notranslate"><span
|
||||
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">bench</span><span class="p">(</span><span class="n">f</span><span class="p">):</span>
|
||||
<span class="c1"># Warm up</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span>
|
||||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Timed run</span>
|
||||
<span class="n">s</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5000</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span>
|
||||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||||
<span class="n">e</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
|
||||
<span class="k">return</span> <span class="n">e</span> <span class="o">-</span> <span class="n">s</span>
|
||||
<span class="k">return</span> <span class="mi">1000</span> <span class="o">*</span> <span class="p">(</span><span class="n">e</span> <span class="o">-</span> <span class="n">s</span><span class="p">)</span> <span class="o">/</span> <span class="mi">100</span>
|
||||
|
||||
<span class="n">simple_time</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">simple_axpby</span><span class="p">)</span>
|
||||
<span class="n">custom_time</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">axpby</span><span class="p">)</span>
|
||||
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Simple axpby: </span><span class="si">{</span><span class="n">simple_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> s | Custom axpby: </span><span class="si">{</span><span class="n">custom_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> s"</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Simple axpby: </span><span class="si">{</span><span class="n">simple_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms | Custom axpby: </span><span class="si">{</span><span class="n">custom_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms"</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The results are <code class="docutils literal notranslate"><span class="pre">Simple</span> <span class="pre">axpby:</span> <span class="pre">0.114</span> <span class="pre">s</span> <span class="pre">|</span> <span class="pre">Custom</span> <span class="pre">axpby:</span> <span class="pre">0.109</span> <span class="pre">s</span></code>. We see
|
||||
<p>The results are <code class="docutils literal notranslate"><span class="pre">Simple</span> <span class="pre">axpby:</span> <span class="pre">1.559</span> <span class="pre">ms</span> <span class="pre">|</span> <span class="pre">Custom</span> <span class="pre">axpby:</span> <span class="pre">0.774</span> <span class="pre">ms</span></code>. We see
|
||||
modest improvements right away!</p>
|
||||
<p>This operation is now good to be used to build other operations, in
|
||||
<a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> calls, and also as a part of graph transformations like
|
||||
|
8
docs/build/html/dev/metal_debugger.html
vendored
8
docs/build/html/dev/metal_debugger.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>Metal Debugger — MLX 0.23.2 documentation</title>
|
||||
<title>Metal Debugger — MLX 0.24.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
|
||||
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||||
|
||||
<script src="../_static/documentation_options.js?v=9900918c"></script>
|
||||
<script src="../_static/documentation_options.js?v=ae1d10b0"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -137,8 +137,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.23.2 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.23.2 documentation - Home"/>`);</script>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.24.0 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.24.0 documentation - Home"/>`);</script>
|
||||
|
||||
|
||||
</a></div>
|
||||
|
8
docs/build/html/dev/mlx_in_cpp.html
vendored
8
docs/build/html/dev/mlx_in_cpp.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>Using MLX in C++ — MLX 0.23.2 documentation</title>
|
||||
<title>Using MLX in C++ — MLX 0.24.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
|
||||
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||||
|
||||
<script src="../_static/documentation_options.js?v=9900918c"></script>
|
||||
<script src="../_static/documentation_options.js?v=ae1d10b0"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -136,8 +136,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.23.2 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.23.2 documentation - Home"/>`);</script>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.24.0 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.24.0 documentation - Home"/>`);</script>
|
||||
|
||||
|
||||
</a></div>
|
||||
|
Reference in New Issue
Block a user