mlx/docs/build/html/python/nn.html
Awni Hannun fbd10a48d4 docs
2025-06-04 01:01:47 +00:00

362 lines
31 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Neural Networks &mdash; MLX 0.0.0 documentation</title>
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
<script src="../_static/jquery.js"></script>
<script src="../_static/underscore.js"></script>
<script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
<script src="../_static/doctools.js"></script>
<script src="../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="mlx.nn.value_and_grad" href="_autosummary/mlx.nn.value_and_grad.html" />
<link rel="prev" title="mlx.core.fft.irfftn" href="_autosummary/mlx.core.fft.irfftn.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../index.html" class="icon icon-home">
MLX
</a>
<div class="version">
0.0.0
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Install</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../install.html">Build and Install</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Usage</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../quick_start.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../using_streams.html">Using Streams</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../examples/linear_regression.html">Linear Regression</a></li>
<li class="toctree-l1"><a class="reference internal" href="../examples/mlp.html">Multi-Layer Perceptron</a></li>
<li class="toctree-l1"><a class="reference internal" href="../examples/llama-inference.html">LLM inference</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Further Reading</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Developer Documentation</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API Reference</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="array.html">Array</a></li>
<li class="toctree-l1"><a class="reference internal" href="devices_and_streams.html">Devices and Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="ops.html">Operations</a></li>
<li class="toctree-l1"><a class="reference internal" href="random.html">Random</a></li>
<li class="toctree-l1"><a class="reference internal" href="transforms.html">Transforms</a></li>
<li class="toctree-l1"><a class="reference internal" href="fft.html">FFT</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Neural Networks</a><ul>
<li class="toctree-l2"><a class="reference internal" href="#quick-start-with-neural-networks">Quick Start with Neural Networks</a></li>
<li class="toctree-l2"><a class="reference internal" href="#the-module-class">The Module Class</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#parameters">Parameters</a></li>
<li class="toctree-l3"><a class="reference internal" href="#updating-the-parameters">Updating the parameters</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="#value-and-grad">Value and grad</a><ul>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html">mlx.nn.value_and_grad</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="#neural-network-layers">Neural Network Layers</a><ul>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.SiLU.html">mlx.nn.SiLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu.html">mlx.nn.gelu</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_approx.html">mlx.nn.gelu_approx</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_fast_approx.html">mlx.nn.gelu_fast_approx</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.relu.html">mlx.nn.relu</a></li>
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.silu.html">mlx.nn.silu</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="optimizers.html">Optimizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="tree_utils.html">Tree Utils</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../cpp/ops.html">Operations</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../index.html">MLX</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item active">Neural Networks</li>
<li class="wy-breadcrumbs-aside">
<a href="../_sources/python/nn.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<section id="neural-networks">
<span id="nn"></span><h1>Neural Networks<a class="headerlink" href="#neural-networks" title="Permalink to this heading"></a></h1>
<p>Writing arbitrarily complex neural networks in MLX can be done using only
<a class="reference internal" href="_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> and <a class="reference internal" href="_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a>. However, this requires the
user to write again and again the same simple neural network operations as well
as handle all the parameter state and initialization manually and explicitly.</p>
<p>The module <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.nn</span></code> solves this problem by providing an intuitive way of
composing neural network layers, initializing their parameters, freezing them
for finetuning and more.</p>
<section id="quick-start-with-neural-networks">
<h2>Quick Start with Neural Networks<a class="headerlink" href="#quick-start-with-neural-networks" title="Permalink to this heading"></a></h2>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">import</span> <span class="nn">mlx.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_dims</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">out_dims</span><span class="p">),</span>
<span class="p">]</span>
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">l</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">x</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">l</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span>
<span class="c1"># The model is created with all its parameters but nothing is initialized</span>
<span class="c1"># yet because MLX is lazily evaluated</span>
<span class="n">mlp</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
<span class="c1"># We can access its parameters by calling mlp.parameters()</span>
<span class="n">params</span> <span class="o">=</span> <span class="n">mlp</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="n">params</span><span class="p">[</span><span class="s2">&quot;layers&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">][</span><span class="s2">&quot;weight&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="c1"># Printing a parameter will cause it to be evaluated and thus initialized</span>
<span class="nb">print</span><span class="p">(</span><span class="n">params</span><span class="p">[</span><span class="s2">&quot;layers&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span>
<span class="c1"># We can also force evaluate all parameters to initialize the model</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">mlp</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
<span class="c1"># A simple loss function.</span>
<span class="c1"># NOTE: It doesn&#39;t matter how it uses the mlp model. It currently captures</span>
<span class="c1"># it from the local scope. It could be a positional argument or a</span>
<span class="c1"># keyword argument.</span>
<span class="k">def</span> <span class="nf">l2_loss</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="n">y_hat</span> <span class="o">=</span> <span class="n">mlp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">y_hat</span> <span class="o">-</span> <span class="n">y</span><span class="p">)</span><span class="o">.</span><span class="n">square</span><span class="p">()</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="c1"># Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the</span>
<span class="c1"># gradient with respect to `mlp.trainable_parameters()`</span>
<span class="n">loss_and_grad</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">value_and_grad</span><span class="p">(</span><span class="n">mlp</span><span class="p">,</span> <span class="n">l2_loss</span><span class="p">)</span>
</pre></div>
</div>
</section>
<section id="the-module-class">
<span id="module-class"></span><h2>The Module Class<a class="headerlink" href="#the-module-class" title="Permalink to this heading"></a></h2>
<p>The workhorse of any neural network library is the <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> class. In
MLX the <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> class is a container of <a class="reference internal" href="_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> or
<a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> instances. Its main function is to provide a way to
recursively <strong>access</strong> and <strong>update</strong> its parameters and those of its
submodules.</p>
<section id="parameters">
<h3>Parameters<a class="headerlink" href="#parameters" title="Permalink to this heading"></a></h3>
<p>A parameter of a module is any public member of type <a class="reference internal" href="_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> (its
name should not start with <code class="docutils literal notranslate"><span class="pre">_</span></code>). It can be arbitrarily nested in other
<a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> instances or lists and dictionaries.</p>
<p><a class="reference internal" href="nn/module.html#mlx.nn.Module.parameters" title="mlx.nn.Module.parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.parameters()</span></code></a> can be used to extract a nested dictionary with all
the parameters of a module and its submodules.</p>
<p>A <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> can also keep track of “frozen” parameters.
<a class="reference internal" href="nn/module.html#mlx.nn.Module.trainable_parameters" title="mlx.nn.Module.trainable_parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.trainable_parameters()</span></code></a> returns only the subset of
<a class="reference internal" href="nn/module.html#mlx.nn.Module.parameters" title="mlx.nn.Module.parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.parameters()</span></code></a> that is not frozen. When using
<a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a> the gradients returned will be with respect to these
trainable parameters.</p>
</section>
<section id="updating-the-parameters">
<h3>Updating the parameters<a class="headerlink" href="#updating-the-parameters" title="Permalink to this heading"></a></h3>
<p>MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a modules parameters. This action is
performed by <a class="reference internal" href="nn/module.html#mlx.nn.Module.update" title="mlx.nn.Module.update"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.update()</span></code></a>.</p>
</section>
</section>
<section id="value-and-grad">
<h2>Value and grad<a class="headerlink" href="#value-and-grad" title="Permalink to this heading"></a></h2>
<p>Using a <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> does not preclude using MLXs high order function
transformations (<a class="reference internal" href="_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a>, <a class="reference internal" href="_autosummary/mlx.core.grad.html#mlx.core.grad" title="mlx.core.grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.grad()</span></code></a>, etc.). However,
these function transformations assume pure functions, namely the parameters
should be passed as an argument to the function being transformed.</p>
<p>There is an easy pattern to achieve that with MLX modules</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">model</span> <span class="o">=</span> <span class="o">...</span>
<span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">other_inputs</span><span class="p">):</span>
<span class="n">model</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">params</span><span class="p">)</span> <span class="c1"># &lt;---- Necessary to make the model use the passed parameters</span>
<span class="k">return</span> <span class="n">model</span><span class="p">(</span><span class="n">other_inputs</span><span class="p">)</span>
<span class="n">f</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">trainable_parameters</span><span class="p">(),</span> <span class="n">mx</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">10</span><span class="p">,)))</span>
</pre></div>
</div>
<p>However, <a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a> provides precisely this pattern and only
computes the gradients with respect to the trainable parameters of the model.</p>
<p>In detail:</p>
<ul class="simple">
<li><p>it wraps the passed function with a function that calls <a class="reference internal" href="nn/module.html#mlx.nn.Module.update" title="mlx.nn.Module.update"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.update()</span></code></a>
to make sure the model is using the provided parameters.</p></li>
<li><p>it calls <a class="reference internal" href="_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a> to transform the function into a function
that also computes the gradients with respect to the passed parameters.</p></li>
<li><p>it wraps the returned function with a function that passes the trainable
parameters as the first argument to the function returned by
<a class="reference internal" href="_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a></p></li>
</ul>
<table class="autosummary longtable docutils align-default">
<tbody>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-obj docutils literal notranslate"><span class="pre">value_and_grad</span></code></a>(model, fn)</p></td>
<td><p>Transform the passed function <code class="docutils literal notranslate"><span class="pre">fn</span></code> to a function that computes the gradients of <code class="docutils literal notranslate"><span class="pre">fn</span></code> wrt the model's trainable parameters and also its value.</p></td>
</tr>
</tbody>
</table>
</section>
<section id="neural-network-layers">
<h2>Neural Network Layers<a class="headerlink" href="#neural-network-layers" title="Permalink to this heading"></a></h2>
<table class="autosummary longtable docutils align-default">
<tbody>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Embedding.html#mlx.nn.Embedding" title="mlx.nn.Embedding"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Embedding</span></code></a>(num_embeddings, dims)</p></td>
<td><p>Implements a simple lookup table that maps each input integer to a high-dimensional vector.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.ReLU.html#mlx.nn.ReLU" title="mlx.nn.ReLU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">ReLU</span></code></a>()</p></td>
<td><p>Applies the Rectified Linear Unit.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.GELU.html#mlx.nn.GELU" title="mlx.nn.GELU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">GELU</span></code></a>([approx])</p></td>
<td><p>Applies the Gaussian Error Linear Units.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.SiLU.html#mlx.nn.SiLU" title="mlx.nn.SiLU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">SiLU</span></code></a>()</p></td>
<td><p>Applies the Sigmoid Linear Unit.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Linear.html#mlx.nn.Linear" title="mlx.nn.Linear"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Linear</span></code></a>(input_dims, output_dims[, bias])</p></td>
<td><p>Applies an affine transformation to the input.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Conv1d.html#mlx.nn.Conv1d" title="mlx.nn.Conv1d"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Conv1d</span></code></a>(in_channels, out_channels, kernel_size)</p></td>
<td><p>Applies a 1-dimensional convolution over the multi-channel input sequence.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Conv2d.html#mlx.nn.Conv2d" title="mlx.nn.Conv2d"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Conv2d</span></code></a>(in_channels, out_channels, kernel_size)</p></td>
<td><p>Applies a 2-dimensional convolution over the multi-channel input image.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.LayerNorm.html#mlx.nn.LayerNorm" title="mlx.nn.LayerNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">LayerNorm</span></code></a>(dims[, eps, affine])</p></td>
<td><p>Applies layer normalization [1] on the inputs.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.RMSNorm.html#mlx.nn.RMSNorm" title="mlx.nn.RMSNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">RMSNorm</span></code></a>(dims[, eps])</p></td>
<td><p>Applies Root Mean Square normalization [1] to the inputs.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.GroupNorm.html#mlx.nn.GroupNorm" title="mlx.nn.GroupNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">GroupNorm</span></code></a>(num_groups, dims[, eps, affine, ...])</p></td>
<td><p>Applies Group Normalization [1] to the inputs.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.RoPE.html#mlx.nn.RoPE" title="mlx.nn.RoPE"><code class="xref py py-obj docutils literal notranslate"><span class="pre">RoPE</span></code></a>(dims[, traditional])</p></td>
<td><p>Implements the rotary positional encoding [1].</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.MultiHeadAttention.html#mlx.nn.MultiHeadAttention" title="mlx.nn.MultiHeadAttention"><code class="xref py py-obj docutils literal notranslate"><span class="pre">MultiHeadAttention</span></code></a>(dims, num_heads[, ...])</p></td>
<td><p>Implements the scaled dot product attention with multiple heads.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Sequential.html#mlx.nn.Sequential" title="mlx.nn.Sequential"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Sequential</span></code></a>(*modules)</p></td>
<td><p>A layer that calls the passed callables in order.</p></td>
</tr>
</tbody>
</table>
<p>Layers without parameters (e.g. activation functions) are also provided as
simple functions.</p>
<table class="autosummary longtable docutils align-default">
<tbody>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu.html#mlx.nn.gelu" title="mlx.nn.gelu"><code class="xref py py-obj docutils literal notranslate"><span class="pre">gelu</span></code></a>(x)</p></td>
<td><p>Applies the Gaussian Error Linear Units function.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_approx.html#mlx.nn.gelu_approx" title="mlx.nn.gelu_approx"><code class="xref py py-obj docutils literal notranslate"><span class="pre">gelu_approx</span></code></a>(x)</p></td>
<td><p>An approximation to Gaussian Error Linear Unit.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_fast_approx.html#mlx.nn.gelu_fast_approx" title="mlx.nn.gelu_fast_approx"><code class="xref py py-obj docutils literal notranslate"><span class="pre">gelu_fast_approx</span></code></a>(x)</p></td>
<td><p>A fast approximation to Gaussian Error Linear Unit.</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.relu.html#mlx.nn.relu" title="mlx.nn.relu"><code class="xref py py-obj docutils literal notranslate"><span class="pre">relu</span></code></a>(x)</p></td>
<td><p>Applies the Rectified Linear Unit.</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.silu.html#mlx.nn.silu" title="mlx.nn.silu"><code class="xref py py-obj docutils literal notranslate"><span class="pre">silu</span></code></a>(x)</p></td>
<td><p>Applies the Sigmoid Linear Unit.</p></td>
</tr>
</tbody>
</table>
</section>
</section>
</div>
</div>
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
<a href="_autosummary/mlx.core.fft.irfftn.html" class="btn btn-neutral float-left" title="mlx.core.fft.irfftn" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
<a href="_autosummary/mlx.nn.value_and_grad.html" class="btn btn-neutral float-right" title="mlx.nn.value_and_grad" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
</div>
<hr/>
<div role="contentinfo">
<p>&#169; Copyright 2023, MLX Contributors.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>