mlx/docs/build/html/examples/linear_regression.html
Awni Hannun fbd10a48d4 docs
2025-06-04 01:01:47 +00:00

202 lines
13 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Linear Regression &mdash; MLX 0.0.0 documentation</title>
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
<script src="../_static/jquery.js"></script>
<script src="../_static/underscore.js"></script>
<script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
<script src="../_static/doctools.js"></script>
<script src="../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="Multi-Layer Perceptron" href="mlp.html" />
<link rel="prev" title="Using Streams" href="../using_streams.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../index.html" class="icon icon-home">
MLX
</a>
<div class="version">
0.0.0
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Install</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../install.html">Build and Install</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Usage</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../quick_start.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../using_streams.html">Using Streams</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="current">
<li class="toctree-l1 current"><a class="current reference internal" href="#">Linear Regression</a></li>
<li class="toctree-l1"><a class="reference internal" href="mlp.html">Multi-Layer Perceptron</a></li>
<li class="toctree-l1"><a class="reference internal" href="llama-inference.html">LLM inference</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Further Reading</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Developer Documentation</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../python/array.html">Array</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/ops.html">Operations</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/random.html">Random</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/transforms.html">Transforms</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/fft.html">FFT</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/nn.html">Neural Networks</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/optimizers.html">Optimizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../cpp/ops.html">Operations</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../index.html">MLX</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item active">Linear Regression</li>
<li class="wy-breadcrumbs-aside">
<a href="../_sources/examples/linear_regression.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<section id="linear-regression">
<span id="id1"></span><h1>Linear Regression<a class="headerlink" href="#linear-regression" title="Permalink to this heading"></a></h1>
<p>Lets implement a basic linear regression model as a starting point to
learn MLX. First import the core package and setup some problem metadata:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="n">num_features</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">num_examples</span> <span class="o">=</span> <span class="mi">1_000</span>
<span class="n">num_iters</span> <span class="o">=</span> <span class="mi">10_000</span> <span class="c1"># iterations of SGD</span>
<span class="n">lr</span> <span class="o">=</span> <span class="mf">0.01</span> <span class="c1"># learning rate for SGD</span>
</pre></div>
</div>
<p>Well generate a synthetic dataset by:</p>
<ol class="arabic simple">
<li><p>Sampling the design matrix <code class="docutils literal notranslate"><span class="pre">X</span></code>.</p></li>
<li><p>Sampling a ground truth parameter vector <code class="docutils literal notranslate"><span class="pre">w_star</span></code>.</p></li>
<li><p>Compute the dependent values <code class="docutils literal notranslate"><span class="pre">y</span></code> by adding Gaussian noise to <code class="docutils literal notranslate"><span class="pre">X</span> <span class="pre">&#64;</span> <span class="pre">w_star</span></code>.</p></li>
</ol>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># True parameters</span>
<span class="n">w_star</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">num_features</span><span class="p">,))</span>
<span class="c1"># Input examples (design matrix)</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">num_examples</span><span class="p">,</span> <span class="n">num_features</span><span class="p">))</span>
<span class="c1"># Noisy labels</span>
<span class="n">eps</span> <span class="o">=</span> <span class="mf">1e-2</span> <span class="o">*</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">num_examples</span><span class="p">,))</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">X</span> <span class="o">@</span> <span class="n">w_star</span> <span class="o">+</span> <span class="n">eps</span>
</pre></div>
</div>
<p>We will use SGD to find the optimal weights. To start, define the squared loss
and get the gradient function of the loss with respect to the parameters.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">w</span><span class="p">):</span>
<span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">mx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">X</span> <span class="o">@</span> <span class="n">w</span> <span class="o">-</span> <span class="n">y</span><span class="p">))</span>
<span class="n">grad_fn</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss_fn</span><span class="p">)</span>
</pre></div>
</div>
<p>Start the optimization by initializing the parameters <code class="docutils literal notranslate"><span class="pre">w</span></code> randomly. Then
repeatedly update the parameters for <code class="docutils literal notranslate"><span class="pre">num_iters</span></code> iterations.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">w</span> <span class="o">=</span> <span class="mf">1e-2</span> <span class="o">*</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">num_features</span><span class="p">,))</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_iters</span><span class="p">):</span>
<span class="n">grad</span> <span class="o">=</span> <span class="n">grad_fn</span><span class="p">(</span><span class="n">w</span><span class="p">)</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">w</span> <span class="o">-</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">grad</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">w</span><span class="p">)</span>
</pre></div>
</div>
<p>Finally, compute the loss of the learned parameters and verify that they are
close to the ground truth parameters.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">w</span><span class="p">)</span>
<span class="n">error_norm</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">w</span> <span class="o">-</span> <span class="n">w_star</span><span class="p">))</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">**</span> <span class="mf">0.5</span>
<span class="nb">print</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Loss </span><span class="si">{</span><span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.5f</span><span class="si">}</span><span class="s2">, |w-w*| = </span><span class="si">{</span><span class="n">error_norm</span><span class="si">:</span><span class="s2">.5f</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="p">)</span>
<span class="c1"># Should print something close to: Loss 0.00005, |w-w*| = 0.00364</span>
</pre></div>
</div>
<p>Complete <a class="reference external" href="https://github.com/ml-explore/mlx/tree/main/examples/python/linear_regression.py">linear regression</a>
and <a class="reference external" href="https://github.com/ml-explore/mlx/tree/main/examples/python/logistic_regression.py">logistic regression</a>
examples are available in the MLX GitHub repo.</p>
</section>
</div>
</div>
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
<a href="../using_streams.html" class="btn btn-neutral float-left" title="Using Streams" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
<a href="mlp.html" class="btn btn-neutral float-right" title="Multi-Layer Perceptron" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
</div>
<hr/>
<div role="contentinfo">
<p>&#169; Copyright 2023, MLX Contributors.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>