mlx/docs/build/html/examples/linear_regression.html

202 lines
13 KiB
HTML
Raw Normal View History

2023-11-30 04:41:56 +08:00
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Linear Regression &mdash; MLX 0.0.0 documentation</title>
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
<script src="../_static/jquery.js"></script>
<script src="../_static/underscore.js"></script>
<script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
<script src="../_static/doctools.js"></script>
<script src="../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="Multi-Layer Perceptron" href="mlp.html" />
<link rel="prev" title="Using Streams" href="../using_streams.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../index.html" class="icon icon-home">
MLX
</a>
<div class="version">
0.0.0
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Install</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../install.html">Build and Install</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Usage</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../quick_start.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../using_streams.html">Using Streams</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="current">
<li class="toctree-l1 current"><a class="current reference internal" href="#">Linear Regression</a></li>
<li class="toctree-l1"><a class="reference internal" href="mlp.html">Multi-Layer Perceptron</a></li>
<li class="toctree-l1"><a class="reference internal" href="llama-inference.html">LLM inference</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Further Reading</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Developer Documentation</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../python/array.html">Array</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/ops.html">Operations</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/random.html">Random</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/transforms.html">Transforms</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/fft.html">FFT</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/nn.html">Neural Networks</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/optimizers.html">Optimizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../cpp/ops.html">Operations</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../index.html">MLX</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item active">Linear Regression</li>
<li class="wy-breadcrumbs-aside">
<a href="../_sources/examples/linear_regression.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<section id="linear-regression">
<span id="id1"></span><h1>Linear Regression<a class="headerlink" href="#linear-regression" title="Permalink to this heading"></a></h1>
<p>Lets implement a basic linear regression model as a starting point to
learn MLX. First import the core package and setup some problem metadata:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="n">num_features</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">num_examples</span> <span class="o">=</span> <span class="mi">1_000</span>
<span class="n">num_iters</span> <span class="o">=</span> <span class="mi">10_000</span> <span class="c1"># iterations of SGD</span>
<span class="n">lr</span> <span class="o">=</span> <span class="mf">0.01</span> <span class="c1"># learning rate for SGD</span>
</pre></div>
</div>
<p>Well generate a synthetic dataset by:</p>
<ol class="arabic simple">
<li><p>Sampling the design matrix <code class="docutils literal notranslate"><span class="pre">X</span></code>.</p></li>
<li><p>Sampling a ground truth parameter vector <code class="docutils literal notranslate"><span class="pre">w_star</span></code>.</p></li>
<li><p>Compute the dependent values <code class="docutils literal notranslate"><span class="pre">y</span></code> by adding Gaussian noise to <code class="docutils literal notranslate"><span class="pre">X</span> <span class="pre">&#64;</span> <span class="pre">w_star</span></code>.</p></li>
</ol>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># True parameters</span>
<span class="n">w_star</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">num_features</span><span class="p">,))</span>
<span class="c1"># Input examples (design matrix)</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">num_examples</span><span class="p">,</span> <span class="n">num_features</span><span class="p">))</span>
<span class="c1"># Noisy labels</span>
<span class="n">eps</span> <span class="o">=</span> <span class="mf">1e-2</span> <span class="o">*</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">num_examples</span><span class="p">,))</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">X</span> <span class="o">@</span> <span class="n">w_star</span> <span class="o">+</span> <span class="n">eps</span>
</pre></div>
</div>
<p>We will use SGD to find the optimal weights. To start, define the squared loss
and get the gradient function of the loss with respect to the parameters.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">w</span><span class="p">):</span>
<span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">mx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">X</span> <span class="o">@</span> <span class="n">w</span> <span class="o">-</span> <span class="n">y</span><span class="p">))</span>
<span class="n">grad_fn</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss_fn</span><span class="p">)</span>
</pre></div>
</div>
<p>Start the optimization by initializing the parameters <code class="docutils literal notranslate"><span class="pre">w</span></code> randomly. Then
repeatedly update the parameters for <code class="docutils literal notranslate"><span class="pre">num_iters</span></code> iterations.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">w</span> <span class="o">=</span> <span class="mf">1e-2</span> <span class="o">*</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">num_features</span><span class="p">,))</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_iters</span><span class="p">):</span>
<span class="n">grad</span> <span class="o">=</span> <span class="n">grad_fn</span><span class="p">(</span><span class="n">w</span><span class="p">)</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">w</span> <span class="o">-</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">grad</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">w</span><span class="p">)</span>
</pre></div>
</div>
<p>Finally, compute the loss of the learned parameters and verify that they are
close to the ground truth parameters.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">w</span><span class="p">)</span>
<span class="n">error_norm</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">w</span> <span class="o">-</span> <span class="n">w_star</span><span class="p">))</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">**</span> <span class="mf">0.5</span>
<span class="nb">print</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Loss </span><span class="si">{</span><span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.5f</span><span class="si">}</span><span class="s2">, |w-w*| = </span><span class="si">{</span><span class="n">error_norm</span><span class="si">:</span><span class="s2">.5f</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="p">)</span>
<span class="c1"># Should print something close to: Loss 0.00005, |w-w*| = 0.00364</span>
</pre></div>
</div>
<p>Complete <a class="reference external" href="https://github.com/ml-explore/mlx/tree/main/examples/python/linear_regression.py">linear regression</a>
and <a class="reference external" href="https://github.com/ml-explore/mlx/tree/main/examples/python/logistic_regression.py">logistic regression</a>
examples are available in the MLX GitHub repo.</p>
</section>
</div>
</div>
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
<a href="../using_streams.html" class="btn btn-neutral float-left" title="Using Streams" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
<a href="mlp.html" class="btn btn-neutral float-right" title="Multi-Layer Perceptron" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
</div>
<hr/>
<div role="contentinfo">
<p>&#169; Copyright 2023, MLX Contributors.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>