mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
docs
This commit is contained in:

committed by
CircleCI Docs

parent
f75712551d
commit
0250e203f6
326
docs/build/html/python/nn.html
vendored
326
docs/build/html/python/nn.html
vendored
@@ -226,6 +226,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.argsort.html">mlx.core.argsort</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.array_equal.html">mlx.core.array_equal</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.broadcast_to.html">mlx.core.broadcast_to</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.ceil.html">mlx.core.ceil</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.concatenate.html">mlx.core.concatenate</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.convolve.html">mlx.core.convolve</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.conv1d.html">mlx.core.conv1d</a></li>
|
||||
@@ -239,6 +240,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.exp.html">mlx.core.exp</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.expand_dims.html">mlx.core.expand_dims</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.eye.html">mlx.core.eye</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.floor.html">mlx.core.floor</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.flatten.html">mlx.core.flatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.full.html">mlx.core.full</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.greater.html">mlx.core.greater</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.greater_equal.html">mlx.core.greater_equal</a></li>
|
||||
@@ -259,6 +262,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.mean.html">mlx.core.mean</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.min.html">mlx.core.min</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.minimum.html">mlx.core.minimum</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.moveaxis.html">mlx.core.moveaxis</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.multiply.html">mlx.core.multiply</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.negative.html">mlx.core.negative</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.ones.html">mlx.core.ones</a></li>
|
||||
@@ -282,14 +286,19 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.sqrt.html">mlx.core.sqrt</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.square.html">mlx.core.square</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.squeeze.html">mlx.core.squeeze</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.stack.html">mlx.core.stack</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.stop_gradient.html">mlx.core.stop_gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.subtract.html">mlx.core.subtract</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.sum.html">mlx.core.sum</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.swapaxes.html">mlx.core.swapaxes</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.take.html">mlx.core.take</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.take_along_axis.html">mlx.core.take_along_axis</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.tan.html">mlx.core.tan</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.tanh.html">mlx.core.tanh</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.transpose.html">mlx.core.transpose</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.var.html">mlx.core.var</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.where.html">mlx.core.where</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.zeros.html">mlx.core.zeros</a></li>
|
||||
@@ -316,6 +325,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.jvp.html">mlx.core.jvp</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.vjp.html">mlx.core.vjp</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.vmap.html">mlx.core.vmap</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.simplify.html">mlx.core.simplify</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="fft.html">FFT</a><input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-6"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
@@ -335,48 +345,63 @@
|
||||
</li>
|
||||
<li class="toctree-l1 current active has-children"><a class="current reference internal" href="#">Neural Networks</a><input checked="" class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-7"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html">mlx.nn.value_and_grad</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.PReLU.html">mlx.nn.PReLU</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.SiLU.html">mlx.nn.SiLU</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.Step.html">mlx.nn.Step</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.SELU.html">mlx.nn.SELU</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.Mish.html">mlx.nn.Mish</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu.html">mlx.nn.gelu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_approx.html">mlx.nn.gelu_approx</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_fast_approx.html">mlx.nn.gelu_fast_approx</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.relu.html">mlx.nn.relu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.prelu.html">mlx.nn.prelu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.silu.html">mlx.nn.silu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.step.html">mlx.nn.step</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.selu.html">mlx.nn.selu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.mish.html">mlx.nn.mish</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.cross_entropy.html">mlx.nn.losses.cross_entropy</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html">mlx.nn.losses.binary_cross_entropy</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.l1_loss.html">mlx.nn.losses.l1_loss</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.mse_loss.html">mlx.nn.losses.mse_loss</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.nll_loss.html">mlx.nn.losses.nll_loss</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.kl_div_loss.html">mlx.nn.losses.kl_div_loss</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.Module.html">mlx.nn.Module</a></li>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="nn/layers.html">Layers</a><input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-8"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.PReLU.html">mlx.nn.PReLU</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.SiLU.html">mlx.nn.SiLU</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Step.html">mlx.nn.Step</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.SELU.html">mlx.nn.SELU</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Mish.html">mlx.nn.Mish</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="optimizers.html">Optimizers</a><input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-8"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="nn/functions.html">Functions</a><input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-9"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.gelu.html">mlx.nn.gelu</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.gelu_approx.html">mlx.nn.gelu_approx</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.gelu_fast_approx.html">mlx.nn.gelu_fast_approx</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.relu.html">mlx.nn.relu</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.prelu.html">mlx.nn.prelu</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.silu.html">mlx.nn.silu</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.step.html">mlx.nn.step</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.selu.html">mlx.nn.selu</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.mish.html">mlx.nn.mish</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="nn/losses.html">Loss Functions</a><input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-10"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html">mlx.nn.losses.cross_entropy</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html">mlx.nn.losses.binary_cross_entropy</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.l1_loss.html">mlx.nn.losses.l1_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.mse_loss.html">mlx.nn.losses.mse_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.nll_loss.html">mlx.nn.losses.nll_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html">mlx.nn.losses.kl_div_loss</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="optimizers.html">Optimizers</a><input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-11"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.optimizers.OptimizerState.html">mlx.optimizers.OptimizerState</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.optimizers.Optimizer.html">mlx.optimizers.Optimizer</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.optimizers.SGD.html">mlx.optimizers.SGD</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.optimizers.RMSprop.html">mlx.optimizers.RMSprop</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.optimizers.Adagrad.html">mlx.optimizers.Adagrad</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.optimizers.AdaDelta.html">mlx.optimizers.AdaDelta</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.optimizers.Adam.html">mlx.optimizers.Adam</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.optimizers.AdamW.html">mlx.optimizers.AdamW</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.optimizers.Adamax.html">mlx.optimizers.Adamax</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="tree_utils.html">Tree Utils</a><input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-9"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="tree_utils.html">Tree Utils</a><input class="toctree-checkbox" id="toctree-checkbox-12" name="toctree-checkbox-12" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-12"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.utils.tree_flatten.html">mlx.utils.tree_flatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.utils.tree_unflatten.html">mlx.utils.tree_unflatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.utils.tree_map.html">mlx.utils.tree_map</a></li>
|
||||
@@ -556,14 +581,13 @@ document.write(`
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#quick-start-with-neural-networks">Quick Start with Neural Networks</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#the-module-class">The Module Class</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#parameters">Parameters</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#updating-the-parameters">Updating the parameters</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#updating-the-parameters">Updating the Parameters</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#inspecting-modules">Inspecting Modules</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#value-and-grad">Value and grad</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#neural-network-layers">Neural Network Layers</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#value-and-grad">Value and Grad</a><ul class="visible nav section-nav flex-column">
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#loss-functions">Loss Functions</a></li>
|
||||
</ul>
|
||||
</nav>
|
||||
</div>
|
||||
@@ -635,34 +659,61 @@ for finetuning and more.</p>
|
||||
</section>
|
||||
<section id="the-module-class">
|
||||
<span id="module-class"></span><h2>The Module Class<a class="headerlink" href="#the-module-class" title="Permalink to this heading">#</a></h2>
|
||||
<p>The workhorse of any neural network library is the <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> class. In
|
||||
MLX the <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> class is a container of <a class="reference internal" href="_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> or
|
||||
<a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> instances. Its main function is to provide a way to
|
||||
<p>The workhorse of any neural network library is the <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> class. In
|
||||
MLX the <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> class is a container of <a class="reference internal" href="_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> or
|
||||
<a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> instances. Its main function is to provide a way to
|
||||
recursively <strong>access</strong> and <strong>update</strong> its parameters and those of its
|
||||
submodules.</p>
|
||||
<section id="parameters">
|
||||
<h3>Parameters<a class="headerlink" href="#parameters" title="Permalink to this heading">#</a></h3>
|
||||
<p>A parameter of a module is any public member of type <a class="reference internal" href="_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> (its
|
||||
name should not start with <code class="docutils literal notranslate"><span class="pre">_</span></code>). It can be arbitrarily nested in other
|
||||
<a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> instances or lists and dictionaries.</p>
|
||||
<p><a class="reference internal" href="nn/module.html#mlx.nn.Module.parameters" title="mlx.nn.Module.parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.parameters()</span></code></a> can be used to extract a nested dictionary with all
|
||||
<a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> instances or lists and dictionaries.</p>
|
||||
<p><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.parameters()</span></code> can be used to extract a nested dictionary with all
|
||||
the parameters of a module and its submodules.</p>
|
||||
<p>A <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> can also keep track of “frozen” parameters.
|
||||
<a class="reference internal" href="nn/module.html#mlx.nn.Module.trainable_parameters" title="mlx.nn.Module.trainable_parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.trainable_parameters()</span></code></a> returns only the subset of
|
||||
<a class="reference internal" href="nn/module.html#mlx.nn.Module.parameters" title="mlx.nn.Module.parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.parameters()</span></code></a> that is not frozen. When using
|
||||
<a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a> the gradients returned will be with respect to these
|
||||
trainable parameters.</p>
|
||||
<p>A <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> can also keep track of “frozen” parameters. See the
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.freeze()</span></code> method for more details. <a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a>
|
||||
the gradients returned will be with respect to these trainable parameters.</p>
|
||||
</section>
|
||||
<section id="updating-the-parameters">
|
||||
<h3>Updating the parameters<a class="headerlink" href="#updating-the-parameters" title="Permalink to this heading">#</a></h3>
|
||||
<h3>Updating the Parameters<a class="headerlink" href="#updating-the-parameters" title="Permalink to this heading">#</a></h3>
|
||||
<p>MLX modules allow accessing and updating individual parameters. However, most
|
||||
times we need to update large subsets of a module’s parameters. This action is
|
||||
performed by <a class="reference internal" href="nn/module.html#mlx.nn.Module.update" title="mlx.nn.Module.update"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.update()</span></code></a>.</p>
|
||||
performed by <code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.update()</span></code>.</p>
|
||||
</section>
|
||||
<section id="inspecting-modules">
|
||||
<h3>Inspecting Modules<a class="headerlink" href="#inspecting-modules" title="Permalink to this heading">#</a></h3>
|
||||
<p>The simplest way to see the model architecture is to print it. Following along with
|
||||
the above example, you can print the <code class="docutils literal notranslate"><span class="pre">MLP</span></code> with:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="n">mlp</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>This will display:</p>
|
||||
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>MLP<span class="o">(</span>
|
||||
<span class="w"> </span><span class="o">(</span>layers.0<span class="o">)</span>:<span class="w"> </span>Linear<span class="o">(</span><span class="nv">input_dims</span><span class="o">=</span><span class="m">2</span>,<span class="w"> </span><span class="nv">output_dims</span><span class="o">=</span><span class="m">128</span>,<span class="w"> </span><span class="nv">bias</span><span class="o">=</span>True<span class="o">)</span>
|
||||
<span class="w"> </span><span class="o">(</span>layers.1<span class="o">)</span>:<span class="w"> </span>Linear<span class="o">(</span><span class="nv">input_dims</span><span class="o">=</span><span class="m">128</span>,<span class="w"> </span><span class="nv">output_dims</span><span class="o">=</span><span class="m">128</span>,<span class="w"> </span><span class="nv">bias</span><span class="o">=</span>True<span class="o">)</span>
|
||||
<span class="w"> </span><span class="o">(</span>layers.2<span class="o">)</span>:<span class="w"> </span>Linear<span class="o">(</span><span class="nv">input_dims</span><span class="o">=</span><span class="m">128</span>,<span class="w"> </span><span class="nv">output_dims</span><span class="o">=</span><span class="m">10</span>,<span class="w"> </span><span class="nv">bias</span><span class="o">=</span>True<span class="o">)</span>
|
||||
<span class="o">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>To get more detailed information on the arrays in a <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> you can use
|
||||
<a class="reference internal" href="_autosummary/mlx.utils.tree_map.html#mlx.utils.tree_map" title="mlx.utils.tree_map"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.utils.tree_map()</span></code></a> on the parameters. For example, to see the shapes of
|
||||
all the parameters in a <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> do:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">mlx.utils</span> <span class="kn">import</span> <span class="n">tree_map</span>
|
||||
<span class="n">shapes</span> <span class="o">=</span> <span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">mlp</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>As another example, you can count the number of parameters in a <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a>
|
||||
with:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">mlx.utils</span> <span class="kn">import</span> <span class="n">tree_flatten</span>
|
||||
<span class="n">num_params</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">size</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">tree_flatten</span><span class="p">(</span><span class="n">mlp</span><span class="o">.</span><span class="n">parameters</span><span class="p">()))</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</section>
|
||||
</section>
|
||||
<section id="value-and-grad">
|
||||
<h2>Value and grad<a class="headerlink" href="#value-and-grad" title="Permalink to this heading">#</a></h2>
|
||||
<p>Using a <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> does not preclude using MLX’s high order function
|
||||
<h2>Value and Grad<a class="headerlink" href="#value-and-grad" title="Permalink to this heading">#</a></h2>
|
||||
<p>Using a <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> does not preclude using MLX’s high order function
|
||||
transformations (<a class="reference internal" href="_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a>, <a class="reference internal" href="_autosummary/mlx.core.grad.html#mlx.core.grad" title="mlx.core.grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.grad()</span></code></a>, etc.). However,
|
||||
these function transformations assume pure functions, namely the parameters
|
||||
should be passed as an argument to the function being transformed.</p>
|
||||
@@ -680,7 +731,7 @@ should be passed as an argument to the function being transformed.</p>
|
||||
computes the gradients with respect to the trainable parameters of the model.</p>
|
||||
<p>In detail:</p>
|
||||
<ul class="simple">
|
||||
<li><p>it wraps the passed function with a function that calls <a class="reference internal" href="nn/module.html#mlx.nn.Module.update" title="mlx.nn.Module.update"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.update()</span></code></a>
|
||||
<li><p>it wraps the passed function with a function that calls <code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.update()</span></code>
|
||||
to make sure the model is using the provided parameters.</p></li>
|
||||
<li><p>it calls <a class="reference internal" href="_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a> to transform the function into a function
|
||||
that also computes the gradients with respect to the passed parameters.</p></li>
|
||||
@@ -693,124 +744,56 @@ parameters as the first argument to the function returned by
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-obj docutils literal notranslate"><span class="pre">value_and_grad</span></code></a>(model, fn)</p></td>
|
||||
<td><p>Transform the passed function <code class="docutils literal notranslate"><span class="pre">fn</span></code> to a function that computes the gradients of <code class="docutils literal notranslate"><span class="pre">fn</span></code> wrt the model's trainable parameters and also its value.</p></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</section>
|
||||
<section id="neural-network-layers">
|
||||
<h2>Neural Network Layers<a class="headerlink" href="#neural-network-layers" title="Permalink to this heading">#</a></h2>
|
||||
<table class="autosummary longtable table autosummary">
|
||||
<tbody>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Embedding.html#mlx.nn.Embedding" title="mlx.nn.Embedding"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Embedding</span></code></a>(num_embeddings, dims)</p></td>
|
||||
<td><p>Implements a simple lookup table that maps each input integer to a high-dimensional vector.</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.ReLU.html#mlx.nn.ReLU" title="mlx.nn.ReLU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">ReLU</span></code></a>()</p></td>
|
||||
<td><p>Applies the Rectified Linear Unit.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.PReLU.html#mlx.nn.PReLU" title="mlx.nn.PReLU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">PReLU</span></code></a>([num_parameters, init])</p></td>
|
||||
<td><p></p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.GELU.html#mlx.nn.GELU" title="mlx.nn.GELU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">GELU</span></code></a>([approx])</p></td>
|
||||
<td><p>Applies the Gaussian Error Linear Units.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.SiLU.html#mlx.nn.SiLU" title="mlx.nn.SiLU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">SiLU</span></code></a>()</p></td>
|
||||
<td><p>Applies the Sigmoid Linear Unit.</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Step.html#mlx.nn.Step" title="mlx.nn.Step"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Step</span></code></a>([threshold])</p></td>
|
||||
<td><p>Applies the Step Activation Function.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.SELU.html#mlx.nn.SELU" title="mlx.nn.SELU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">SELU</span></code></a>()</p></td>
|
||||
<td><p>Applies the Scaled Exponential Linear Unit.</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Mish.html#mlx.nn.Mish" title="mlx.nn.Mish"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Mish</span></code></a>()</p></td>
|
||||
<td><p>Applies the Mish function, element-wise.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Linear.html#mlx.nn.Linear" title="mlx.nn.Linear"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Linear</span></code></a>(input_dims, output_dims[, bias])</p></td>
|
||||
<td><p>Applies an affine transformation to the input.</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Conv1d.html#mlx.nn.Conv1d" title="mlx.nn.Conv1d"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Conv1d</span></code></a>(in_channels, out_channels, kernel_size)</p></td>
|
||||
<td><p>Applies a 1-dimensional convolution over the multi-channel input sequence.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Conv2d.html#mlx.nn.Conv2d" title="mlx.nn.Conv2d"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Conv2d</span></code></a>(in_channels, out_channels, kernel_size)</p></td>
|
||||
<td><p>Applies a 2-dimensional convolution over the multi-channel input image.</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.LayerNorm.html#mlx.nn.LayerNorm" title="mlx.nn.LayerNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">LayerNorm</span></code></a>(dims[, eps, affine])</p></td>
|
||||
<td><p>Applies layer normalization [1] on the inputs.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.RMSNorm.html#mlx.nn.RMSNorm" title="mlx.nn.RMSNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">RMSNorm</span></code></a>(dims[, eps])</p></td>
|
||||
<td><p>Applies Root Mean Square normalization [1] to the inputs.</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.GroupNorm.html#mlx.nn.GroupNorm" title="mlx.nn.GroupNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">GroupNorm</span></code></a>(num_groups, dims[, eps, affine, ...])</p></td>
|
||||
<td><p>Applies Group Normalization [1] to the inputs.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.RoPE.html#mlx.nn.RoPE" title="mlx.nn.RoPE"><code class="xref py py-obj docutils literal notranslate"><span class="pre">RoPE</span></code></a>(dims[, traditional])</p></td>
|
||||
<td><p>Implements the rotary positional encoding [1].</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.MultiHeadAttention.html#mlx.nn.MultiHeadAttention" title="mlx.nn.MultiHeadAttention"><code class="xref py py-obj docutils literal notranslate"><span class="pre">MultiHeadAttention</span></code></a>(dims, num_heads[, ...])</p></td>
|
||||
<td><p>Implements the scaled dot product attention with multiple heads.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Sequential.html#mlx.nn.Sequential" title="mlx.nn.Sequential"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Sequential</span></code></a>(*modules)</p></td>
|
||||
<td><p>A layer that calls the passed callables in order.</p></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
<p>Layers without parameters (e.g. activation functions) are also provided as
|
||||
simple functions.</p>
|
||||
<table class="autosummary longtable table autosummary">
|
||||
<tbody>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu.html#mlx.nn.gelu" title="mlx.nn.gelu"><code class="xref py py-obj docutils literal notranslate"><span class="pre">gelu</span></code></a>(x)</p></td>
|
||||
<td><p>Applies the Gaussian Error Linear Units function.</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_approx.html#mlx.nn.gelu_approx" title="mlx.nn.gelu_approx"><code class="xref py py-obj docutils literal notranslate"><span class="pre">gelu_approx</span></code></a>(x)</p></td>
|
||||
<td><p>An approximation to Gaussian Error Linear Unit.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_fast_approx.html#mlx.nn.gelu_fast_approx" title="mlx.nn.gelu_fast_approx"><code class="xref py py-obj docutils literal notranslate"><span class="pre">gelu_fast_approx</span></code></a>(x)</p></td>
|
||||
<td><p>A fast approximation to Gaussian Error Linear Unit.</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.relu.html#mlx.nn.relu" title="mlx.nn.relu"><code class="xref py py-obj docutils literal notranslate"><span class="pre">relu</span></code></a>(x)</p></td>
|
||||
<td><p>Applies the Rectified Linear Unit.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.prelu.html#mlx.nn.prelu" title="mlx.nn.prelu"><code class="xref py py-obj docutils literal notranslate"><span class="pre">prelu</span></code></a>(x, alpha)</p></td>
|
||||
<td><p>Applies the element-wise function:</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.silu.html#mlx.nn.silu" title="mlx.nn.silu"><code class="xref py py-obj docutils literal notranslate"><span class="pre">silu</span></code></a>(x)</p></td>
|
||||
<td><p>Applies the Sigmoid Linear Unit.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.step.html#mlx.nn.step" title="mlx.nn.step"><code class="xref py py-obj docutils literal notranslate"><span class="pre">step</span></code></a>(x[, threshold])</p></td>
|
||||
<td><p>Applies the Step Activation Function.</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.selu.html#mlx.nn.selu" title="mlx.nn.selu"><code class="xref py py-obj docutils literal notranslate"><span class="pre">selu</span></code></a>(x)</p></td>
|
||||
<td><p>Applies the Scaled Exponential Linear Unit.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.mish.html#mlx.nn.mish" title="mlx.nn.mish"><code class="xref py py-obj docutils literal notranslate"><span class="pre">mish</span></code></a>(x)</p></td>
|
||||
<td><p>Applies the Mish function, element-wise.</p></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</section>
|
||||
<section id="loss-functions">
|
||||
<h2>Loss Functions<a class="headerlink" href="#loss-functions" title="Permalink to this heading">#</a></h2>
|
||||
<table class="autosummary longtable table autosummary">
|
||||
<tbody>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.cross_entropy.html#mlx.nn.losses.cross_entropy" title="mlx.nn.losses.cross_entropy"><code class="xref py py-obj docutils literal notranslate"><span class="pre">losses.cross_entropy</span></code></a>(logits, targets[, ...])</p></td>
|
||||
<td><p>Computes the cross entropy loss between logits and targets.</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html#mlx.nn.losses.binary_cross_entropy" title="mlx.nn.losses.binary_cross_entropy"><code class="xref py py-obj docutils literal notranslate"><span class="pre">losses.binary_cross_entropy</span></code></a>(inputs, targets)</p></td>
|
||||
<td><p>Computes the binary cross entropy loss between inputs and targets.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.l1_loss.html#mlx.nn.losses.l1_loss" title="mlx.nn.losses.l1_loss"><code class="xref py py-obj docutils literal notranslate"><span class="pre">losses.l1_loss</span></code></a>(predictions, targets[, reduction])</p></td>
|
||||
<td><p>Computes the L1 loss between predictions and targets.</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.mse_loss.html#mlx.nn.losses.mse_loss" title="mlx.nn.losses.mse_loss"><code class="xref py py-obj docutils literal notranslate"><span class="pre">losses.mse_loss</span></code></a>(predictions, targets[, ...])</p></td>
|
||||
<td><p>Computes the mean squared error loss between predictions and targets.</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.nll_loss.html#mlx.nn.losses.nll_loss" title="mlx.nn.losses.nll_loss"><code class="xref py py-obj docutils literal notranslate"><span class="pre">losses.nll_loss</span></code></a>(inputs, targets[, axis, ...])</p></td>
|
||||
<td><p>Computes the negative log likelihood loss between inputs and targets.</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.kl_div_loss.html#mlx.nn.losses.kl_div_loss" title="mlx.nn.losses.kl_div_loss"><code class="xref py py-obj docutils literal notranslate"><span class="pre">losses.kl_div_loss</span></code></a>(inputs, targets[, axis, ...])</p></td>
|
||||
<td><p>Computes the Kullback-Leibler divergence loss between targets and the inputs.</p></td>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Module</span></code></a>()</p></td>
|
||||
<td><p>Base class for building neural networks with MLX.</p></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
<div class="toctree-wrapper compound">
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="nn/layers.html">Layers</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.PReLU.html">mlx.nn.PReLU</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.SiLU.html">mlx.nn.SiLU</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Step.html">mlx.nn.Step</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.SELU.html">mlx.nn.SELU</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Mish.html">mlx.nn.Mish</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="nn/functions.html">Functions</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.gelu.html">mlx.nn.gelu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.gelu_approx.html">mlx.nn.gelu_approx</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.gelu_fast_approx.html">mlx.nn.gelu_fast_approx</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.relu.html">mlx.nn.relu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.prelu.html">mlx.nn.prelu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.silu.html">mlx.nn.silu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.step.html">mlx.nn.step</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.selu.html">mlx.nn.selu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.mish.html">mlx.nn.mish</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="nn/losses.html">Loss Functions</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html">mlx.nn.losses.cross_entropy</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html">mlx.nn.losses.binary_cross_entropy</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.l1_loss.html">mlx.nn.losses.l1_loss</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.mse_loss.html">mlx.nn.losses.mse_loss</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.nll_loss.html">mlx.nn.losses.nll_loss</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html">mlx.nn.losses.kl_div_loss</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</section>
|
||||
</section>
|
||||
|
||||
@@ -861,14 +844,13 @@ simple functions.</p>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#quick-start-with-neural-networks">Quick Start with Neural Networks</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#the-module-class">The Module Class</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#parameters">Parameters</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#updating-the-parameters">Updating the parameters</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#updating-the-parameters">Updating the Parameters</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#inspecting-modules">Inspecting Modules</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#value-and-grad">Value and grad</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#neural-network-layers">Neural Network Layers</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#value-and-grad">Value and Grad</a><ul class="visible nav section-nav flex-column">
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#loss-functions">Loss Functions</a></li>
|
||||
</ul>
|
||||
</nav></div>
|
||||
|
||||
|
Reference in New Issue
Block a user