mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
rebase
This commit is contained in:
@@ -986,13 +986,13 @@ We will prioritize including it.</p>
|
||||
<span class="n">ys</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">4096</span><span class="p">))</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">naive_add</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="p">[</span><span class="n">xs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">ys</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">xs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])]</span>
|
||||
<span class="k">return</span> <span class="p">[</span><span class="n">xs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">ys</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">xs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Instead you can use <a class="reference internal" href="../python/_autosummary/mlx.core.vmap.html#mlx.core.vmap" title="mlx.core.vmap"><code class="xref py py-func docutils literal notranslate"><span class="pre">vmap()</span></code></a> to automatically vectorize the addition:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># Vectorize over the second dimension of x and the</span>
|
||||
<span class="c1"># first dimension of y</span>
|
||||
<span class="n">vmap_add</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
|
||||
<span class="n">vmap_add</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The <code class="docutils literal notranslate"><span class="pre">in_axes</span></code> parameter can be used to specify which dimensions of the
|
||||
|
2
docs/build/html/usage/indexing.html
vendored
2
docs/build/html/usage/indexing.html
vendored
@@ -922,7 +922,7 @@ undefined behavior.</p></li>
|
||||
from the GPU. Performing bounds checking for array indices before launching the
|
||||
kernel would be extremely inefficient.</p>
|
||||
<p>Indexing with boolean masks is something that MLX may support in the future. In
|
||||
general, MLX has limited support for operations for which outputs
|
||||
general, MLX has limited support for operations for which output
|
||||
<em>shapes</em> are dependent on input <em>data</em>. Other examples of these types of
|
||||
operations which MLX does not yet support include <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.nonzero.html#numpy.nonzero" title="(in NumPy v2.1)"><code class="xref py py-func docutils literal notranslate"><span class="pre">numpy.nonzero()</span></code></a> and the
|
||||
single input version of <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.where.html#numpy.where" title="(in NumPy v2.1)"><code class="xref py py-func docutils literal notranslate"><span class="pre">numpy.where()</span></code></a>.</p>
|
||||
|
2
docs/build/html/usage/lazy_evaluation.html
vendored
2
docs/build/html/usage/lazy_evaluation.html
vendored
@@ -952,7 +952,7 @@ stochastic gradient descent). A natural and usually efficient place to use
|
||||
</div>
|
||||
<p>An important behavior to be aware of is when the graph will be implicitly
|
||||
evaluated. Anytime you <code class="docutils literal notranslate"><span class="pre">print</span></code> an array, convert it to an
|
||||
<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v2.1)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">numpy.ndarray</span></code></a>, or otherwise access it’s memory via <a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#memoryview" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">memoryview</span></code></a>,
|
||||
<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v2.1)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">numpy.ndarray</span></code></a>, or otherwise access its memory via <a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#memoryview" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">memoryview</span></code></a>,
|
||||
the graph will be evaluated. Saving arrays via <a class="reference internal" href="../python/_autosummary/mlx.core.save.html#mlx.core.save" title="mlx.core.save"><code class="xref py py-func docutils literal notranslate"><span class="pre">save()</span></code></a> (or any other MLX
|
||||
saving functions) will also evaluate the array.</p>
|
||||
<p>Calling <a class="reference internal" href="../python/_autosummary/mlx.core.array.item.html#mlx.core.array.item" title="mlx.core.array.item"><code class="xref py py-func docutils literal notranslate"><span class="pre">array.item()</span></code></a> on a scalar array will also evaluate it. In the
|
||||
|
Reference in New Issue
Block a user