mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
362 lines
31 KiB
HTML
362 lines
31 KiB
HTML
![]() |
<!DOCTYPE html>
|
|||
|
<html class="writer-html5" lang="en" >
|
|||
|
<head>
|
|||
|
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
|
|||
|
|
|||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
|||
|
<title>Neural Networks — MLX 0.0.0 documentation</title>
|
|||
|
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
|||
|
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
|||
|
<!--[if lt IE 9]>
|
|||
|
<script src="../_static/js/html5shiv.min.js"></script>
|
|||
|
<![endif]-->
|
|||
|
|
|||
|
<script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
|
|||
|
<script src="../_static/jquery.js"></script>
|
|||
|
<script src="../_static/underscore.js"></script>
|
|||
|
<script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
|
|||
|
<script src="../_static/doctools.js"></script>
|
|||
|
<script src="../_static/js/theme.js"></script>
|
|||
|
<link rel="index" title="Index" href="../genindex.html" />
|
|||
|
<link rel="search" title="Search" href="../search.html" />
|
|||
|
<link rel="next" title="mlx.nn.value_and_grad" href="_autosummary/mlx.nn.value_and_grad.html" />
|
|||
|
<link rel="prev" title="mlx.core.fft.irfftn" href="_autosummary/mlx.core.fft.irfftn.html" />
|
|||
|
</head>
|
|||
|
|
|||
|
<body class="wy-body-for-nav">
|
|||
|
<div class="wy-grid-for-nav">
|
|||
|
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
|||
|
<div class="wy-side-scroll">
|
|||
|
<div class="wy-side-nav-search" >
|
|||
|
|
|||
|
|
|||
|
|
|||
|
<a href="../index.html" class="icon icon-home">
|
|||
|
MLX
|
|||
|
</a>
|
|||
|
<div class="version">
|
|||
|
0.0.0
|
|||
|
</div>
|
|||
|
<div role="search">
|
|||
|
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
|
|||
|
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
|
|||
|
<input type="hidden" name="check_keywords" value="yes" />
|
|||
|
<input type="hidden" name="area" value="default" />
|
|||
|
</form>
|
|||
|
</div>
|
|||
|
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
|
|||
|
<p class="caption" role="heading"><span class="caption-text">Install</span></p>
|
|||
|
<ul>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="../install.html">Build and Install</a></li>
|
|||
|
</ul>
|
|||
|
<p class="caption" role="heading"><span class="caption-text">Usage</span></p>
|
|||
|
<ul>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="../quick_start.html">Quick Start Guide</a></li>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="../using_streams.html">Using Streams</a></li>
|
|||
|
</ul>
|
|||
|
<p class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
|||
|
<ul>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="../examples/linear_regression.html">Linear Regression</a></li>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="../examples/mlp.html">Multi-Layer Perceptron</a></li>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="../examples/llama-inference.html">LLM inference</a></li>
|
|||
|
</ul>
|
|||
|
<p class="caption" role="heading"><span class="caption-text">Further Reading</span></p>
|
|||
|
<ul>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Developer Documentation</a></li>
|
|||
|
</ul>
|
|||
|
<p class="caption" role="heading"><span class="caption-text">Python API Reference</span></p>
|
|||
|
<ul class="current">
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="array.html">Array</a></li>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="devices_and_streams.html">Devices and Streams</a></li>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="ops.html">Operations</a></li>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="random.html">Random</a></li>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="transforms.html">Transforms</a></li>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="fft.html">FFT</a></li>
|
|||
|
<li class="toctree-l1 current"><a class="current reference internal" href="#">Neural Networks</a><ul>
|
|||
|
<li class="toctree-l2"><a class="reference internal" href="#quick-start-with-neural-networks">Quick Start with Neural Networks</a></li>
|
|||
|
<li class="toctree-l2"><a class="reference internal" href="#the-module-class">The Module Class</a><ul>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="#parameters">Parameters</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="#updating-the-parameters">Updating the parameters</a></li>
|
|||
|
</ul>
|
|||
|
</li>
|
|||
|
<li class="toctree-l2"><a class="reference internal" href="#value-and-grad">Value and grad</a><ul>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html">mlx.nn.value_and_grad</a></li>
|
|||
|
</ul>
|
|||
|
</li>
|
|||
|
<li class="toctree-l2"><a class="reference internal" href="#neural-network-layers">Neural Network Layers</a><ul>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.SiLU.html">mlx.nn.SiLU</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu.html">mlx.nn.gelu</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_approx.html">mlx.nn.gelu_approx</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_fast_approx.html">mlx.nn.gelu_fast_approx</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.relu.html">mlx.nn.relu</a></li>
|
|||
|
<li class="toctree-l3"><a class="reference internal" href="_autosummary_functions/mlx.nn.silu.html">mlx.nn.silu</a></li>
|
|||
|
</ul>
|
|||
|
</li>
|
|||
|
</ul>
|
|||
|
</li>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="optimizers.html">Optimizers</a></li>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="tree_utils.html">Tree Utils</a></li>
|
|||
|
</ul>
|
|||
|
<p class="caption" role="heading"><span class="caption-text">C++ API Reference</span></p>
|
|||
|
<ul>
|
|||
|
<li class="toctree-l1"><a class="reference internal" href="../cpp/ops.html">Operations</a></li>
|
|||
|
</ul>
|
|||
|
|
|||
|
</div>
|
|||
|
</div>
|
|||
|
</nav>
|
|||
|
|
|||
|
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
|
|||
|
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
|||
|
<a href="../index.html">MLX</a>
|
|||
|
</nav>
|
|||
|
|
|||
|
<div class="wy-nav-content">
|
|||
|
<div class="rst-content">
|
|||
|
<div role="navigation" aria-label="Page navigation">
|
|||
|
<ul class="wy-breadcrumbs">
|
|||
|
<li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
|
|||
|
<li class="breadcrumb-item active">Neural Networks</li>
|
|||
|
<li class="wy-breadcrumbs-aside">
|
|||
|
<a href="../_sources/python/nn.rst.txt" rel="nofollow"> View page source</a>
|
|||
|
</li>
|
|||
|
</ul>
|
|||
|
<hr/>
|
|||
|
</div>
|
|||
|
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
|||
|
<div itemprop="articleBody">
|
|||
|
|
|||
|
<section id="neural-networks">
|
|||
|
<span id="nn"></span><h1>Neural Networks<a class="headerlink" href="#neural-networks" title="Permalink to this heading"></a></h1>
|
|||
|
<p>Writing arbitrarily complex neural networks in MLX can be done using only
|
|||
|
<a class="reference internal" href="_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> and <a class="reference internal" href="_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a>. However, this requires the
|
|||
|
user to write again and again the same simple neural network operations as well
|
|||
|
as handle all the parameter state and initialization manually and explicitly.</p>
|
|||
|
<p>The module <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.nn</span></code> solves this problem by providing an intuitive way of
|
|||
|
composing neural network layers, initializing their parameters, freezing them
|
|||
|
for finetuning and more.</p>
|
|||
|
<section id="quick-start-with-neural-networks">
|
|||
|
<h2>Quick Start with Neural Networks<a class="headerlink" href="#quick-start-with-neural-networks" title="Permalink to this heading"></a></h2>
|
|||
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
|||
|
<span class="kn">import</span> <span class="nn">mlx.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
|||
|
|
|||
|
<span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
|||
|
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
|||
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|||
|
|
|||
|
<span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="p">[</span>
|
|||
|
<span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_dims</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span>
|
|||
|
<span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span>
|
|||
|
<span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">out_dims</span><span class="p">),</span>
|
|||
|
<span class="p">]</span>
|
|||
|
|
|||
|
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
|||
|
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">l</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">):</span>
|
|||
|
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="o">></span> <span class="mi">0</span> <span class="k">else</span> <span class="n">x</span>
|
|||
|
<span class="n">x</span> <span class="o">=</span> <span class="n">l</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
|||
|
<span class="k">return</span> <span class="n">x</span>
|
|||
|
|
|||
|
<span class="c1"># The model is created with all its parameters but nothing is initialized</span>
|
|||
|
<span class="c1"># yet because MLX is lazily evaluated</span>
|
|||
|
<span class="n">mlp</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
|
|||
|
|
|||
|
<span class="c1"># We can access its parameters by calling mlp.parameters()</span>
|
|||
|
<span class="n">params</span> <span class="o">=</span> <span class="n">mlp</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span>
|
|||
|
<span class="nb">print</span><span class="p">(</span><span class="n">params</span><span class="p">[</span><span class="s2">"layers"</span><span class="p">][</span><span class="mi">0</span><span class="p">][</span><span class="s2">"weight"</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|||
|
|
|||
|
<span class="c1"># Printing a parameter will cause it to be evaluated and thus initialized</span>
|
|||
|
<span class="nb">print</span><span class="p">(</span><span class="n">params</span><span class="p">[</span><span class="s2">"layers"</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span>
|
|||
|
|
|||
|
<span class="c1"># We can also force evaluate all parameters to initialize the model</span>
|
|||
|
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">mlp</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
|
|||
|
|
|||
|
<span class="c1"># A simple loss function.</span>
|
|||
|
<span class="c1"># NOTE: It doesn't matter how it uses the mlp model. It currently captures</span>
|
|||
|
<span class="c1"># it from the local scope. It could be a positional argument or a</span>
|
|||
|
<span class="c1"># keyword argument.</span>
|
|||
|
<span class="k">def</span> <span class="nf">l2_loss</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
|||
|
<span class="n">y_hat</span> <span class="o">=</span> <span class="n">mlp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
|||
|
<span class="k">return</span> <span class="p">(</span><span class="n">y_hat</span> <span class="o">-</span> <span class="n">y</span><span class="p">)</span><span class="o">.</span><span class="n">square</span><span class="p">()</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
|
|||
|
|
|||
|
<span class="c1"># Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the</span>
|
|||
|
<span class="c1"># gradient with respect to `mlp.trainable_parameters()`</span>
|
|||
|
<span class="n">loss_and_grad</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">value_and_grad</span><span class="p">(</span><span class="n">mlp</span><span class="p">,</span> <span class="n">l2_loss</span><span class="p">)</span>
|
|||
|
</pre></div>
|
|||
|
</div>
|
|||
|
</section>
|
|||
|
<section id="the-module-class">
|
|||
|
<span id="module-class"></span><h2>The Module Class<a class="headerlink" href="#the-module-class" title="Permalink to this heading"></a></h2>
|
|||
|
<p>The workhorse of any neural network library is the <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> class. In
|
|||
|
MLX the <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> class is a container of <a class="reference internal" href="_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> or
|
|||
|
<a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> instances. Its main function is to provide a way to
|
|||
|
recursively <strong>access</strong> and <strong>update</strong> its parameters and those of its
|
|||
|
submodules.</p>
|
|||
|
<section id="parameters">
|
|||
|
<h3>Parameters<a class="headerlink" href="#parameters" title="Permalink to this heading"></a></h3>
|
|||
|
<p>A parameter of a module is any public member of type <a class="reference internal" href="_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> (its
|
|||
|
name should not start with <code class="docutils literal notranslate"><span class="pre">_</span></code>). It can be arbitrarily nested in other
|
|||
|
<a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> instances or lists and dictionaries.</p>
|
|||
|
<p><a class="reference internal" href="nn/module.html#mlx.nn.Module.parameters" title="mlx.nn.Module.parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.parameters()</span></code></a> can be used to extract a nested dictionary with all
|
|||
|
the parameters of a module and its submodules.</p>
|
|||
|
<p>A <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> can also keep track of “frozen” parameters.
|
|||
|
<a class="reference internal" href="nn/module.html#mlx.nn.Module.trainable_parameters" title="mlx.nn.Module.trainable_parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.trainable_parameters()</span></code></a> returns only the subset of
|
|||
|
<a class="reference internal" href="nn/module.html#mlx.nn.Module.parameters" title="mlx.nn.Module.parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.parameters()</span></code></a> that is not frozen. When using
|
|||
|
<a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a> the gradients returned will be with respect to these
|
|||
|
trainable parameters.</p>
|
|||
|
</section>
|
|||
|
<section id="updating-the-parameters">
|
|||
|
<h3>Updating the parameters<a class="headerlink" href="#updating-the-parameters" title="Permalink to this heading"></a></h3>
|
|||
|
<p>MLX modules allow accessing and updating individual parameters. However, most
|
|||
|
times we need to update large subsets of a module’s parameters. This action is
|
|||
|
performed by <a class="reference internal" href="nn/module.html#mlx.nn.Module.update" title="mlx.nn.Module.update"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.update()</span></code></a>.</p>
|
|||
|
</section>
|
|||
|
</section>
|
|||
|
<section id="value-and-grad">
|
|||
|
<h2>Value and grad<a class="headerlink" href="#value-and-grad" title="Permalink to this heading"></a></h2>
|
|||
|
<p>Using a <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> does not preclude using MLX’s high order function
|
|||
|
transformations (<a class="reference internal" href="_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a>, <a class="reference internal" href="_autosummary/mlx.core.grad.html#mlx.core.grad" title="mlx.core.grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.grad()</span></code></a>, etc.). However,
|
|||
|
these function transformations assume pure functions, namely the parameters
|
|||
|
should be passed as an argument to the function being transformed.</p>
|
|||
|
<p>There is an easy pattern to achieve that with MLX modules</p>
|
|||
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">model</span> <span class="o">=</span> <span class="o">...</span>
|
|||
|
|
|||
|
<span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">other_inputs</span><span class="p">):</span>
|
|||
|
<span class="n">model</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">params</span><span class="p">)</span> <span class="c1"># <---- Necessary to make the model use the passed parameters</span>
|
|||
|
<span class="k">return</span> <span class="n">model</span><span class="p">(</span><span class="n">other_inputs</span><span class="p">)</span>
|
|||
|
|
|||
|
<span class="n">f</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">trainable_parameters</span><span class="p">(),</span> <span class="n">mx</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">10</span><span class="p">,)))</span>
|
|||
|
</pre></div>
|
|||
|
</div>
|
|||
|
<p>However, <a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a> provides precisely this pattern and only
|
|||
|
computes the gradients with respect to the trainable parameters of the model.</p>
|
|||
|
<p>In detail:</p>
|
|||
|
<ul class="simple">
|
|||
|
<li><p>it wraps the passed function with a function that calls <a class="reference internal" href="nn/module.html#mlx.nn.Module.update" title="mlx.nn.Module.update"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.update()</span></code></a>
|
|||
|
to make sure the model is using the provided parameters.</p></li>
|
|||
|
<li><p>it calls <a class="reference internal" href="_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a> to transform the function into a function
|
|||
|
that also computes the gradients with respect to the passed parameters.</p></li>
|
|||
|
<li><p>it wraps the returned function with a function that passes the trainable
|
|||
|
parameters as the first argument to the function returned by
|
|||
|
<a class="reference internal" href="_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a></p></li>
|
|||
|
</ul>
|
|||
|
<table class="autosummary longtable docutils align-default">
|
|||
|
<tbody>
|
|||
|
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-obj docutils literal notranslate"><span class="pre">value_and_grad</span></code></a>(model, fn)</p></td>
|
|||
|
<td><p>Transform the passed function <code class="docutils literal notranslate"><span class="pre">fn</span></code> to a function that computes the gradients of <code class="docutils literal notranslate"><span class="pre">fn</span></code> wrt the model's trainable parameters and also its value.</p></td>
|
|||
|
</tr>
|
|||
|
</tbody>
|
|||
|
</table>
|
|||
|
</section>
|
|||
|
<section id="neural-network-layers">
|
|||
|
<h2>Neural Network Layers<a class="headerlink" href="#neural-network-layers" title="Permalink to this heading"></a></h2>
|
|||
|
<table class="autosummary longtable docutils align-default">
|
|||
|
<tbody>
|
|||
|
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Embedding.html#mlx.nn.Embedding" title="mlx.nn.Embedding"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Embedding</span></code></a>(num_embeddings, dims)</p></td>
|
|||
|
<td><p>Implements a simple lookup table that maps each input integer to a high-dimensional vector.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.ReLU.html#mlx.nn.ReLU" title="mlx.nn.ReLU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">ReLU</span></code></a>()</p></td>
|
|||
|
<td><p>Applies the Rectified Linear Unit.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.GELU.html#mlx.nn.GELU" title="mlx.nn.GELU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">GELU</span></code></a>([approx])</p></td>
|
|||
|
<td><p>Applies the Gaussian Error Linear Units.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.SiLU.html#mlx.nn.SiLU" title="mlx.nn.SiLU"><code class="xref py py-obj docutils literal notranslate"><span class="pre">SiLU</span></code></a>()</p></td>
|
|||
|
<td><p>Applies the Sigmoid Linear Unit.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Linear.html#mlx.nn.Linear" title="mlx.nn.Linear"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Linear</span></code></a>(input_dims, output_dims[, bias])</p></td>
|
|||
|
<td><p>Applies an affine transformation to the input.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Conv1d.html#mlx.nn.Conv1d" title="mlx.nn.Conv1d"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Conv1d</span></code></a>(in_channels, out_channels, kernel_size)</p></td>
|
|||
|
<td><p>Applies a 1-dimensional convolution over the multi-channel input sequence.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Conv2d.html#mlx.nn.Conv2d" title="mlx.nn.Conv2d"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Conv2d</span></code></a>(in_channels, out_channels, kernel_size)</p></td>
|
|||
|
<td><p>Applies a 2-dimensional convolution over the multi-channel input image.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.LayerNorm.html#mlx.nn.LayerNorm" title="mlx.nn.LayerNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">LayerNorm</span></code></a>(dims[, eps, affine])</p></td>
|
|||
|
<td><p>Applies layer normalization [1] on the inputs.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.RMSNorm.html#mlx.nn.RMSNorm" title="mlx.nn.RMSNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">RMSNorm</span></code></a>(dims[, eps])</p></td>
|
|||
|
<td><p>Applies Root Mean Square normalization [1] to the inputs.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.GroupNorm.html#mlx.nn.GroupNorm" title="mlx.nn.GroupNorm"><code class="xref py py-obj docutils literal notranslate"><span class="pre">GroupNorm</span></code></a>(num_groups, dims[, eps, affine, ...])</p></td>
|
|||
|
<td><p>Applies Group Normalization [1] to the inputs.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.RoPE.html#mlx.nn.RoPE" title="mlx.nn.RoPE"><code class="xref py py-obj docutils literal notranslate"><span class="pre">RoPE</span></code></a>(dims[, traditional])</p></td>
|
|||
|
<td><p>Implements the rotary positional encoding [1].</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.MultiHeadAttention.html#mlx.nn.MultiHeadAttention" title="mlx.nn.MultiHeadAttention"><code class="xref py py-obj docutils literal notranslate"><span class="pre">MultiHeadAttention</span></code></a>(dims, num_heads[, ...])</p></td>
|
|||
|
<td><p>Implements the scaled dot product attention with multiple heads.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Sequential.html#mlx.nn.Sequential" title="mlx.nn.Sequential"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Sequential</span></code></a>(*modules)</p></td>
|
|||
|
<td><p>A layer that calls the passed callables in order.</p></td>
|
|||
|
</tr>
|
|||
|
</tbody>
|
|||
|
</table>
|
|||
|
<p>Layers without parameters (e.g. activation functions) are also provided as
|
|||
|
simple functions.</p>
|
|||
|
<table class="autosummary longtable docutils align-default">
|
|||
|
<tbody>
|
|||
|
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu.html#mlx.nn.gelu" title="mlx.nn.gelu"><code class="xref py py-obj docutils literal notranslate"><span class="pre">gelu</span></code></a>(x)</p></td>
|
|||
|
<td><p>Applies the Gaussian Error Linear Units function.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_approx.html#mlx.nn.gelu_approx" title="mlx.nn.gelu_approx"><code class="xref py py-obj docutils literal notranslate"><span class="pre">gelu_approx</span></code></a>(x)</p></td>
|
|||
|
<td><p>An approximation to Gaussian Error Linear Unit.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.gelu_fast_approx.html#mlx.nn.gelu_fast_approx" title="mlx.nn.gelu_fast_approx"><code class="xref py py-obj docutils literal notranslate"><span class="pre">gelu_fast_approx</span></code></a>(x)</p></td>
|
|||
|
<td><p>A fast approximation to Gaussian Error Linear Unit.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.relu.html#mlx.nn.relu" title="mlx.nn.relu"><code class="xref py py-obj docutils literal notranslate"><span class="pre">relu</span></code></a>(x)</p></td>
|
|||
|
<td><p>Applies the Rectified Linear Unit.</p></td>
|
|||
|
</tr>
|
|||
|
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary_functions/mlx.nn.silu.html#mlx.nn.silu" title="mlx.nn.silu"><code class="xref py py-obj docutils literal notranslate"><span class="pre">silu</span></code></a>(x)</p></td>
|
|||
|
<td><p>Applies the Sigmoid Linear Unit.</p></td>
|
|||
|
</tr>
|
|||
|
</tbody>
|
|||
|
</table>
|
|||
|
</section>
|
|||
|
</section>
|
|||
|
|
|||
|
|
|||
|
</div>
|
|||
|
</div>
|
|||
|
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
|
|||
|
<a href="_autosummary/mlx.core.fft.irfftn.html" class="btn btn-neutral float-left" title="mlx.core.fft.irfftn" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
|||
|
<a href="_autosummary/mlx.nn.value_and_grad.html" class="btn btn-neutral float-right" title="mlx.nn.value_and_grad" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
|||
|
</div>
|
|||
|
|
|||
|
<hr/>
|
|||
|
|
|||
|
<div role="contentinfo">
|
|||
|
<p>© Copyright 2023, MLX Contributors.</p>
|
|||
|
</div>
|
|||
|
|
|||
|
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
|||
|
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
|||
|
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
|||
|
|
|||
|
|
|||
|
</footer>
|
|||
|
</div>
|
|||
|
</div>
|
|||
|
</section>
|
|||
|
</div>
|
|||
|
<script>
|
|||
|
jQuery(function () {
|
|||
|
SphinxRtdTheme.Navigation.enable(true);
|
|||
|
});
|
|||
|
</script>
|
|||
|
|
|||
|
</body>
|
|||
|
</html>
|