docs update

This commit is contained in:
Awni Hannun
2024-01-03 20:14:05 -08:00
committed by CircleCI Docs
parent 8bea0a4eb8
commit d03b91923e
330 changed files with 40495 additions and 3726 deletions

View File

@@ -9,7 +9,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<title>Developer Documentation &#8212; MLX 0.0.6 documentation</title>
<title>Developer Documentation &#8212; MLX 0.0.7 documentation</title>
@@ -133,8 +133,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.6 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.6 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.7 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.7 documentation - Home"/>`);</script>
</a></div>
@@ -277,12 +277,14 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.quantize.html">mlx.core.quantize</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.quantized_matmul.html">mlx.core.quantized_matmul</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.reciprocal.html">mlx.core.reciprocal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.repeat.html">mlx.core.repeat</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.reshape.html">mlx.core.reshape</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.round.html">mlx.core.round</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.rsqrt.html">mlx.core.rsqrt</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save.html">mlx.core.save</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.savez.html">mlx.core.savez</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.savez_compressed.html">mlx.core.savez_compressed</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save_safetensors.html">mlx.core.save_safetensors</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sigmoid.html">mlx.core.sigmoid</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
@@ -302,6 +304,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.take_along_axis.html">mlx.core.take_along_axis</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tan.html">mlx.core.tan</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tanh.html">mlx.core.tanh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tensordot.html">mlx.core.tensordot</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.transpose.html">mlx.core.transpose</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
@@ -350,11 +353,35 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.irfftn.html">mlx.core.fft.irfftn</a></li>
</ul>
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/nn.html">Neural Networks</a><input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-7"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/linalg.html">Linear Algebra</a><input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-7"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.norm.html">mlx.core.linalg.norm</a></li>
</ul>
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/nn.html">Neural Networks</a><input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-8"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.value_and_grad.html">mlx.nn.value_and_grad</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.Module.html">mlx.nn.Module</a></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/layers.html">Layers</a><input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-8"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/module.html">Module</a><input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-9"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.training.html">mlx.nn.Module.training</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply.html">mlx.nn.Module.apply</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply_to_modules.html">mlx.nn.Module.apply_to_modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.children.html">mlx.nn.Module.children</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.eval.html">mlx.nn.Module.eval</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.filter_and_map.html">mlx.nn.Module.filter_and_map</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.freeze.html">mlx.nn.Module.freeze</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.leaf_modules.html">mlx.nn.Module.leaf_modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.load_weights.html">mlx.nn.Module.load_weights</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.modules.html">mlx.nn.Module.modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.named_modules.html">mlx.nn.Module.named_modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.parameters.html">mlx.nn.Module.parameters</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.save_weights.html">mlx.nn.Module.save_weights</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.train.html">mlx.nn.Module.train</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.trainable_parameters.html">mlx.nn.Module.trainable_parameters</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.unfreeze.html">mlx.nn.Module.unfreeze</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.update.html">mlx.nn.Module.update</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.update_modules.html">mlx.nn.Module.update_modules</a></li>
</ul>
</li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/layers.html">Layers</a><input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-10"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.PReLU.html">mlx.nn.PReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
@@ -362,19 +389,27 @@
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Step.html">mlx.nn.Step</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.SELU.html">mlx.nn.SELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Mish.html">mlx.nn.Mish</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.QuantizedLinear.html">mlx.nn.QuantizedLinear</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.BatchNorm.html">mlx.nn.BatchNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.InstanceNorm.html">mlx.nn.InstanceNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout.html">mlx.nn.Dropout</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout2d.html">mlx.nn.Dropout2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout3d.html">mlx.nn.Dropout3d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Transformer.html">mlx.nn.Transformer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.QuantizedLinear.html">mlx.nn.QuantizedLinear</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ALiBi.html">mlx.nn.ALiBi</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.html">mlx.nn.SinusoidalPositionalEncoding</a></li>
</ul>
</li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/functions.html">Functions</a><input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-9"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/functions.html">Functions</a><input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-11"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu.html">mlx.nn.gelu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu_approx.html">mlx.nn.gelu_approx</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.html">mlx.nn.gelu_fast_approx</a></li>
@@ -386,7 +421,7 @@
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.mish.html">mlx.nn.mish</a></li>
</ul>
</li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/losses.html">Loss Functions</a><input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-10"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/losses.html">Loss Functions</a><input class="toctree-checkbox" id="toctree-checkbox-12" name="toctree-checkbox-12" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-12"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html">mlx.nn.losses.binary_cross_entropy</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html">mlx.nn.losses.cross_entropy</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html">mlx.nn.losses.kl_div_loss</a></li>
@@ -395,11 +430,14 @@
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html">mlx.nn.losses.nll_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html">mlx.nn.losses.smooth_l1_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.html">mlx.nn.losses.triplet_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html">mlx.nn.losses.hinge_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.html">mlx.nn.losses.huber_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html">mlx.nn.losses.log_cosh_loss</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/optimizers.html">Optimizers</a><input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-11"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/optimizers.html">Optimizers</a><input class="toctree-checkbox" id="toctree-checkbox-13" name="toctree-checkbox-13" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-13"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.OptimizerState.html">mlx.optimizers.OptimizerState</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Optimizer.html">mlx.optimizers.Optimizer</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.SGD.html">mlx.optimizers.SGD</a></li>
@@ -412,7 +450,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Lion.html">mlx.optimizers.Lion</a></li>
</ul>
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a><input class="toctree-checkbox" id="toctree-checkbox-12" name="toctree-checkbox-12" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-12"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a><input class="toctree-checkbox" id="toctree-checkbox-14" name="toctree-checkbox-14" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-14"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_flatten.html">mlx.utils.tree_flatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_unflatten.html">mlx.utils.tree_unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_map.html">mlx.utils.tree_map</a></li>
@@ -638,7 +676,7 @@ and GPU implementations.</p>
<section id="introducing-the-example">
<h2>Introducing the Example<a class="headerlink" href="#introducing-the-example" title="Permalink to this heading">#</a></h2>
<p>Lets say that you would like an operation that takes in two arrays,
<code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>, scales them both by some coefficents <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>
<code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>, scales them both by some coefficients <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>
respectively, and then adds them together to get the result
<code class="docutils literal notranslate"><span class="pre">z</span> <span class="pre">=</span> <span class="pre">alpha</span> <span class="pre">*</span> <span class="pre">x</span> <span class="pre">+</span> <span class="pre">beta</span> <span class="pre">*</span> <span class="pre">y</span></code>. Well, you can very easily do that by just
writing out a function as follows:</p>
@@ -682,7 +720,7 @@ operations in the Python API (<a class="reference internal" href="../python/ops.
and two scalars, <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>. This is how we would define it in the
C++ API:</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/**</span>
<span class="cm">* Scale and sum two vectors elementwise</span>
<span class="cm">* Scale and sum two vectors element-wise</span>
<span class="cm">* z = alpha * x + beta * y</span>
<span class="cm">*</span>
<span class="cm">* Follow numpy style broadcasting between x and y</span>
@@ -833,7 +871,7 @@ data type, shape, the <code class="xref py py-class docutils literal notranslate
</div>
<p>This operation now handles the following:</p>
<ol class="arabic simple">
<li><p>Upcast inputs and resolve the the output data type.</p></li>
<li><p>Upcast inputs and resolve the output data type.</p></li>
<li><p>Broadcast the inputs and resolve the output shape.</p></li>
<li><p>Construct the primitive <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> using the given stream, <code class="docutils literal notranslate"><span class="pre">alpha</span></code>, and <code class="docutils literal notranslate"><span class="pre">beta</span></code>.</p></li>
<li><p>Construct the output <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> using the primitive and the inputs.</p></li>
@@ -883,14 +921,14 @@ pointwise. This is captured in the templated function <code class="xref py py-me
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">(</span><span class="n">alpha_</span><span class="p">);</span>
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">(</span><span class="n">beta_</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Do the elementwise operation for each output</span>
<span class="w"> </span><span class="c1">// Do the element-wise operation for each output</span>
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="kt">size_t</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="mi">0</span><span class="p">;</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o">&lt;</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">();</span><span class="w"> </span><span class="n">out_idx</span><span class="o">++</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Map linear indices to offsets in x and y</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">());</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">());</span>
<span class="w"> </span><span class="c1">// We allocate the output to be contiguous and regularly strided</span>
<span class="w"> </span><span class="c1">// (defaults to row major) and hence it doesn&#39;t need additonal mapping</span>
<span class="w"> </span><span class="c1">// (defaults to row major) and hence it doesn&#39;t need additional mapping</span>
<span class="w"> </span><span class="n">out_ptr</span><span class="p">[</span><span class="n">out_idx</span><span class="p">]</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="p">[</span><span class="n">x_offset</span><span class="p">]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="p">[</span><span class="n">y_offset</span><span class="p">];</span>
<span class="w"> </span><span class="p">}</span>
<span class="p">}</span>
@@ -902,7 +940,7 @@ for all incoming floating point arrays. Accordingly, we add dispatches for
if we encounter an unexpected type.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Fall back implementation for evaluation on CPU */</span>
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&amp;</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Check the inputs (registered in the op while contructing the out array)</span>
<span class="w"> </span><span class="c1">// Check the inputs (registered in the op while constructing the out array)</span>
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
<span class="w"> </span><span class="k">auto</span><span class="o">&amp;</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
@@ -1071,7 +1109,7 @@ each data type.</p>
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">float32</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="p">);</span>
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">float16</span><span class="p">,</span><span class="w"> </span><span class="n">half</span><span class="p">);</span>
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">bflot16</span><span class="p">,</span><span class="w"> </span><span class="n">bfloat16_t</span><span class="p">);</span>
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">bfloat16</span><span class="p">,</span><span class="w"> </span><span class="n">bfloat16_t</span><span class="p">);</span>
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">complex64</span><span class="p">,</span><span class="w"> </span><span class="n">complex64_t</span><span class="p">);</span>
</pre></div>
</div>
@@ -1120,7 +1158,7 @@ below.</p>
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-&gt;</span><span class="n">setComputePipelineState</span><span class="p">(</span><span class="n">kernel</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Kernel parameters are registered with buffer indices corresponding to</span>
<span class="w"> </span><span class="c1">// those in the kernel decelaration at axpby.metal</span>
<span class="w"> </span><span class="c1">// those in the kernel declaration at axpby.metal</span>
<span class="w"> </span><span class="kt">int</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">ndim</span><span class="p">();</span>
<span class="w"> </span><span class="kt">size_t</span><span class="w"> </span><span class="n">nelem</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">();</span>
@@ -1151,7 +1189,7 @@ below.</p>
<span class="w"> </span><span class="c1">// Fix the 3D size of the launch grid (in terms of threads)</span>
<span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="w"> </span><span class="n">grid_dims</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="p">(</span><span class="n">nelem</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Launch the grid with the given number of threads divded among</span>
<span class="w"> </span><span class="c1">// Launch the grid with the given number of threads divided among</span>
<span class="w"> </span><span class="c1">// the given threadgroups</span>
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-&gt;</span><span class="n">dispatchThreads</span><span class="p">(</span><span class="n">grid_dims</span><span class="p">,</span><span class="w"> </span><span class="n">group_dims</span><span class="p">);</span>
<span class="p">}</span>
@@ -1164,7 +1202,7 @@ to give us the active metal compute command encoder instead of building a
new one and calling <code class="xref py py-meth docutils literal notranslate"><span class="pre">compute_encoder-&gt;end_encoding()</span></code> at the end.
MLX keeps adding kernels (compute pipelines) to the active command encoder
until some specified limit is hit or the compute encoder needs to be flushed
for synchronization. MLX also handles enqueuing and commiting the associated
for synchronization. MLX also handles enqueuing and committing the associated
command buffers as needed. We suggest taking a deeper dive into
<code class="xref py py-class docutils literal notranslate"><span class="pre">metal::Device</span></code> if you would like to study this routine further.</p>
</section>
@@ -1180,8 +1218,8 @@ us the following <code class="xref py py-meth docutils literal notranslate"><spa
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Forward mode diff that pushes along the tangents</span>
<span class="w"> </span><span class="c1">// The jvp transform on the the primitive can built with ops</span>
<span class="w"> </span><span class="c1">// that are scheduled on the same stream as the primtive</span>
<span class="w"> </span><span class="c1">// The jvp transform on the primitive can built with ops</span>
<span class="w"> </span><span class="c1">// that are scheduled on the same stream as the primitive</span>
<span class="w"> </span><span class="c1">// If argnums = {0}, we only push along x in which case the</span>
<span class="w"> </span><span class="c1">// jvp is just the tangent scaled by alpha</span>
@@ -1218,7 +1256,7 @@ us the following <code class="xref py py-meth docutils literal notranslate"><spa
</div>
<p>Finally, you need not have a transformation fully defined to start using your
own <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>.</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Vectorize primitve along given axis */</span>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Vectorize primitive along given axis */</span>
<span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o">&lt;</span><span class="n">array</span><span class="p">,</span><span class="w"> </span><span class="kt">int</span><span class="o">&gt;</span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vmap</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">array</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span><span class="w"> </span><span class="n">axes</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
@@ -1245,7 +1283,7 @@ own <code class="xref py py-class docutils literal notranslate"><span class="pre
</div>
<ul class="simple">
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/axpby/</span></code> defines the C++ extension library</p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> sets out the strucutre for the
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> sets out the structure for the
associated python package</p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/bindings.cpp</span></code> provides python bindings for our operation</p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/CMakeLists.txt</span></code> holds CMake rules to build the library and
@@ -1272,7 +1310,7 @@ are already provided, adding our <code class="xref py py-meth docutils literal n
<span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">kw_only</span><span class="p">(),</span>
<span class="w"> </span><span class="s">&quot;stream&quot;</span><span class="n">_a</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">none</span><span class="p">(),</span>
<span class="w"> </span><span class="sa">R</span><span class="s">&quot;</span><span class="dl">pbdoc(</span>
<span class="s"> Scale and sum two vectors elementwise</span>
<span class="s"> Scale and sum two vectors element-wise</span>
<span class="s"> ``z = alpha * x + beta * y``</span>
<span class="s"> Follows numpy style broadcasting between ``x`` and ``y``</span>
@@ -1405,7 +1443,7 @@ bindings and copied together if the package is installed</p></li>
<div class="line"></div>
</div>
<p>When you try to install using the command <code class="docutils literal notranslate"><span class="pre">python</span> <span class="pre">-m</span> <span class="pre">pip</span> <span class="pre">install</span> <span class="pre">.</span></code>
(in <code class="docutils literal notranslate"><span class="pre">extensions/</span></code>), the package will be installed with the same strucutre as
(in <code class="docutils literal notranslate"><span class="pre">extensions/</span></code>), the package will be installed with the same structure as
<code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> and the C++ and metal library will be
copied along with the python binding since they are specified as <code class="docutils literal notranslate"><span class="pre">package_data</span></code>.</p>
</section>
@@ -1482,7 +1520,7 @@ with the naive <code class="xref py py-meth docutils literal notranslate"><span
</div>
<p>We see some modest improvements right away!</p>
<p>This operation is now good to be used to build other operations,
in <a class="reference internal" href="../python/_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> calls, and also as a part of graph
in <a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> calls, and also as a part of graph
transformations such as <code class="xref py py-meth docutils literal notranslate"><span class="pre">grad()</span></code> and <code class="xref py py-meth docutils literal notranslate"><span class="pre">simplify()</span></code>!</p>
</section>
</section>