This commit is contained in:
CircleCI Docs
2024-11-05 19:54:16 +00:00
parent 3addf172d9
commit 98e590e52d
51 changed files with 2277 additions and 1802 deletions

View File

@@ -986,13 +986,13 @@ We will prioritize including it.</p>
<span class="n">ys</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">4096</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">naive_add</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">):</span>
<span class="k">return</span> <span class="p">[</span><span class="n">xs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">ys</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">xs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])]</span>
<span class="k">return</span> <span class="p">[</span><span class="n">xs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">ys</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">xs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
</pre></div>
</div>
<p>Instead you can use <a class="reference internal" href="../python/_autosummary/mlx.core.vmap.html#mlx.core.vmap" title="mlx.core.vmap"><code class="xref py py-func docutils literal notranslate"><span class="pre">vmap()</span></code></a> to automatically vectorize the addition:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># Vectorize over the second dimension of x and the</span>
<span class="c1"># first dimension of y</span>
<span class="n">vmap_add</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="n">vmap_add</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
</pre></div>
</div>
<p>The <code class="docutils literal notranslate"><span class="pre">in_axes</span></code> parameter can be used to specify which dimensions of the