docs update

This commit is contained in:
Awni Hannun
2024-01-17 17:15:29 -08:00
committed by CircleCI Docs
parent d9d0777c2e
commit 30ea2df988
611 changed files with 15484 additions and 9815 deletions

View File

@@ -9,7 +9,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<title>Layers &#8212; MLX 0.0.7 documentation</title>
<title>Layers &#8212; MLX 0.0.9 documentation</title>
@@ -46,7 +46,7 @@
<script>DOCUMENTATION_OPTIONS.pagename = 'python/nn/layers';</script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="mlx.nn.Sequential" href="_autosummary/mlx.nn.Sequential.html" />
<link rel="next" title="mlx.nn.ALiBi" href="_autosummary/mlx.nn.ALiBi.html" />
<link rel="prev" title="mlx.nn.Module.update_modules" href="_autosummary/mlx.nn.Module.update_modules.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
@@ -134,8 +134,8 @@
<img src="../../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.7 documentation - Home"/>
<script>document.write(`<img src="../../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.7 documentation - Home"/>`);</script>
<img src="../../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.9 documentation - Home"/>
<script>document.write(`<img src="../../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.9 documentation - Home"/>`);</script>
</a></div>
@@ -152,6 +152,7 @@
<li class="toctree-l1"><a class="reference internal" href="../../usage/unified_memory.html">Unified Memory</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../usage/indexing.html">Indexing Arrays</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../usage/saving_and_loading.html">Saving and Loading Arrays</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../usage/function_transforms.html">Function Transforms</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../usage/numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../usage/using_streams.html">Using Streams</a></li>
</ul>
@@ -256,6 +257,10 @@
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.greater_equal.html">mlx.core.greater_equal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.identity.html">mlx.core.identity</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.inner.html">mlx.core.inner</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.isnan.html">mlx.core.isnan</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.isinf.html">mlx.core.isinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.less.html">mlx.core.less</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.linspace.html">mlx.core.linspace</a></li>
@@ -327,16 +332,16 @@
</ul>
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../random.html">Random</a><input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-4"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.seed.html">mlx.core.random.seed</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.key.html">mlx.core.random.key</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.split.html">mlx.core.random.split</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.bernoulli.html">mlx.core.random.bernoulli</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.categorical.html">mlx.core.random.categorical</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.gumbel.html">mlx.core.random.gumbel</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.key.html">mlx.core.random.key</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.normal.html">mlx.core.random.normal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.randint.html">mlx.core.random.randint</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.uniform.html">mlx.core.random.uniform</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.seed.html">mlx.core.random.seed</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.split.html">mlx.core.random.split</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.truncated_normal.html">mlx.core.random.truncated_normal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../_autosummary/mlx.core.random.uniform.html">mlx.core.random.uniform</a></li>
</ul>
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../transforms.html">Transforms</a><input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-5"><i class="fa-solid fa-chevron-down"></i></label><ul>
@@ -392,59 +397,59 @@
</ul>
</li>
<li class="toctree-l2 current active has-children"><a class="current reference internal" href="#">Layers</a><input checked="" class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-10"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.PReLU.html">mlx.nn.PReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.SiLU.html">mlx.nn.SiLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Step.html">mlx.nn.Step</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.SELU.html">mlx.nn.SELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Mish.html">mlx.nn.Mish</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.QuantizedLinear.html">mlx.nn.QuantizedLinear</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.ALiBi.html">mlx.nn.ALiBi</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.BatchNorm.html">mlx.nn.BatchNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.BatchNorm.html">mlx.nn.BatchNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.InstanceNorm.html">mlx.nn.InstanceNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Dropout.html">mlx.nn.Dropout</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Dropout2d.html">mlx.nn.Dropout2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Dropout3d.html">mlx.nn.Dropout3d</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Transformer.html">mlx.nn.Transformer</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.InstanceNorm.html">mlx.nn.InstanceNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Mish.html">mlx.nn.Mish</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.ALiBi.html">mlx.nn.ALiBi</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.PReLU.html">mlx.nn.PReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.QuantizedLinear.html">mlx.nn.QuantizedLinear</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.SELU.html">mlx.nn.SELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.SiLU.html">mlx.nn.SiLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.SinusoidalPositionalEncoding.html">mlx.nn.SinusoidalPositionalEncoding</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Step.html">mlx.nn.Step</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Transformer.html">mlx.nn.Transformer</a></li>
</ul>
</li>
<li class="toctree-l2 has-children"><a class="reference internal" href="functions.html">Functions</a><input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-11"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu.html">mlx.nn.gelu</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_approx.html">mlx.nn.gelu_approx</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_fast_approx.html">mlx.nn.gelu_fast_approx</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.relu.html">mlx.nn.relu</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.mish.html">mlx.nn.mish</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.prelu.html">mlx.nn.prelu</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.relu.html">mlx.nn.relu</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.selu.html">mlx.nn.selu</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.silu.html">mlx.nn.silu</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.step.html">mlx.nn.step</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.selu.html">mlx.nn.selu</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.mish.html">mlx.nn.mish</a></li>
</ul>
</li>
<li class="toctree-l2 has-children"><a class="reference internal" href="losses.html">Loss Functions</a><input class="toctree-checkbox" id="toctree-checkbox-12" name="toctree-checkbox-12" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-12"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html">mlx.nn.losses.binary_cross_entropy</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.html">mlx.nn.losses.cosine_similarity_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.cross_entropy.html">mlx.nn.losses.cross_entropy</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.hinge_loss.html">mlx.nn.losses.hinge_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.huber_loss.html">mlx.nn.losses.huber_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.kl_div_loss.html">mlx.nn.losses.kl_div_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.l1_loss.html">mlx.nn.losses.l1_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.log_cosh_loss.html">mlx.nn.losses.log_cosh_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.mse_loss.html">mlx.nn.losses.mse_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.nll_loss.html">mlx.nn.losses.nll_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html">mlx.nn.losses.smooth_l1_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.triplet_loss.html">mlx.nn.losses.triplet_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.hinge_loss.html">mlx.nn.losses.hinge_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.huber_loss.html">mlx.nn.losses.huber_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.log_cosh_loss.html">mlx.nn.losses.log_cosh_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.html">mlx.nn.losses.cosine_similarity_loss</a></li>
</ul>
</li>
</ul>
@@ -645,59 +650,17 @@ document.write(`
<span id="id1"></span><h1>Layers<a class="headerlink" href="#layers" title="Permalink to this heading">#</a></h1>
<table class="autosummary longtable table autosummary">
<tbody>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Sequential.html#mlx.nn.Sequential" title="mlx.nn.Sequential"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Sequential</span></code></a>(*modules)</p></td>
<td><p>A layer that calls the passed callables in order.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.ReLU.html#mlx.nn.ReLU" title="mlx.nn.ReLU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">ReLU</span></code></a>()</p></td>
<td><p>Applies the Rectified Linear Unit.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.PReLU.html#mlx.nn.PReLU" title="mlx.nn.PReLU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">PReLU</span></code></a>([num_parameters, init])</p></td>
<td><p>Applies the element-wise parametric ReLU.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.GELU.html#mlx.nn.GELU" title="mlx.nn.GELU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">GELU</span></code></a>([approx])</p></td>
<td><p>Applies the Gaussian Error Linear Units.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.SiLU.html#mlx.nn.SiLU" title="mlx.nn.SiLU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">SiLU</span></code></a>()</p></td>
<td><p>Applies the Sigmoid Linear Unit.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Step.html#mlx.nn.Step" title="mlx.nn.Step"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Step</span></code></a>([threshold])</p></td>
<td><p>Applies the Step Activation Function.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.SELU.html#mlx.nn.SELU" title="mlx.nn.SELU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">SELU</span></code></a>()</p></td>
<td><p>Applies the Scaled Exponential Linear Unit.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Mish.html#mlx.nn.Mish" title="mlx.nn.Mish"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Mish</span></code></a>()</p></td>
<td><p>Applies the Mish function, element-wise.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Embedding.html#mlx.nn.Embedding" title="mlx.nn.Embedding"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Embedding</span></code></a>(num_embeddings, dims)</p></td>
<td><p>Implements a simple lookup table that maps each input integer to a high-dimensional vector.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Linear.html#mlx.nn.Linear" title="mlx.nn.Linear"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Linear</span></code></a>(input_dims, output_dims[, bias])</p></td>
<td><p>Applies an affine transformation to the input.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.QuantizedLinear.html#mlx.nn.QuantizedLinear" title="mlx.nn.QuantizedLinear"><code class="xref py py-obj docutils literal notranslate"><span class="pre">QuantizedLinear</span></code></a>(input_dims, output_dims[, ...])</p></td>
<td><p>Applies an affine transformation to the input using a quantized weight matrix.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Conv1d.html#mlx.nn.Conv1d" title="mlx.nn.Conv1d"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Conv1d</span></code></a>(in_channels, out_channels, kernel_size)</p></td>
<td><p>Applies a 1-dimensional convolution over the multi-channel input sequence.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Conv2d.html#mlx.nn.Conv2d" title="mlx.nn.Conv2d"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Conv2d</span></code></a>(in_channels, out_channels, kernel_size)</p></td>
<td><p>Applies a 2-dimensional convolution over the multi-channel input image.</p></td>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.ALiBi.html#mlx.nn.ALiBi" title="mlx.nn.ALiBi"><code class="xref py py-obj docutils literal notranslate"><span class="pre">ALiBi</span></code></a>()</p></td>
<td><p></p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.BatchNorm.html#mlx.nn.BatchNorm" title="mlx.nn.BatchNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">BatchNorm</span></code></a>(num_features[, eps, momentum, ...])</p></td>
<td><p>Applies Batch Normalization over a 2D or 3D input.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.LayerNorm.html#mlx.nn.LayerNorm" title="mlx.nn.LayerNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">LayerNorm</span></code></a>(dims[, eps, affine])</p></td>
<td><p>Applies layer normalization [1] on the inputs.</p></td>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Conv1d.html#mlx.nn.Conv1d" title="mlx.nn.Conv1d"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Conv1d</span></code></a>(in_channels, out_channels, kernel_size)</p></td>
<td><p>Applies a 1-dimensional convolution over the multi-channel input sequence.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.RMSNorm.html#mlx.nn.RMSNorm" title="mlx.nn.RMSNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">RMSNorm</span></code></a>(dims[, eps])</p></td>
<td><p>Applies Root Mean Square normalization [1] to the inputs.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.GroupNorm.html#mlx.nn.GroupNorm" title="mlx.nn.GroupNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">GroupNorm</span></code></a>(num_groups, dims[, eps, affine, ...])</p></td>
<td><p>Applies Group Normalization [1] to the inputs.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.InstanceNorm.html#mlx.nn.InstanceNorm" title="mlx.nn.InstanceNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">InstanceNorm</span></code></a>(dims[, eps, affine])</p></td>
<td><p>Applies instance normalization [1] on the inputs.</p></td>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Conv2d.html#mlx.nn.Conv2d" title="mlx.nn.Conv2d"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Conv2d</span></code></a>(in_channels, out_channels, kernel_size)</p></td>
<td><p>Applies a 2-dimensional convolution over the multi-channel input image.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Dropout.html#mlx.nn.Dropout" title="mlx.nn.Dropout"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Dropout</span></code></a>([p])</p></td>
<td><p>Randomly zero a portion of the elements during training.</p></td>
@@ -708,21 +671,63 @@ document.write(`
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Dropout3d.html#mlx.nn.Dropout3d" title="mlx.nn.Dropout3d"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Dropout3d</span></code></a>([p])</p></td>
<td><p>Apply 3D channel-wise dropout during training.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Transformer.html#mlx.nn.Transformer" title="mlx.nn.Transformer"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Transformer</span></code></a>(dims, num_heads, ...)</p></td>
<td><p>Implements a standard Transformer model.</p></td>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Embedding.html#mlx.nn.Embedding" title="mlx.nn.Embedding"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Embedding</span></code></a>(num_embeddings, dims)</p></td>
<td><p>Implements a simple lookup table that maps each input integer to a high-dimensional vector.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.GELU.html#mlx.nn.GELU" title="mlx.nn.GELU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">GELU</span></code></a>([approx])</p></td>
<td><p>Applies the Gaussian Error Linear Units.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.GroupNorm.html#mlx.nn.GroupNorm" title="mlx.nn.GroupNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">GroupNorm</span></code></a>(num_groups, dims[, eps, affine, ...])</p></td>
<td><p>Applies Group Normalization [1] to the inputs.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.InstanceNorm.html#mlx.nn.InstanceNorm" title="mlx.nn.InstanceNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">InstanceNorm</span></code></a>(dims[, eps, affine])</p></td>
<td><p>Applies instance normalization [1] on the inputs.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.LayerNorm.html#mlx.nn.LayerNorm" title="mlx.nn.LayerNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">LayerNorm</span></code></a>(dims[, eps, affine])</p></td>
<td><p>Applies layer normalization [1] on the inputs.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Linear.html#mlx.nn.Linear" title="mlx.nn.Linear"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Linear</span></code></a>(input_dims, output_dims[, bias])</p></td>
<td><p>Applies an affine transformation to the input.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Mish.html#mlx.nn.Mish" title="mlx.nn.Mish"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Mish</span></code></a>()</p></td>
<td><p>Applies the Mish function, element-wise.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.MultiHeadAttention.html#mlx.nn.MultiHeadAttention" title="mlx.nn.MultiHeadAttention"><code class="xref py py-obj docutils literal notranslate"><span class="pre">MultiHeadAttention</span></code></a>(dims, num_heads[, ...])</p></td>
<td><p>Implements the scaled dot product attention with multiple heads.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.ALiBi.html#mlx.nn.ALiBi" title="mlx.nn.ALiBi"><code class="xref py py-obj docutils literal notranslate"><span class="pre">ALiBi</span></code></a>()</p></td>
<td><p></p></td>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.PReLU.html#mlx.nn.PReLU" title="mlx.nn.PReLU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">PReLU</span></code></a>([num_parameters, init])</p></td>
<td><p>Applies the element-wise parametric ReLU.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.RoPE.html#mlx.nn.RoPE" title="mlx.nn.RoPE"><code class="xref py py-obj docutils literal notranslate"><span class="pre">RoPE</span></code></a>(dims[, traditional, base, scale])</p></td>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.QuantizedLinear.html#mlx.nn.QuantizedLinear" title="mlx.nn.QuantizedLinear"><code class="xref py py-obj docutils literal notranslate"><span class="pre">QuantizedLinear</span></code></a>(input_dims, output_dims[, ...])</p></td>
<td><p>Applies an affine transformation to the input using a quantized weight matrix.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.RMSNorm.html#mlx.nn.RMSNorm" title="mlx.nn.RMSNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">RMSNorm</span></code></a>(dims[, eps])</p></td>
<td><p>Applies Root Mean Square normalization [1] to the inputs.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.ReLU.html#mlx.nn.ReLU" title="mlx.nn.ReLU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">ReLU</span></code></a>()</p></td>
<td><p>Applies the Rectified Linear Unit.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.RoPE.html#mlx.nn.RoPE" title="mlx.nn.RoPE"><code class="xref py py-obj docutils literal notranslate"><span class="pre">RoPE</span></code></a>(dims[, traditional, base, scale])</p></td>
<td><p>Implements the rotary positional encoding.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.SELU.html#mlx.nn.SELU" title="mlx.nn.SELU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">SELU</span></code></a>()</p></td>
<td><p>Applies the Scaled Exponential Linear Unit.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Sequential.html#mlx.nn.Sequential" title="mlx.nn.Sequential"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Sequential</span></code></a>(*modules)</p></td>
<td><p>A layer that calls the passed callables in order.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.SiLU.html#mlx.nn.SiLU" title="mlx.nn.SiLU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">SiLU</span></code></a>()</p></td>
<td><p>Applies the Sigmoid Linear Unit.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.SinusoidalPositionalEncoding.html#mlx.nn.SinusoidalPositionalEncoding" title="mlx.nn.SinusoidalPositionalEncoding"><code class="xref py py-obj docutils literal notranslate"><span class="pre">SinusoidalPositionalEncoding</span></code></a>(dims[, ...])</p></td>
<td><p>Implements sinusoidal positional encoding.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Step.html#mlx.nn.Step" title="mlx.nn.Step"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Step</span></code></a>([threshold])</p></td>
<td><p>Applies the Step Activation Function.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Transformer.html#mlx.nn.Transformer" title="mlx.nn.Transformer"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Transformer</span></code></a>(dims, num_heads, ...)</p></td>
<td><p>Implements a standard Transformer model.</p></td>
</tr>
</tbody>
</table>
</section>
@@ -748,11 +753,11 @@ document.write(`
</div>
</a>
<a class="right-next"
href="_autosummary/mlx.nn.Sequential.html"
href="_autosummary/mlx.nn.ALiBi.html"
title="next page">
<div class="prev-next-info">
<p class="prev-next-subtitle">next</p>
<p class="prev-next-title">mlx.nn.Sequential</p>
<p class="prev-next-title">mlx.nn.ALiBi</p>
</div>
<i class="fa-solid fa-angle-right"></i>
</a>