mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
docs update
This commit is contained in:

committed by
CircleCI Docs

parent
8bea0a4eb8
commit
d03b91923e
100
docs/build/html/dev/extensions.html
vendored
100
docs/build/html/dev/extensions.html
vendored
@@ -9,7 +9,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
|
||||
|
||||
<title>Developer Documentation — MLX 0.0.6 documentation</title>
|
||||
<title>Developer Documentation — MLX 0.0.7 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -133,8 +133,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.6 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.6 documentation - Home"/>`);</script>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.7 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.7 documentation - Home"/>`);</script>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -277,12 +277,14 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.quantize.html">mlx.core.quantize</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.quantized_matmul.html">mlx.core.quantized_matmul</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.reciprocal.html">mlx.core.reciprocal</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.repeat.html">mlx.core.repeat</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.reshape.html">mlx.core.reshape</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.round.html">mlx.core.round</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.rsqrt.html">mlx.core.rsqrt</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save.html">mlx.core.save</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.savez.html">mlx.core.savez</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.savez_compressed.html">mlx.core.savez_compressed</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save_safetensors.html">mlx.core.save_safetensors</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sigmoid.html">mlx.core.sigmoid</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
|
||||
@@ -302,6 +304,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.take_along_axis.html">mlx.core.take_along_axis</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tan.html">mlx.core.tan</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tanh.html">mlx.core.tanh</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tensordot.html">mlx.core.tensordot</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.transpose.html">mlx.core.transpose</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
|
||||
@@ -350,11 +353,35 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.irfftn.html">mlx.core.fft.irfftn</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/nn.html">Neural Networks</a><input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-7"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/linalg.html">Linear Algebra</a><input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-7"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.norm.html">mlx.core.linalg.norm</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/nn.html">Neural Networks</a><input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-8"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.value_and_grad.html">mlx.nn.value_and_grad</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.Module.html">mlx.nn.Module</a></li>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/layers.html">Layers</a><input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-8"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/module.html">Module</a><input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-9"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.training.html">mlx.nn.Module.training</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply.html">mlx.nn.Module.apply</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply_to_modules.html">mlx.nn.Module.apply_to_modules</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.children.html">mlx.nn.Module.children</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.eval.html">mlx.nn.Module.eval</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.filter_and_map.html">mlx.nn.Module.filter_and_map</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.freeze.html">mlx.nn.Module.freeze</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.leaf_modules.html">mlx.nn.Module.leaf_modules</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.load_weights.html">mlx.nn.Module.load_weights</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.modules.html">mlx.nn.Module.modules</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.named_modules.html">mlx.nn.Module.named_modules</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.parameters.html">mlx.nn.Module.parameters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.save_weights.html">mlx.nn.Module.save_weights</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.train.html">mlx.nn.Module.train</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.trainable_parameters.html">mlx.nn.Module.trainable_parameters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.unfreeze.html">mlx.nn.Module.unfreeze</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.update.html">mlx.nn.Module.update</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.update_modules.html">mlx.nn.Module.update_modules</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/layers.html">Layers</a><input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-10"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.PReLU.html">mlx.nn.PReLU</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
|
||||
@@ -362,19 +389,27 @@
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Step.html">mlx.nn.Step</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.SELU.html">mlx.nn.SELU</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Mish.html">mlx.nn.Mish</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.QuantizedLinear.html">mlx.nn.QuantizedLinear</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.BatchNorm.html">mlx.nn.BatchNorm</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.InstanceNorm.html">mlx.nn.InstanceNorm</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout.html">mlx.nn.Dropout</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout2d.html">mlx.nn.Dropout2d</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout3d.html">mlx.nn.Dropout3d</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Transformer.html">mlx.nn.Transformer</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.QuantizedLinear.html">mlx.nn.QuantizedLinear</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ALiBi.html">mlx.nn.ALiBi</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.html">mlx.nn.SinusoidalPositionalEncoding</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/functions.html">Functions</a><input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-9"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/functions.html">Functions</a><input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-11"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu.html">mlx.nn.gelu</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu_approx.html">mlx.nn.gelu_approx</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.html">mlx.nn.gelu_fast_approx</a></li>
|
||||
@@ -386,7 +421,7 @@
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.mish.html">mlx.nn.mish</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/losses.html">Loss Functions</a><input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-10"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/losses.html">Loss Functions</a><input class="toctree-checkbox" id="toctree-checkbox-12" name="toctree-checkbox-12" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-12"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html">mlx.nn.losses.binary_cross_entropy</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html">mlx.nn.losses.cross_entropy</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html">mlx.nn.losses.kl_div_loss</a></li>
|
||||
@@ -395,11 +430,14 @@
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html">mlx.nn.losses.nll_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html">mlx.nn.losses.smooth_l1_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.html">mlx.nn.losses.triplet_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html">mlx.nn.losses.hinge_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.html">mlx.nn.losses.huber_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html">mlx.nn.losses.log_cosh_loss</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/optimizers.html">Optimizers</a><input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-11"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/optimizers.html">Optimizers</a><input class="toctree-checkbox" id="toctree-checkbox-13" name="toctree-checkbox-13" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-13"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.OptimizerState.html">mlx.optimizers.OptimizerState</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Optimizer.html">mlx.optimizers.Optimizer</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.SGD.html">mlx.optimizers.SGD</a></li>
|
||||
@@ -412,7 +450,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Lion.html">mlx.optimizers.Lion</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a><input class="toctree-checkbox" id="toctree-checkbox-12" name="toctree-checkbox-12" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-12"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a><input class="toctree-checkbox" id="toctree-checkbox-14" name="toctree-checkbox-14" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-14"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_flatten.html">mlx.utils.tree_flatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_unflatten.html">mlx.utils.tree_unflatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_map.html">mlx.utils.tree_map</a></li>
|
||||
@@ -638,7 +676,7 @@ and GPU implementations.</p>
|
||||
<section id="introducing-the-example">
|
||||
<h2>Introducing the Example<a class="headerlink" href="#introducing-the-example" title="Permalink to this heading">#</a></h2>
|
||||
<p>Let’s say that you would like an operation that takes in two arrays,
|
||||
<code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>, scales them both by some coefficents <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>
|
||||
<code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>, scales them both by some coefficients <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>
|
||||
respectively, and then adds them together to get the result
|
||||
<code class="docutils literal notranslate"><span class="pre">z</span> <span class="pre">=</span> <span class="pre">alpha</span> <span class="pre">*</span> <span class="pre">x</span> <span class="pre">+</span> <span class="pre">beta</span> <span class="pre">*</span> <span class="pre">y</span></code>. Well, you can very easily do that by just
|
||||
writing out a function as follows:</p>
|
||||
@@ -682,7 +720,7 @@ operations in the Python API (<a class="reference internal" href="../python/ops.
|
||||
and two scalars, <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>. This is how we would define it in the
|
||||
C++ API:</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/**</span>
|
||||
<span class="cm">* Scale and sum two vectors elementwise</span>
|
||||
<span class="cm">* Scale and sum two vectors element-wise</span>
|
||||
<span class="cm">* z = alpha * x + beta * y</span>
|
||||
<span class="cm">*</span>
|
||||
<span class="cm">* Follow numpy style broadcasting between x and y</span>
|
||||
@@ -833,7 +871,7 @@ data type, shape, the <code class="xref py py-class docutils literal notranslate
|
||||
</div>
|
||||
<p>This operation now handles the following:</p>
|
||||
<ol class="arabic simple">
|
||||
<li><p>Upcast inputs and resolve the the output data type.</p></li>
|
||||
<li><p>Upcast inputs and resolve the output data type.</p></li>
|
||||
<li><p>Broadcast the inputs and resolve the output shape.</p></li>
|
||||
<li><p>Construct the primitive <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> using the given stream, <code class="docutils literal notranslate"><span class="pre">alpha</span></code>, and <code class="docutils literal notranslate"><span class="pre">beta</span></code>.</p></li>
|
||||
<li><p>Construct the output <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> using the primitive and the inputs.</p></li>
|
||||
@@ -883,14 +921,14 @@ pointwise. This is captured in the templated function <code class="xref py py-me
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">alpha_</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">beta_</span><span class="p">);</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Do the elementwise operation for each output</span>
|
||||
<span class="w"> </span><span class="c1">// Do the element-wise operation for each output</span>
|
||||
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="kt">size_t</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="mi">0</span><span class="p">;</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o"><</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">();</span><span class="w"> </span><span class="n">out_idx</span><span class="o">++</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="c1">// Map linear indices to offsets in x and y</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">());</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">());</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// We allocate the output to be contiguous and regularly strided</span>
|
||||
<span class="w"> </span><span class="c1">// (defaults to row major) and hence it doesn't need additonal mapping</span>
|
||||
<span class="w"> </span><span class="c1">// (defaults to row major) and hence it doesn't need additional mapping</span>
|
||||
<span class="w"> </span><span class="n">out_ptr</span><span class="p">[</span><span class="n">out_idx</span><span class="p">]</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="p">[</span><span class="n">x_offset</span><span class="p">]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="p">[</span><span class="n">y_offset</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="p">}</span>
|
||||
<span class="p">}</span>
|
||||
@@ -902,7 +940,7 @@ for all incoming floating point arrays. Accordingly, we add dispatches for
|
||||
if we encounter an unexpected type.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Fall back implementation for evaluation on CPU */</span>
|
||||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="c1">// Check the inputs (registered in the op while contructing the out array)</span>
|
||||
<span class="w"> </span><span class="c1">// Check the inputs (registered in the op while constructing the out array)</span>
|
||||
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
|
||||
@@ -1071,7 +1109,7 @@ each data type.</p>
|
||||
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">float32</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="p">);</span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">float16</span><span class="p">,</span><span class="w"> </span><span class="n">half</span><span class="p">);</span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">bflot16</span><span class="p">,</span><span class="w"> </span><span class="n">bfloat16_t</span><span class="p">);</span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">bfloat16</span><span class="p">,</span><span class="w"> </span><span class="n">bfloat16_t</span><span class="p">);</span>
|
||||
<span class="n">instantiate_axpby</span><span class="p">(</span><span class="n">complex64</span><span class="p">,</span><span class="w"> </span><span class="n">complex64_t</span><span class="p">);</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
@@ -1120,7 +1158,7 @@ below.</p>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">setComputePipelineState</span><span class="p">(</span><span class="n">kernel</span><span class="p">);</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Kernel parameters are registered with buffer indices corresponding to</span>
|
||||
<span class="w"> </span><span class="c1">// those in the kernel decelaration at axpby.metal</span>
|
||||
<span class="w"> </span><span class="c1">// those in the kernel declaration at axpby.metal</span>
|
||||
<span class="w"> </span><span class="kt">int</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">ndim</span><span class="p">();</span>
|
||||
<span class="w"> </span><span class="kt">size_t</span><span class="w"> </span><span class="n">nelem</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">();</span>
|
||||
|
||||
@@ -1151,7 +1189,7 @@ below.</p>
|
||||
<span class="w"> </span><span class="c1">// Fix the 3D size of the launch grid (in terms of threads)</span>
|
||||
<span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="w"> </span><span class="n">grid_dims</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="p">(</span><span class="n">nelem</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// Launch the grid with the given number of threads divded among</span>
|
||||
<span class="w"> </span><span class="c1">// Launch the grid with the given number of threads divided among</span>
|
||||
<span class="w"> </span><span class="c1">// the given threadgroups</span>
|
||||
<span class="w"> </span><span class="n">compute_encoder</span><span class="o">-></span><span class="n">dispatchThreads</span><span class="p">(</span><span class="n">grid_dims</span><span class="p">,</span><span class="w"> </span><span class="n">group_dims</span><span class="p">);</span>
|
||||
<span class="p">}</span>
|
||||
@@ -1164,7 +1202,7 @@ to give us the active metal compute command encoder instead of building a
|
||||
new one and calling <code class="xref py py-meth docutils literal notranslate"><span class="pre">compute_encoder->end_encoding()</span></code> at the end.
|
||||
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
||||
until some specified limit is hit or the compute encoder needs to be flushed
|
||||
for synchronization. MLX also handles enqueuing and commiting the associated
|
||||
for synchronization. MLX also handles enqueuing and committing the associated
|
||||
command buffers as needed. We suggest taking a deeper dive into
|
||||
<code class="xref py py-class docutils literal notranslate"><span class="pre">metal::Device</span></code> if you would like to study this routine further.</p>
|
||||
</section>
|
||||
@@ -1180,8 +1218,8 @@ us the following <code class="xref py py-meth docutils literal notranslate"><spa
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
<span class="w"> </span><span class="c1">// Forward mode diff that pushes along the tangents</span>
|
||||
<span class="w"> </span><span class="c1">// The jvp transform on the the primitive can built with ops</span>
|
||||
<span class="w"> </span><span class="c1">// that are scheduled on the same stream as the primtive</span>
|
||||
<span class="w"> </span><span class="c1">// The jvp transform on the primitive can built with ops</span>
|
||||
<span class="w"> </span><span class="c1">// that are scheduled on the same stream as the primitive</span>
|
||||
|
||||
<span class="w"> </span><span class="c1">// If argnums = {0}, we only push along x in which case the</span>
|
||||
<span class="w"> </span><span class="c1">// jvp is just the tangent scaled by alpha</span>
|
||||
@@ -1218,7 +1256,7 @@ us the following <code class="xref py py-meth docutils literal notranslate"><spa
|
||||
</div>
|
||||
<p>Finally, you need not have a transformation fully defined to start using your
|
||||
own <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>.</p>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Vectorize primitve along given axis */</span>
|
||||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Vectorize primitive along given axis */</span>
|
||||
<span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="n">array</span><span class="p">,</span><span class="w"> </span><span class="kt">int</span><span class="o">></span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vmap</span><span class="p">(</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">axes</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||||
@@ -1245,7 +1283,7 @@ own <code class="xref py py-class docutils literal notranslate"><span class="pre
|
||||
</div>
|
||||
<ul class="simple">
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/axpby/</span></code> defines the C++ extension library</p></li>
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> sets out the strucutre for the
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> sets out the structure for the
|
||||
associated python package</p></li>
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/bindings.cpp</span></code> provides python bindings for our operation</p></li>
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/CMakeLists.txt</span></code> holds CMake rules to build the library and
|
||||
@@ -1272,7 +1310,7 @@ are already provided, adding our <code class="xref py py-meth docutils literal n
|
||||
<span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">kw_only</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="s">"stream"</span><span class="n">_a</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">py</span><span class="o">::</span><span class="n">none</span><span class="p">(),</span>
|
||||
<span class="w"> </span><span class="sa">R</span><span class="s">"</span><span class="dl">pbdoc(</span>
|
||||
<span class="s"> Scale and sum two vectors elementwise</span>
|
||||
<span class="s"> Scale and sum two vectors element-wise</span>
|
||||
<span class="s"> ``z = alpha * x + beta * y``</span>
|
||||
|
||||
<span class="s"> Follows numpy style broadcasting between ``x`` and ``y``</span>
|
||||
@@ -1405,7 +1443,7 @@ bindings and copied together if the package is installed</p></li>
|
||||
<div class="line">…</div>
|
||||
</div>
|
||||
<p>When you try to install using the command <code class="docutils literal notranslate"><span class="pre">python</span> <span class="pre">-m</span> <span class="pre">pip</span> <span class="pre">install</span> <span class="pre">.</span></code>
|
||||
(in <code class="docutils literal notranslate"><span class="pre">extensions/</span></code>), the package will be installed with the same strucutre as
|
||||
(in <code class="docutils literal notranslate"><span class="pre">extensions/</span></code>), the package will be installed with the same structure as
|
||||
<code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> and the C++ and metal library will be
|
||||
copied along with the python binding since they are specified as <code class="docutils literal notranslate"><span class="pre">package_data</span></code>.</p>
|
||||
</section>
|
||||
@@ -1482,7 +1520,7 @@ with the naive <code class="xref py py-meth docutils literal notranslate"><span
|
||||
</div>
|
||||
<p>We see some modest improvements right away!</p>
|
||||
<p>This operation is now good to be used to build other operations,
|
||||
in <a class="reference internal" href="../python/_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> calls, and also as a part of graph
|
||||
in <a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> calls, and also as a part of graph
|
||||
transformations such as <code class="xref py py-meth docutils literal notranslate"><span class="pre">grad()</span></code> and <code class="xref py py-meth docutils literal notranslate"><span class="pre">simplify()</span></code>!</p>
|
||||
</section>
|
||||
</section>
|
||||
|
Reference in New Issue
Block a user