mlx/docs/build/html/examples/mlp.html
Awni Hannun fbd10a48d4 docs
2025-06-04 01:01:47 +00:00

252 lines
22 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Multi-Layer Perceptron &mdash; MLX 0.0.0 documentation</title>
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
<script src="../_static/jquery.js"></script>
<script src="../_static/underscore.js"></script>
<script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
<script src="../_static/doctools.js"></script>
<script src="../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="LLM inference" href="llama-inference.html" />
<link rel="prev" title="Linear Regression" href="linear_regression.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../index.html" class="icon icon-home">
MLX
</a>
<div class="version">
0.0.0
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Install</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../install.html">Build and Install</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Usage</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../quick_start.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../using_streams.html">Using Streams</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="linear_regression.html">Linear Regression</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Multi-Layer Perceptron</a></li>
<li class="toctree-l1"><a class="reference internal" href="llama-inference.html">LLM inference</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Further Reading</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Developer Documentation</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../python/array.html">Array</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/ops.html">Operations</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/random.html">Random</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/transforms.html">Transforms</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/fft.html">FFT</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/nn.html">Neural Networks</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/optimizers.html">Optimizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../cpp/ops.html">Operations</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../index.html">MLX</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item active">Multi-Layer Perceptron</li>
<li class="wy-breadcrumbs-aside">
<a href="../_sources/examples/mlp.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<section id="multi-layer-perceptron">
<span id="mlp"></span><h1>Multi-Layer Perceptron<a class="headerlink" href="#multi-layer-perceptron" title="Permalink to this heading"></a></h1>
<p>In this example well learn to use <code class="docutils literal notranslate"><span class="pre">mlx.nn</span></code> by implementing a simple
multi-layer perceptron to classify MNIST.</p>
<p>As a first step import the MLX packages we need:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">import</span> <span class="nn">mlx.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="kn">import</span> <span class="nn">mlx.optimizers</span> <span class="k">as</span> <span class="nn">optim</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
</pre></div>
</div>
<p>The model is defined as the <code class="docutils literal notranslate"><span class="pre">MLP</span></code> class which inherits from
<a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a>. We follow the standard idiom to make a new module:</p>
<ol class="arabic simple">
<li><p>Define an <code class="docutils literal notranslate"><span class="pre">__init__</span></code> where the parameters and/or submodules are setup. See
the <a class="reference internal" href="../python/nn.html#module-class"><span class="std std-ref">Module class docs</span></a> for more information on how
<a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> registers parameters.</p></li>
<li><p>Define a <code class="docutils literal notranslate"><span class="pre">__call__</span></code> where the computation is implemented.</p></li>
</ol>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">:</span> <span class="nb">int</span>
<span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">input_dim</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">hidden_dim</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_layers</span> <span class="o">+</span> <span class="p">[</span><span class="n">output_dim</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">idim</span><span class="p">,</span> <span class="n">odim</span><span class="p">)</span>
<span class="k">for</span> <span class="n">idim</span><span class="p">,</span> <span class="n">odim</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span>
<span class="p">]</span>
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">l</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">](</span><span class="n">x</span><span class="p">)</span>
</pre></div>
</div>
<p>We define the loss function which takes the mean of the per-example cross
entropy loss. The <code class="docutils literal notranslate"><span class="pre">mlx.nn.losses</span></code> sub-package has implementations of some
commonly used loss functions.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">cross_entropy</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">))</span>
</pre></div>
</div>
<p>We also need a function to compute the accuracy of the model on the validation
set:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">eval_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">y</span><span class="p">)</span>
</pre></div>
</div>
<p>Next, setup the problem parameters and load the data:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">num_layers</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">hidden_dim</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">num_classes</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">256</span>
<span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">1e-1</span>
<span class="c1"># Load the data</span>
<span class="kn">import</span> <span class="nn">mnist</span>
<span class="n">train_images</span><span class="p">,</span> <span class="n">train_labels</span><span class="p">,</span> <span class="n">test_images</span><span class="p">,</span> <span class="n">test_labels</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span>
<span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">mnist</span><span class="o">.</span><span class="n">mnist</span><span class="p">()</span>
<span class="p">)</span>
</pre></div>
</div>
<p>Since were using SGD, we need an iterator which shuffles and constructs
minibatches of examples in the training set:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">batch_iterate</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="n">perm</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
<span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">y</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">):</span>
<span class="n">ids</span> <span class="o">=</span> <span class="n">perm</span><span class="p">[</span><span class="n">s</span> <span class="p">:</span> <span class="n">s</span> <span class="o">+</span> <span class="n">batch_size</span><span class="p">]</span>
<span class="k">yield</span> <span class="n">X</span><span class="p">[</span><span class="n">ids</span><span class="p">],</span> <span class="n">y</span><span class="p">[</span><span class="n">ids</span><span class="p">]</span>
</pre></div>
</div>
<p>Finally, we put it all together by instantiating the model, the
<a class="reference internal" href="../python/_autosummary/mlx.optimizers.SGD.html#mlx.optimizers.SGD" title="mlx.optimizers.SGD"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.optimizers.SGD</span></code></a> optimizer, and running the training loop:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># Load the model</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">(</span><span class="n">num_layers</span><span class="p">,</span> <span class="n">train_images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">hidden_dim</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
<span class="c1"># Get a function which gives the loss and gradient of the</span>
<span class="c1"># loss with respect to the model&#39;s trainable parameters</span>
<span class="n">loss_and_grad_fn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">value_and_grad</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">)</span>
<span class="c1"># Instantiate the optimizer</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">learning_rate</span><span class="p">)</span>
<span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_epochs</span><span class="p">):</span>
<span class="k">for</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">batch_iterate</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">train_images</span><span class="p">,</span> <span class="n">train_labels</span><span class="p">):</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">loss_and_grad_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="c1"># Update the optimizer state and model parameters</span>
<span class="c1"># in a single call</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">grads</span><span class="p">)</span>
<span class="c1"># Force a graph evaluation</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">state</span><span class="p">)</span>
<span class="n">accuracy</span> <span class="o">=</span> <span class="n">eval_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">test_images</span><span class="p">,</span> <span class="n">test_labels</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Epoch </span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">: Test accuracy </span><span class="si">{</span><span class="n">accuracy</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</pre></div>
</div>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>The <a class="reference internal" href="../python/_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a> function is a convenience function to get
the gradient of a loss with respect to the trainable parameters of a model.
This should not be confused with <a class="reference internal" href="../python/_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a>.</p>
</div>
<p>The model should train to a decent accuracy (about 95%) after just a few passes
over the training set. The <a class="reference external" href="https://github.com/ml-explore/mlx-examples/tree/main/mlp">full example</a>
is available in the MLX GitHub repo.</p>
</section>
</div>
</div>
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
<a href="linear_regression.html" class="btn btn-neutral float-left" title="Linear Regression" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
<a href="llama-inference.html" class="btn btn-neutral float-right" title="LLM inference" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
</div>
<hr/>
<div role="contentinfo">
<p>&#169; Copyright 2023, MLX Contributors.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>