mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
252 lines
22 KiB
HTML
252 lines
22 KiB
HTML
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" >
|
||
<head>
|
||
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
<title>Multi-Layer Perceptron — MLX 0.0.0 documentation</title>
|
||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||
<!--[if lt IE 9]>
|
||
<script src="../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
<script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
|
||
<script src="../_static/jquery.js"></script>
|
||
<script src="../_static/underscore.js"></script>
|
||
<script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
|
||
<script src="../_static/doctools.js"></script>
|
||
<script src="../_static/js/theme.js"></script>
|
||
<link rel="index" title="Index" href="../genindex.html" />
|
||
<link rel="search" title="Search" href="../search.html" />
|
||
<link rel="next" title="LLM inference" href="llama-inference.html" />
|
||
<link rel="prev" title="Linear Regression" href="linear_regression.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
<div class="wy-grid-for-nav">
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../index.html" class="icon icon-home">
|
||
MLX
|
||
</a>
|
||
<div class="version">
|
||
0.0.0
|
||
</div>
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
|
||
<p class="caption" role="heading"><span class="caption-text">Install</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../install.html">Build and Install</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Usage</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../quick_start.html">Quick Start Guide</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../using_streams.html">Using Streams</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1"><a class="reference internal" href="linear_regression.html">Linear Regression</a></li>
|
||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Multi-Layer Perceptron</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="llama-inference.html">LLM inference</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Further Reading</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Developer Documentation</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Python API Reference</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python/array.html">Array</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python/ops.html">Operations</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python/random.html">Random</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python/transforms.html">Transforms</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python/fft.html">FFT</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python/nn.html">Neural Networks</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python/optimizers.html">Optimizers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">C++ API Reference</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../cpp/ops.html">Operations</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../index.html">MLX</a>
|
||
</nav>
|
||
|
||
<div class="wy-nav-content">
|
||
<div class="rst-content">
|
||
<div role="navigation" aria-label="Page navigation">
|
||
<ul class="wy-breadcrumbs">
|
||
<li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
|
||
<li class="breadcrumb-item active">Multi-Layer Perceptron</li>
|
||
<li class="wy-breadcrumbs-aside">
|
||
<a href="../_sources/examples/mlp.rst.txt" rel="nofollow"> View page source</a>
|
||
</li>
|
||
</ul>
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<section id="multi-layer-perceptron">
|
||
<span id="mlp"></span><h1>Multi-Layer Perceptron<a class="headerlink" href="#multi-layer-perceptron" title="Permalink to this heading"></a></h1>
|
||
<p>In this example we’ll learn to use <code class="docutils literal notranslate"><span class="pre">mlx.nn</span></code> by implementing a simple
|
||
multi-layer perceptron to classify MNIST.</p>
|
||
<p>As a first step import the MLX packages we need:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||
<span class="kn">import</span> <span class="nn">mlx.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||
<span class="kn">import</span> <span class="nn">mlx.optimizers</span> <span class="k">as</span> <span class="nn">optim</span>
|
||
|
||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>The model is defined as the <code class="docutils literal notranslate"><span class="pre">MLP</span></code> class which inherits from
|
||
<a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a>. We follow the standard idiom to make a new module:</p>
|
||
<ol class="arabic simple">
|
||
<li><p>Define an <code class="docutils literal notranslate"><span class="pre">__init__</span></code> where the parameters and/or submodules are setup. See
|
||
the <a class="reference internal" href="../python/nn.html#module-class"><span class="std std-ref">Module class docs</span></a> for more information on how
|
||
<a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> registers parameters.</p></li>
|
||
<li><p>Define a <code class="docutils literal notranslate"><span class="pre">__call__</span></code> where the computation is implemented.</p></li>
|
||
</ol>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">:</span> <span class="nb">int</span>
|
||
<span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||
<span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">input_dim</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">hidden_dim</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_layers</span> <span class="o">+</span> <span class="p">[</span><span class="n">output_dim</span><span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">idim</span><span class="p">,</span> <span class="n">odim</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">idim</span><span class="p">,</span> <span class="n">odim</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||
<span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">l</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">](</span><span class="n">x</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>We define the loss function which takes the mean of the per-example cross
|
||
entropy loss. The <code class="docutils literal notranslate"><span class="pre">mlx.nn.losses</span></code> sub-package has implementations of some
|
||
commonly used loss functions.</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">cross_entropy</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">))</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>We also need a function to compute the accuracy of the model on the validation
|
||
set:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">eval_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">y</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Next, setup the problem parameters and load the data:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">num_layers</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">hidden_dim</span> <span class="o">=</span> <span class="mi">32</span>
|
||
<span class="n">num_classes</span> <span class="o">=</span> <span class="mi">10</span>
|
||
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">256</span>
|
||
<span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">10</span>
|
||
<span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">1e-1</span>
|
||
|
||
<span class="c1"># Load the data</span>
|
||
<span class="kn">import</span> <span class="nn">mnist</span>
|
||
<span class="n">train_images</span><span class="p">,</span> <span class="n">train_labels</span><span class="p">,</span> <span class="n">test_images</span><span class="p">,</span> <span class="n">test_labels</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span>
|
||
<span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">mnist</span><span class="o">.</span><span class="n">mnist</span><span class="p">()</span>
|
||
<span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Since we’re using SGD, we need an iterator which shuffles and constructs
|
||
minibatches of examples in the training set:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">batch_iterate</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||
<span class="n">perm</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
|
||
<span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">y</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">ids</span> <span class="o">=</span> <span class="n">perm</span><span class="p">[</span><span class="n">s</span> <span class="p">:</span> <span class="n">s</span> <span class="o">+</span> <span class="n">batch_size</span><span class="p">]</span>
|
||
<span class="k">yield</span> <span class="n">X</span><span class="p">[</span><span class="n">ids</span><span class="p">],</span> <span class="n">y</span><span class="p">[</span><span class="n">ids</span><span class="p">]</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Finally, we put it all together by instantiating the model, the
|
||
<a class="reference internal" href="../python/_autosummary/mlx.optimizers.SGD.html#mlx.optimizers.SGD" title="mlx.optimizers.SGD"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.optimizers.SGD</span></code></a> optimizer, and running the training loop:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># Load the model</span>
|
||
<span class="n">model</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">(</span><span class="n">num_layers</span><span class="p">,</span> <span class="n">train_images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">hidden_dim</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span>
|
||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
|
||
|
||
<span class="c1"># Get a function which gives the loss and gradient of the</span>
|
||
<span class="c1"># loss with respect to the model's trainable parameters</span>
|
||
<span class="n">loss_and_grad_fn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">value_and_grad</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Instantiate the optimizer</span>
|
||
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">learning_rate</span><span class="p">)</span>
|
||
|
||
<span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_epochs</span><span class="p">):</span>
|
||
<span class="k">for</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">batch_iterate</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">train_images</span><span class="p">,</span> <span class="n">train_labels</span><span class="p">):</span>
|
||
<span class="n">loss</span><span class="p">,</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">loss_and_grad_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Update the optimizer state and model parameters</span>
|
||
<span class="c1"># in a single call</span>
|
||
<span class="n">optimizer</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">grads</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Force a graph evaluation</span>
|
||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">state</span><span class="p">)</span>
|
||
|
||
<span class="n">accuracy</span> <span class="o">=</span> <span class="n">eval_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">test_images</span><span class="p">,</span> <span class="n">test_labels</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Epoch </span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">: Test accuracy </span><span class="si">{</span><span class="n">accuracy</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<div class="admonition note">
|
||
<p class="admonition-title">Note</p>
|
||
<p>The <a class="reference internal" href="../python/_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a> function is a convenience function to get
|
||
the gradient of a loss with respect to the trainable parameters of a model.
|
||
This should not be confused with <a class="reference internal" href="../python/_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a>.</p>
|
||
</div>
|
||
<p>The model should train to a decent accuracy (about 95%) after just a few passes
|
||
over the training set. The <a class="reference external" href="https://github.com/ml-explore/mlx-examples/tree/main/mlp">full example</a>
|
||
is available in the MLX GitHub repo.</p>
|
||
</section>
|
||
|
||
|
||
</div>
|
||
</div>
|
||
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
|
||
<a href="linear_regression.html" class="btn btn-neutral float-left" title="Linear Regression" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||
<a href="llama-inference.html" class="btn btn-neutral float-right" title="LLM inference" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||
</div>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>© Copyright 2023, MLX Contributors.</p>
|
||
</div>
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
</section>
|
||
</div>
|
||
<script>
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
</body>
|
||
</html> |