mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
docs up
This commit is contained in:

committed by
CircleCI Docs

parent
2aeb6df29c
commit
7534da7269
482
docs/build/html/dev/extensions.html
vendored
482
docs/build/html/dev/extensions.html
vendored
@@ -147,9 +147,12 @@
|
||||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Usage</span></p>
|
||||
<ul class="nav bd-sidenav">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/quick_start.html">Quick Start Guide</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/lazy_evaluation.html">Lazy Evaluation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/unified_memory.html">Unified Memory</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/using_streams.html">Using Streams</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/indexing.html">Indexing Arrays</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/saving_and_loading.html">Saving and Loading Arrays</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/numpy.html">Conversion to NumPy and Other Frameworks</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/using_streams.html">Using Streams</a></li>
|
||||
</ul>
|
||||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
||||
<ul class="nav bd-sidenav">
|
||||
@@ -237,6 +240,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cosh.html">mlx.core.cosh</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.dequantize.html">mlx.core.dequantize</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.divide.html">mlx.core.divide</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.divmod.html">mlx.core.divmod</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.equal.html">mlx.core.equal</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erf.html">mlx.core.erf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erfinv.html">mlx.core.erfinv</a></li>
|
||||
@@ -250,6 +254,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.greater.html">mlx.core.greater</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.greater_equal.html">mlx.core.greater_equal</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.identity.html">mlx.core.identity</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.inner.html">mlx.core.inner</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linspace.html">mlx.core.linspace</a></li>
|
||||
@@ -260,6 +265,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.log1p.html">mlx.core.log1p</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logaddexp.html">mlx.core.logaddexp</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logical_not.html">mlx.core.logical_not</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logical_and.html">mlx.core.logical_and</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logical_or.html">mlx.core.logical_or</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logsumexp.html">mlx.core.logsumexp</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.matmul.html">mlx.core.matmul</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.max.html">mlx.core.max</a></li>
|
||||
@@ -272,6 +279,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.negative.html">mlx.core.negative</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.ones.html">mlx.core.ones</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.ones_like.html">mlx.core.ones_like</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.outer.html">mlx.core.outer</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.partition.html">mlx.core.partition</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.pad.html">mlx.core.pad</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.prod.html">mlx.core.prod</a></li>
|
||||
@@ -285,6 +293,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save.html">mlx.core.save</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.savez.html">mlx.core.savez</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.savez_compressed.html">mlx.core.savez_compressed</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save_gguf.html">mlx.core.save_gguf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save_safetensors.html">mlx.core.save_safetensors</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sigmoid.html">mlx.core.sigmoid</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
|
||||
@@ -434,6 +443,7 @@
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html">mlx.nn.losses.hinge_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.html">mlx.nn.losses.huber_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html">mlx.nn.losses.log_cosh_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.html">mlx.nn.losses.cosine_similarity_loss</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
@@ -726,33 +736,33 @@ C++ API:</p>
|
||||
<span class="cm">*</span>
|
||||
<span class="cm">* Follow numpy style broadcasting between x and y</span>
|
||||
<span class="cm">* Inputs are upcasted to floats if needed</span>
|
||||
<span class="cm">**/</span>
|
||||
<span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
|
||||
<span class="cm">**/</span><span class="w"></span>
|
||||
<span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array x</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array y</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="c1">// Scaling factor for x</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta</span><span class="p">,</span><span class="w"> </span><span class="c1">// Scaling factor for y</span>
|
||||
<span class="w"> </span><span class="n">StreamOrDevice</span><span class="w"> </span><span class="n">s</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="p">{}</span><span class="w"> </span><span class="c1">// Stream on which to schedule the operation</span>
|
||||
<span class="p">);</span>
|
||||
<span class="p">);</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>This operation itself can call other operations within it if needed. So, the
|
||||
simplest way to go about implementing this operation would be do so in terms
|
||||
of existing operations.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array x</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array y</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="c1">// Scaling factor for x</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta</span><span class="p">,</span><span class="w"> </span><span class="c1">// Scaling factor for y</span>
|
||||
<span class="w"> </span><span class="n">StreamOrDevice</span><span class="w"> </span><span class="n">s</span><span class="w"> </span><span class="cm">/* = {} */</span><span class="w"> </span><span class="c1">// Stream on which to schedule the operation</span>
|
||||
<span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="c1">// Scale x and y on the provided stream</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">ax</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">multiply</span><span class="p">(</span><span class="n">array</span><span class="p">(</span><span class="n">alpha</span><span class="p">),</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">by</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">multiply</span><span class="p">(</span><span class="n">array</span><span class="p">(</span><span class="n">beta</span><span class="p">),</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">ax</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">multiply</span><span class="p">(</span><span class="n">array</span><span class="p">(</span><span class="n">alpha</span><span class="p">),</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">by</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">multiply</span><span class="p">(</span><span class="n">array</span><span class="p">(</span><span class="n">beta</span><span class="p">),</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Add and return</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">add</span><span class="p">(</span><span class="n">ax</span><span class="p">,</span><span class="w"> </span><span class="n">by</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span>
|
||||
<span class="p">}</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">add</span><span class="p">(</span><span class="n">ax</span><span class="p">,</span><span class="w"> </span><span class="n">by</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="p">}</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>However, as we discussed earlier, this is not our goal. The operations themselves
|
||||
@@ -768,10 +778,10 @@ a <code class="xref py py-class docutils literal notranslate"><span class="pre">
|
||||
on the CPU or GPU, and how it acts under transformations such as <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and
|
||||
<code class="docutils literal notranslate"><span class="pre">jvp</span></code>. These words on their own can be a bit abstract, so lets take a step
|
||||
back and go to our example to give ourselves a more concrete image.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">Axpby</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="k">public</span><span class="w"> </span><span class="n">Primitive</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">public</span><span class="o">:</span>
|
||||
<span class="w"> </span><span class="k">explicit</span><span class="w"> </span><span class="n">Axpby</span><span class="p">(</span><span class="n">Stream</span><span class="w"> </span><span class="n">stream</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta</span><span class="p">)</span>
|
||||
<span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">Primitive</span><span class="p">(</span><span class="n">stream</span><span class="p">),</span><span class="w"> </span><span class="n">alpha_</span><span class="p">(</span><span class="n">alpha</span><span class="p">),</span><span class="w"> </span><span class="n">beta_</span><span class="p">(</span><span class="n">beta</span><span class="p">){};</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">Axpby</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="k">public</span><span class="w"> </span><span class="n">Primitive</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">public</span><span class="o">:</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">explicit</span><span class="w"> </span><span class="n">Axpby</span><span class="p">(</span><span class="n">Stream</span><span class="w"> </span><span class="n">stream</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta</span><span class="p">)</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">Primitive</span><span class="p">(</span><span class="n">stream</span><span class="p">),</span><span class="w"> </span><span class="n">alpha_</span><span class="p">(</span><span class="n">alpha</span><span class="p">),</span><span class="w"> </span><span class="n">beta_</span><span class="p">(</span><span class="n">beta</span><span class="p">){};</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="cm">/**</span>
|
||||
<span class="cm"> * A primitive must know how to evaluate itself on the CPU/GPU</span>
|
||||
@@ -779,47 +789,47 @@ back and go to our example to give ourselves a more concrete image.</p>
|
||||
<span class="cm"> *</span>
|
||||
<span class="cm"> * To avoid unnecessary allocations, the evaluation function</span>
|
||||
<span class="cm"> * is responsible for allocating space for the array.</span>
|
||||
<span class="cm"> */</span>
|
||||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_cpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_gpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
<span class="cm"> */</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_cpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_gpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="cm">/** The Jacobian-vector product. */</span>
|
||||
<span class="w"> </span><span class="n">array</span><span class="w"> </span><span class="nf">jvp</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="cm">/** The Jacobian-vector product. */</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">array</span><span class="w"> </span><span class="nf">jvp</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="cm">/** The vector-Jacobian product. */</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">vjp</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="cm">/** The vector-Jacobian product. */</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">vjp</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="cm">/**</span>
|
||||
<span class="cm"> * The primitive must know how to vectorize itself across</span>
|
||||
<span class="cm"> * the given axes. The output is a pair containing the array</span>
|
||||
<span class="cm"> * representing the vectorized computation and the axis which</span>
|
||||
<span class="cm"> * corresponds to the output vectorized dimension.</span>
|
||||
<span class="cm"> */</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="n">array</span><span class="p">,</span><span class="w"> </span><span class="kt">int</span><span class="o">></span><span class="w"> </span><span class="n">vmap</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">axes</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
<span class="cm"> */</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="n">array</span><span class="p">,</span><span class="w"> </span><span class="kt">int</span><span class="o">></span><span class="w"> </span><span class="n">vmap</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">axes</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="cm">/** Print the primitive. */</span>
|
||||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">print</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">ostream</span><span class="o">&</span><span class="w"> </span><span class="n">os</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="n">os</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"Axpby"</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="w"> </span><span class="cm">/** Print the primitive. */</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">print</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">ostream</span><span class="o">&</span><span class="w"> </span><span class="n">os</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">os</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"Axpby"</span><span class="p">;</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="cm">/** Equivalence check **/</span>
|
||||
<span class="w"> </span><span class="kt">bool</span><span class="w"> </span><span class="nf">is_equivalent</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">Primitive</span><span class="o">&</span><span class="w"> </span><span class="n">other</span><span class="p">)</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="cm">/** Equivalence check **/</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="kt">bool</span><span class="w"> </span><span class="nf">is_equivalent</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">Primitive</span><span class="o">&</span><span class="w"> </span><span class="n">other</span><span class="p">)</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="k">override</span><span class="p">;</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="k">private</span><span class="o">:</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="k">private</span><span class="o">:</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">;</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="cm">/** Fall back implementation for evaluation on CPU */</span>
|
||||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">);</span>
|
||||
<span class="p">};</span>
|
||||
<span class="w"> </span><span class="cm">/** Fall back implementation for evaluation on CPU */</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="p">};</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> class derives from the base <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> class and
|
||||
@@ -836,38 +846,38 @@ the computation graph. An <code class="xref py py-class docutils literal notrans
|
||||
data type, shape, the <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> that computes it, and the
|
||||
<code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> inputs that are passed to the primitive.</p>
|
||||
<p>Let’s re-implement our operation now in terms of our <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> primitive.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array x</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array y</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="c1">// Scaling factor for x</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta</span><span class="p">,</span><span class="w"> </span><span class="c1">// Scaling factor for y</span>
|
||||
<span class="w"> </span><span class="n">StreamOrDevice</span><span class="w"> </span><span class="n">s</span><span class="w"> </span><span class="cm">/* = {} */</span><span class="w"> </span><span class="c1">// Stream on which to schedule the operation</span>
|
||||
<span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="c1">// Promote dtypes between x and y as needed</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">promoted_dtype</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">promote_types</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">dtype</span><span class="p">(),</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">dtype</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">promoted_dtype</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">promote_types</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">dtype</span><span class="p">(),</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">dtype</span><span class="p">());</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Upcast to float32 for non-floating point inputs x and y</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">out_dtype</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">is_floating_point</span><span class="p">(</span><span class="n">promoted_dtype</span><span class="p">)</span>
|
||||
<span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">promoted_dtype</span>
|
||||
<span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">promote_types</span><span class="p">(</span><span class="n">promoted_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">float32</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">out_dtype</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">is_floating_point</span><span class="p">(</span><span class="n">promoted_dtype</span><span class="p">)</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">promoted_dtype</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">promote_types</span><span class="p">(</span><span class="n">promoted_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">float32</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Cast x and y up to the determined dtype (on the same stream s)</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_casted</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">astype</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">out_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_casted</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">astype</span><span class="p">(</span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_casted</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">astype</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">out_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_casted</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">astype</span><span class="p">(</span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Broadcast the shapes of x and y (on the same stream s)</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">broadcasted_inputs</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">broadcast_arrays</span><span class="p">({</span><span class="n">x_casted</span><span class="p">,</span><span class="w"> </span><span class="n">y_casted</span><span class="p">},</span><span class="w"> </span><span class="n">s</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">out_shape</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">broadcasted_inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">shape</span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">broadcasted_inputs</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">broadcast_arrays</span><span class="p">({</span><span class="n">x_casted</span><span class="p">,</span><span class="w"> </span><span class="n">y_casted</span><span class="p">},</span><span class="w"> </span><span class="n">s</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">out_shape</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">broadcasted_inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">shape</span><span class="p">();</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Construct the array as the output of the Axpby primitive</span>
|
||||
<span class="w"> </span><span class="c1">// with the broadcasted and upcasted arrays as inputs</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">array</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="cm">/* const std::vector<int>& shape = */</span><span class="w"> </span><span class="n">out_shape</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="cm">/* Dtype dtype = */</span><span class="w"> </span><span class="n">out_dtype</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="cm">/* std::unique_ptr<Primitive> primitive = */</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">make_unique</span><span class="o"><</span><span class="n">Axpby</span><span class="o">></span><span class="p">(</span><span class="n">to_stream</span><span class="p">(</span><span class="n">s</span><span class="p">),</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="n">beta</span><span class="p">),</span>
|
||||
<span class="w"> </span><span class="cm">/* const std::vector<array>& inputs = */</span><span class="w"> </span><span class="n">broadcasted_inputs</span><span class="p">);</span>
|
||||
<span class="p">}</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="cm">/* const std::vector<int>& shape = */</span><span class="w"> </span><span class="n">out_shape</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="cm">/* Dtype dtype = */</span><span class="w"> </span><span class="n">out_dtype</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="cm">/* std::unique_ptr<Primitive> primitive = */</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">make_unique</span><span class="o"><</span><span class="n">Axpby</span><span class="o">></span><span class="p">(</span><span class="n">to_stream</span><span class="p">(</span><span class="n">s</span><span class="p">),</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="n">beta</span><span class="p">),</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="cm">/* const std::vector<array>& inputs = */</span><span class="w"> </span><span class="n">broadcasted_inputs</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="p">}</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>This operation now handles the following:</p>
|
||||
@@ -900,66 +910,66 @@ of these functions to allocate memory as needed</p>
|
||||
<p>Our naive method will go over each element of the output array, find the
|
||||
corresponding input elements of <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> and perform the operation
|
||||
pointwise. This is captured in the templated function <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby_impl()</span></code>.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o"><</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">></span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="n">axpby_impl</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o"><</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">></span><span class="w"></span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="n">axpby_impl</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="c1">// We only allocate memory when we are ready to fill the output</span>
|
||||
<span class="w"> </span><span class="c1">// malloc_or_wait synchronously allocates available memory</span>
|
||||
<span class="w"> </span><span class="c1">// There may be a wait executed here if the allocation is requested</span>
|
||||
<span class="w"> </span><span class="c1">// under memory-pressured conditions</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Collect input and output data pointers</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">out_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">out_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Cast alpha and beta to the relevant types</span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">alpha_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">beta_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">alpha_</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">beta_</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Do the element-wise operation for each output</span>
|
||||
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="kt">size_t</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="mi">0</span><span class="p">;</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o"><</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">();</span><span class="w"> </span><span class="n">out_idx</span><span class="o">++</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="kt">size_t</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="mi">0</span><span class="p">;</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o"><</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">();</span><span class="w"> </span><span class="n">out_idx</span><span class="o">++</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="c1">// Map linear indices to offsets in x and y</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">());</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">());</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// We allocate the output to be contiguous and regularly strided</span>
|
||||
<span class="w"> </span><span class="c1">// (defaults to row major) and hence it doesn't need additional mapping</span>
|
||||
<span class="w"> </span><span class="n">out_ptr</span><span class="p">[</span><span class="n">out_idx</span><span class="p">]</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="p">[</span><span class="n">x_offset</span><span class="p">]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="p">[</span><span class="n">y_offset</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="p">}</span>
|
||||
<span class="w"> </span><span class="n">out_ptr</span><span class="p">[</span><span class="n">out_idx</span><span class="p">]</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="p">[</span><span class="n">x_offset</span><span class="p">]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="p">[</span><span class="n">y_offset</span><span class="p">];</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"></span>
|
||||
<span class="p">}</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Now, we would like our implementation to be able to do this pointwise operation
|
||||
for all incoming floating point arrays. Accordingly, we add dispatches for
|
||||
<code class="docutils literal notranslate"><span class="pre">float32</span></code>, <code class="docutils literal notranslate"><span class="pre">float16</span></code>, <code class="docutils literal notranslate"><span class="pre">bfloat16</span></code> and <code class="docutils literal notranslate"><span class="pre">complex64</span></code>. We throw an error
|
||||
if we encounter an unexpected type.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Fall back implementation for evaluation on CPU */</span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Fall back implementation for evaluation on CPU */</span><span class="w"></span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="c1">// Check the inputs (registered in the op while constructing the out array)</span>
|
||||
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Dispatch to the correct dtype</span>
|
||||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float32</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">float16_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">bfloat16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">bfloat16_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">complex64</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">complex64_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="s">"Axpby is only supported for floating point types."</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="p">}</span>
|
||||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float32</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">float16_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">bfloat16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">bfloat16_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">complex64</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">complex64_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="s">"Axpby is only supported for floating point types."</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"></span>
|
||||
<span class="p">}</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We have a fallback implementation! Now, to do what we are really here to do.
|
||||
@@ -980,13 +990,13 @@ of <code class="docutils literal notranslate"><span class="pre">y</span></code>
|
||||
<p>Let’s write out an implementation that uses Accelerate in the right conditions.
|
||||
It must simply allocate data for the output, copy elements of <code class="docutils literal notranslate"><span class="pre">y</span></code> into it,
|
||||
and then call the <code class="xref py py-meth docutils literal notranslate"><span class="pre">catlas_saxpby()</span></code> from accelerate.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o"><</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">></span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="n">axpby_impl_accelerate</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o"><</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">></span><span class="w"></span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="n">axpby_impl_accelerate</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="c1">// Accelerate library provides catlas_saxpby which does</span>
|
||||
<span class="w"> </span><span class="c1">// Y = (alpha * X) + (beta * Y) in place</span>
|
||||
<span class="w"> </span><span class="c1">// To use it, we first copy the data in y over to the output array</span>
|
||||
@@ -996,54 +1006,54 @@ and then call the <code class="xref py py-meth docutils literal notranslate"><sp
|
||||
<span class="w"> </span><span class="c1">// The data in the output array is allocated to match the strides in y</span>
|
||||
<span class="w"> </span><span class="c1">// such that x, y, and out are contiguous in the same mode and</span>
|
||||
<span class="w"> </span><span class="c1">// no transposition is needed</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">data_size</span><span class="p">()</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">itemsize</span><span class="p">()),</span>
|
||||
<span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">data_size</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">flags</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">data_size</span><span class="p">()</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">itemsize</span><span class="p">()),</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">data_size</span><span class="p">(),</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">flags</span><span class="p">());</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// We then copy over the elements using the contiguous vector specialization</span>
|
||||
<span class="w"> </span><span class="n">copy_inplace</span><span class="p">(</span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">CopyType</span><span class="o">::</span><span class="n">Vector</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">copy_inplace</span><span class="p">(</span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">CopyType</span><span class="o">::</span><span class="n">Vector</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Get x and y pointers for catlas_saxpby</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">();</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">alpha_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">beta_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">alpha_</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">beta_</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Call the inplace accelerate operator</span>
|
||||
<span class="w"> </span><span class="n">catlas_saxpby</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="cm">/* N = */</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="cm">/* ALPHA = */</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="cm">/* X = */</span><span class="w"> </span><span class="n">x_ptr</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="cm">/* INCX = */</span><span class="w"> </span><span class="mi">1</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="cm">/* BETA = */</span><span class="w"> </span><span class="n">beta</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="cm">/* Y = */</span><span class="w"> </span><span class="n">y_ptr</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="cm">/* INCY = */</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span>
|
||||
<span class="p">}</span>
|
||||
<span class="w"> </span><span class="n">catlas_saxpby</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="cm">/* N = */</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="cm">/* ALPHA = */</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="cm">/* X = */</span><span class="w"> </span><span class="n">x_ptr</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="cm">/* INCX = */</span><span class="w"> </span><span class="mi">1</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="cm">/* BETA = */</span><span class="w"> </span><span class="n">beta</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="cm">/* Y = */</span><span class="w"> </span><span class="n">y_ptr</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="cm">/* INCY = */</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="p">}</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Great! But what about the inputs that do not fit the criteria for accelerate?
|
||||
Luckily, we can always just direct back to <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval()</span></code>.</p>
|
||||
<p>With this in mind, lets finally implement our <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Evaluate primitive on CPU using accelerate specializations */</span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_cpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Evaluate primitive on CPU using accelerate specializations */</span><span class="w"></span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_cpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Accelerate specialization for contiguous single precision float arrays</span>
|
||||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float32</span><span class="w"> </span><span class="o">&&</span>
|
||||
<span class="w"> </span><span class="p">((</span><span class="n">x</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">row_contiguous</span><span class="w"> </span><span class="o">&&</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">row_contiguous</span><span class="p">)</span><span class="w"> </span><span class="o">||</span>
|
||||
<span class="w"> </span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">col_contiguous</span><span class="w"> </span><span class="o">&&</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">col_contiguous</span><span class="p">)))</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="n">axpby_impl_accelerate</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">float32</span><span class="w"> </span><span class="o">&&</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="p">((</span><span class="n">x</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">row_contiguous</span><span class="w"> </span><span class="o">&&</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">row_contiguous</span><span class="p">)</span><span class="w"> </span><span class="o">||</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">col_contiguous</span><span class="w"> </span><span class="o">&&</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">flags</span><span class="p">().</span><span class="n">col_contiguous</span><span class="p">)))</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">axpby_impl_accelerate</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="p">;</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Fall back to common backend if specializations are not available</span>
|
||||
<span class="w"> </span><span class="n">eval</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">);</span>
|
||||
<span class="p">}</span>
|
||||
<span class="w"> </span><span class="n">eval</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="p">}</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We have now hit a milestone! Just this much is enough to run the operation
|
||||
@@ -1069,26 +1079,26 @@ all GPU kernels in MLX are written using metal.</p>
|
||||
as there are elements in the output. Each thread will pick the element it needs
|
||||
from <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>, do the pointwise operation, and then update its assigned
|
||||
element in the output.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o"><</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">></span>
|
||||
<span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="n">axpby_general</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="n">device</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">0</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">device</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">1</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">device</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">2</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="o">&</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">3</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="o">&</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">4</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">int</span><span class="o">*</span><span class="w"> </span><span class="n">shape</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">5</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">size_t</span><span class="o">*</span><span class="w"> </span><span class="n">x_strides</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">6</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">size_t</span><span class="o">*</span><span class="w"> </span><span class="n">y_strides</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">7</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">int</span><span class="o">&</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">8</span><span class="p">)]],</span>
|
||||
<span class="w"> </span><span class="n">uint</span><span class="w"> </span><span class="n">index</span><span class="w"> </span><span class="p">[[</span><span class="n">thread_position_in_grid</span><span class="p">]])</span><span class="w"> </span><span class="p">{</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o"><</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">></span><span class="w"></span>
|
||||
<span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="n">axpby_general</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">device</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">0</span><span class="p">)]],</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">device</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">1</span><span class="p">)]],</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">device</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">2</span><span class="p">)]],</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="o">&</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">3</span><span class="p">)]],</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="o">&</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">4</span><span class="p">)]],</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">int</span><span class="o">*</span><span class="w"> </span><span class="n">shape</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">5</span><span class="p">)]],</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">size_t</span><span class="o">*</span><span class="w"> </span><span class="n">x_strides</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">6</span><span class="p">)]],</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">size_t</span><span class="o">*</span><span class="w"> </span><span class="n">y_strides</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">7</span><span class="p">)]],</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">int</span><span class="o">&</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">8</span><span class="p">)]],</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">uint</span><span class="w"> </span><span class="n">index</span><span class="w"> </span><span class="p">[[</span><span class="n">thread_position_in_grid</span><span class="p">]])</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="c1">// Convert linear indices to offsets in array</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">index</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="p">,</span><span class="w"> </span><span class="n">x_strides</span><span class="p">,</span><span class="w"> </span><span class="n">ndim</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">index</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="p">,</span><span class="w"> </span><span class="n">y_strides</span><span class="p">,</span><span class="w"> </span><span class="n">ndim</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">index</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="p">,</span><span class="w"> </span><span class="n">x_strides</span><span class="p">,</span><span class="w"> </span><span class="n">ndim</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">index</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="p">,</span><span class="w"> </span><span class="n">y_strides</span><span class="p">,</span><span class="w"> </span><span class="n">ndim</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Do the operation and update the output</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">[</span><span class="n">index</span><span class="p">]</span><span class="w"> </span><span class="o">=</span>
|
||||
<span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">alpha</span><span class="p">)</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">x</span><span class="p">[</span><span class="n">x_offset</span><span class="p">]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">beta</span><span class="p">)</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">y</span><span class="p">[</span><span class="n">y_offset</span><span class="p">];</span>
|
||||
<span class="p">}</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">[</span><span class="n">index</span><span class="p">]</span><span class="w"> </span><span class="o">=</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">alpha</span><span class="p">)</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">x</span><span class="p">[</span><span class="n">x_offset</span><span class="p">]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">beta</span><span class="p">)</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">y</span><span class="p">[</span><span class="n">y_offset</span><span class="p">];</span><span class="w"></span>
|
||||
<span class="p">}</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We then need to instantiate this template for all floating point types and give
|
||||
@@ -1108,10 +1118,10 @@ each data type.</p>
|
||||
<span class="cp"> constant const int& ndim [[buffer(8)]], \</span>
|
||||
<span class="cp"> uint index [[thread_position_in_grid]]);</span>
|
||||
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">float32</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="p">);</span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">float16</span><span class="p">,</span><span class="w"> </span><span class="n">half</span><span class="p">);</span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">bfloat16</span><span class="p">,</span><span class="w"> </span><span class="n">bfloat16_t</span><span class="p">);</span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">complex64</span><span class="p">,</span><span class="w"> </span><span class="n">complex64_t</span><span class="p">);</span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">float32</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">float16</span><span class="p">,</span><span class="w"> </span><span class="n">half</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">bfloat16</span><span class="p">,</span><span class="w"> </span><span class="n">bfloat16_t</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">complex64</span><span class="p">,</span><span class="w"> </span><span class="n">complex64_t</span><span class="p">);</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>This kernel will be compiled into a metal library <code class="docutils literal notranslate"><span class="pre">mlx_ext.metallib</span></code> as we
|
||||
@@ -1127,73 +1137,73 @@ go over this process in more detail later.</p>
|
||||
<p>The logic to determine the kernel, set the inputs, resolve the grid dimensions
|
||||
and dispatch it to the GPU are contained in <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code> as shown
|
||||
below.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Evaluate primitive on GPU */</span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_gpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Evaluate primitive on GPU */</span><span class="w"></span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_gpu</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="c1">// Prepare inputs</span>
|
||||
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Each primitive carries the stream it should execute on</span>
|
||||
<span class="w"> </span><span class="c1">// and each stream carries its device identifiers</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">s</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">stream</span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">s</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">stream</span><span class="p">();</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="c1">// We get the needed metal device using the stream</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">d</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">metal</span><span class="o">::</span><span class="n">device</span><span class="p">(</span><span class="n">s</span><span class="p">.</span><span class="n">device</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">d</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">metal</span><span class="o">::</span><span class="n">device</span><span class="p">(</span><span class="n">s</span><span class="p">.</span><span class="n">device</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Allocate output memory</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
|
||||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc_or_wait</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Resolve name of kernel (corresponds to axpby.metal)</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">ostringstream</span><span class="w"> </span><span class="n">kname</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="n">kname</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"axpby_"</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"general_"</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">type_to_name</span><span class="p">(</span><span class="n">out</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">ostringstream</span><span class="w"> </span><span class="n">kname</span><span class="p">;</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">kname</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"axpby_"</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"general_"</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">type_to_name</span><span class="p">(</span><span class="n">out</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Make sure the metal library is available and look for it</span>
|
||||
<span class="w"> </span><span class="c1">// in the same folder as this executable if needed</span>
|
||||
<span class="w"> </span><span class="n">d</span><span class="p">.</span><span class="n">register_library</span><span class="p">(</span><span class="s">"mlx_ext"</span><span class="p">,</span><span class="w"> </span><span class="n">metal</span><span class="o">::</span><span class="n">get_colocated_mtllib_path</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">d</span><span class="p">.</span><span class="n">register_library</span><span class="p">(</span><span class="s">"mlx_ext"</span><span class="p">,</span><span class="w"> </span><span class="n">metal</span><span class="o">::</span><span class="n">get_colocated_mtllib_path</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Make a kernel from this metal library</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">kernel</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">d</span><span class="p">.</span><span class="n">get_kernel</span><span class="p">(</span><span class="n">kname</span><span class="p">.</span><span class="n">str</span><span class="p">(),</span><span class="w"> </span><span class="s">"mlx_ext"</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">kernel</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">d</span><span class="p">.</span><span class="n">get_kernel</span><span class="p">(</span><span class="n">kname</span><span class="p">.</span><span class="n">str</span><span class="p">(),</span><span class="w"> </span><span class="s">"mlx_ext"</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Prepare to encode kernel</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">compute_encoder</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">d</span><span class="p">.</span><span class="n">get_command_encoder</span><span class="p">(</span><span class="n">s</span><span class="p">.</span><span class="n">index</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setComputePipelineState</span><span class="p">(</span><span class="n">kernel</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">compute_encoder</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">d</span><span class="p">.</span><span class="n">get_command_encoder</span><span class="p">(</span><span class="n">s</span><span class="p">.</span><span class="n">index</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setComputePipelineState</span><span class="p">(</span><span class="n">kernel</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Kernel parameters are registered with buffer indices corresponding to</span>
|
||||
<span class="w"> </span><span class="c1">// those in the kernel declaration at axpby.metal</span>
|
||||
<span class="w"> </span><span class="kt">int</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">ndim</span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="kt">size_t</span><span class="w"> </span><span class="n">nelem</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="kt">int</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">ndim</span><span class="p">();</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="kt">size_t</span><span class="w"> </span><span class="n">nelem</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">();</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Encode input arrays to kernel</span>
|
||||
<span class="w"> </span><span class="n">set_array_buffer</span><span class="p">(</span><span class="n">compute_encoder</span><span class="p">,</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="mi">0</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">set_array_buffer</span><span class="p">(</span><span class="n">compute_encoder</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">set_array_buffer</span><span class="p">(</span><span class="n">compute_encoder</span><span class="p">,</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="mi">0</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">set_array_buffer</span><span class="p">(</span><span class="n">compute_encoder</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Encode output arrays to kernel</span>
|
||||
<span class="w"> </span><span class="n">set_array_buffer</span><span class="p">(</span><span class="n">compute_encoder</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">set_array_buffer</span><span class="p">(</span><span class="n">compute_encoder</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Encode alpha and beta</span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setBytes</span><span class="p">(</span><span class="o">&</span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">),</span><span class="w"> </span><span class="mi">3</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setBytes</span><span class="p">(</span><span class="o">&</span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">),</span><span class="w"> </span><span class="mi">4</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setBytes</span><span class="p">(</span><span class="o">&</span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">),</span><span class="w"> </span><span class="mi">3</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setBytes</span><span class="p">(</span><span class="o">&</span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">),</span><span class="w"> </span><span class="mi">4</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Encode shape, strides and ndim</span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setBytes</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">().</span><span class="n">data</span><span class="p">(),</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">int</span><span class="p">),</span><span class="w"> </span><span class="mi">5</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setBytes</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">().</span><span class="n">data</span><span class="p">(),</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">size_t</span><span class="p">),</span><span class="w"> </span><span class="mi">6</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setBytes</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">().</span><span class="n">data</span><span class="p">(),</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">size_t</span><span class="p">),</span><span class="w"> </span><span class="mi">7</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setBytes</span><span class="p">(</span><span class="o">&</span><span class="n">ndim</span><span class="p">,</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">int</span><span class="p">),</span><span class="w"> </span><span class="mi">8</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setBytes</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">().</span><span class="n">data</span><span class="p">(),</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">int</span><span class="p">),</span><span class="w"> </span><span class="mi">5</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setBytes</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">().</span><span class="n">data</span><span class="p">(),</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">size_t</span><span class="p">),</span><span class="w"> </span><span class="mi">6</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setBytes</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">().</span><span class="n">data</span><span class="p">(),</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">size_t</span><span class="p">),</span><span class="w"> </span><span class="mi">7</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setBytes</span><span class="p">(</span><span class="o">&</span><span class="n">ndim</span><span class="p">,</span><span class="w"> </span><span class="k">sizeof</span><span class="p">(</span><span class="kt">int</span><span class="p">),</span><span class="w"> </span><span class="mi">8</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// We launch 1 thread for each input and make sure that the number of</span>
|
||||
<span class="w"> </span><span class="c1">// threads in any given threadgroup is not higher than the max allowed</span>
|
||||
<span class="w"> </span><span class="kt">size_t</span><span class="w"> </span><span class="n">tgp_size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">min</span><span class="p">(</span><span class="n">nelem</span><span class="p">,</span><span class="w"> </span><span class="n">kernel</span><span class="o">-></span><span class="n">maxTotalThreadsPerThreadgroup</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="kt">size_t</span><span class="w"> </span><span class="n">tgp_size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">min</span><span class="p">(</span><span class="n">nelem</span><span class="p">,</span><span class="w"> </span><span class="n">kernel</span><span class="o">-></span><span class="n">maxTotalThreadsPerThreadgroup</span><span class="p">());</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Fix the 3D size of each threadgroup (in terms of threads)</span>
|
||||
<span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="w"> </span><span class="n">group_dims</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="p">(</span><span class="n">tgp_size</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="w"> </span><span class="n">group_dims</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="p">(</span><span class="n">tgp_size</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Fix the 3D size of the launch grid (in terms of threads)</span>
|
||||
<span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="w"> </span><span class="n">grid_dims</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="p">(</span><span class="n">nelem</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="w"> </span><span class="n">grid_dims</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="p">(</span><span class="n">nelem</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Launch the grid with the given number of threads divided among</span>
|
||||
<span class="w"> </span><span class="c1">// the given threadgroups</span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">dispatchThreads</span><span class="p">(</span><span class="n">grid_dims</span><span class="p">,</span><span class="w"> </span><span class="n">group_dims</span><span class="p">);</span>
|
||||
<span class="p">}</span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">dispatchThreads</span><span class="p">(</span><span class="n">grid_dims</span><span class="p">,</span><span class="w"> </span><span class="n">group_dims</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="p">}</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can now call the <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> operation on both the CPU and the GPU!</p>
|
||||
@@ -1213,11 +1223,11 @@ command buffers as needed. We suggest taking a deeper dive into
|
||||
transformations in a <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>. These transformations can be built on
|
||||
top of our operations, including the one we just defined now. Which then gives
|
||||
us the following <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::jvp()</span></code> and <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::vjp()</span></code> implementations.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** The Jacobian-vector product. */</span>
|
||||
<span class="n">array</span><span class="w"> </span><span class="nf">Axpby::jvp</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** The Jacobian-vector product. */</span><span class="w"></span>
|
||||
<span class="n">array</span><span class="w"> </span><span class="nf">Axpby::jvp</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="c1">// Forward mode diff that pushes along the tangents</span>
|
||||
<span class="w"> </span><span class="c1">// The jvp transform on the primitive can built with ops</span>
|
||||
<span class="w"> </span><span class="c1">// that are scheduled on the same stream as the primitive</span>
|
||||
@@ -1226,43 +1236,43 @@ us the following <code class="xref py py-meth docutils literal notranslate"><spa
|
||||
<span class="w"> </span><span class="c1">// jvp is just the tangent scaled by alpha</span>
|
||||
<span class="w"> </span><span class="c1">// Similarly, if argnums = {1}, the jvp is just the tangent</span>
|
||||
<span class="w"> </span><span class="c1">// scaled by beta</span>
|
||||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">argnums</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">></span><span class="w"> </span><span class="mi">1</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">argnums</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">0</span><span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">alpha_</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dtype</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">argnums</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">></span><span class="w"> </span><span class="mi">1</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">argnums</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">0</span><span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">alpha_</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dtype</span><span class="p">());</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="c1">// If, argnums = {0, 1}, we take contributions from both</span>
|
||||
<span class="w"> </span><span class="c1">// which gives us jvp = tangent_x * alpha + tangent_y * beta</span>
|
||||
<span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby</span><span class="p">(</span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="p">}</span>
|
||||
<span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby</span><span class="p">(</span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"></span>
|
||||
<span class="p">}</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** The vector-Jacobian product. */</span>
|
||||
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vjp</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** The vector-Jacobian product. */</span><span class="w"></span>
|
||||
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vjp</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="c1">// Reverse mode diff</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">vjps</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="k">auto</span><span class="w"> </span><span class="n">arg</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">arg</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">0</span><span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">alpha_</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">cotan</span><span class="p">.</span><span class="n">dtype</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="n">vjps</span><span class="p">.</span><span class="n">push_back</span><span class="p">(</span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">()));</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">vjps</span><span class="p">;</span>
|
||||
<span class="p">}</span>
|
||||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">vjps</span><span class="p">;</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="k">auto</span><span class="w"> </span><span class="n">arg</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">arg</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">0</span><span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">alpha_</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">cotan</span><span class="p">.</span><span class="n">dtype</span><span class="p">());</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">vjps</span><span class="p">.</span><span class="n">push_back</span><span class="p">(</span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">cotan</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">()));</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="p">}</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">vjps</span><span class="p">;</span><span class="w"></span>
|
||||
<span class="p">}</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Finally, you need not have a transformation fully defined to start using your
|
||||
own <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Vectorize primitive along given axis */</span>
|
||||
<span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="n">array</span><span class="p">,</span><span class="w"> </span><span class="kt">int</span><span class="o">></span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vmap</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">axes</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span><span class="s">"Axpby has no vmap implementation."</span><span class="p">);</span>
|
||||
<span class="p">}</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Vectorize primitive along given axis */</span><span class="w"></span>
|
||||
<span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="n">array</span><span class="p">,</span><span class="w"> </span><span class="kt">int</span><span class="o">></span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vmap</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">axes</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span><span class="s">"Axpby has no vmap implementation."</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="p">}</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</section>
|
||||
@@ -1297,20 +1307,20 @@ the python package</p></li>
|
||||
<p>We use <a class="reference external" href="https://pybind11.readthedocs.io/en/stable/">PyBind11</a> to build a Python API for the C++ library. Since bindings
|
||||
for all needed components such as <cite>mlx.core.array</cite>, <cite>mlx.core.stream</cite>, etc.
|
||||
are already provided, adding our <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> becomes very simple!</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">PYBIND11_MODULE</span><span class="p">(</span><span class="n">mlx_sample_extensions</span><span class="p">,</span><span class="w"> </span><span class="n">m</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">doc</span><span class="p">()</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s">"Sample C++ and metal extensions for MLX"</span><span class="p">;</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">PYBIND11_MODULE</span><span class="p">(</span><span class="n">mlx_sample_extensions</span><span class="p">,</span><span class="w"> </span><span class="n">m</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">doc</span><span class="p">()</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s">"Sample C++ and metal extensions for MLX"</span><span class="p">;</span><span class="w"></span>
|
||||
|
||||
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">def</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="s">"axpby"</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="o">&</span><span class="n">axpby</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="s">"x"</span><span class="n">_a</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="s">"y"</span><span class="n">_a</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">pos_only</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="s">"alpha"</span><span class="n">_a</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="s">"beta"</span><span class="n">_a</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">kw_only</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="s">"stream"</span><span class="n">_a</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">none</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="sa">R</span><span class="s">"</span><span class="dl">pbdoc(</span>
|
||||
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">def</span><span class="p">(</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="s">"axpby"</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="o">&</span><span class="n">axpby</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="s">"x"</span><span class="n">_a</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="s">"y"</span><span class="n">_a</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">pos_only</span><span class="p">(),</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="s">"alpha"</span><span class="n">_a</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="s">"beta"</span><span class="n">_a</span><span class="p">,</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">kw_only</span><span class="p">(),</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="s">"stream"</span><span class="n">_a</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">none</span><span class="p">(),</span><span class="w"></span>
|
||||
<span class="w"> </span><span class="sa">R</span><span class="s">"</span><span class="dl">pbdoc(</span><span class="s"></span>
|
||||
<span class="s"> Scale and sum two vectors element-wise</span>
|
||||
<span class="s"> ``z = alpha * x + beta * y``</span>
|
||||
|
||||
@@ -1325,8 +1335,8 @@ are already provided, adding our <code class="xref py py-meth docutils literal n
|
||||
|
||||
<span class="s"> Returns:</span>
|
||||
<span class="s"> array: ``alpha * x + beta * y``</span>
|
||||
<span class="s"> </span><span class="dl">)pbdoc</span><span class="s">"</span><span class="p">);</span>
|
||||
<span class="p">}</span>
|
||||
<span class="s"> </span><span class="dl">)pbdoc</span><span class="s">"</span><span class="p">);</span><span class="w"></span>
|
||||
<span class="p">}</span><span class="w"></span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Most of the complexity in the above example comes from additional bells and
|
||||
@@ -1463,7 +1473,7 @@ import the python package and play with it as you would any other MLX operation!
|
||||
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"c shape: </span><span class="si">{</span><span class="n">c</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"c dtype: </span><span class="si">{</span><span class="n">c</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"c correctness: </span><span class="si">{</span><span class="n">mx</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">c</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mf">6.0</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"c correctness: </span><span class="si">{</span><span class="n">mx</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">c</span> <span class="o">==</span> <span class="mf">6.0</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Output:</p>
|
||||
|
Reference in New Issue
Block a user