This commit is contained in:
Awni Hannun
2023-12-06 08:13:20 -08:00
committed by CircleCI Docs
parent 901a4ba68a
commit e4d33acace
209 changed files with 664 additions and 621 deletions

View File

@@ -9,7 +9,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<title>Linear Regression &#8212; MLX 0.0.0 documentation</title>
<title>Linear Regression &#8212; MLX 0.0.3 documentation</title>
@@ -134,8 +134,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.0 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.3 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.3 documentation - Home"/>`);</script>
</a></div>

View File

@@ -9,7 +9,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<title>LLM inference &#8212; MLX 0.0.0 documentation</title>
<title>LLM inference &#8212; MLX 0.0.3 documentation</title>
@@ -134,8 +134,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.0 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.3 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.3 documentation - Home"/>`);</script>
</a></div>
@@ -571,7 +571,7 @@ module to concisely define the model architecture.</p>
<section id="attention-layer">
<h3>Attention layer<a class="headerlink" href="#attention-layer" title="Permalink to this heading">#</a></h3>
<p>We will start with the llama attention layer which notably uses the RoPE
positional encoding. <a class="footnote-reference brackets" href="#id5" id="id1" role="doc-noteref"><span class="fn-bracket">[</span>1<span class="fn-bracket">]</span></a> In addition, our attention layer will optionally use a
positional encoding. <a class="footnote-reference brackets" href="#id4" id="id1" role="doc-noteref"><span class="fn-bracket">[</span>1<span class="fn-bracket">]</span></a> In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to
support efficient inference.</p>
<p>Our implementation uses <a class="reference internal" href="../python/_autosummary/mlx.nn.Linear.html#mlx.nn.Linear" title="mlx.nn.Linear"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Linear</span></code></a> for all the projections and
@@ -632,7 +632,7 @@ support efficient inference.</p>
<section id="encoder-layer">
<h3>Encoder layer<a class="headerlink" href="#encoder-layer" title="Permalink to this heading">#</a></h3>
<p>The other component of the Llama model is the encoder layer which uses RMS
normalization <a class="footnote-reference brackets" href="#id6" id="id2" role="doc-noteref"><span class="fn-bracket">[</span>2<span class="fn-bracket">]</span></a> and SwiGLU. <a class="footnote-reference brackets" href="#id7" id="id3" role="doc-noteref"><span class="fn-bracket">[</span>3<span class="fn-bracket">]</span></a> For RMS normalization we will use
normalization <a class="footnote-reference brackets" href="#id5" id="id2" role="doc-noteref"><span class="fn-bracket">[</span>2<span class="fn-bracket">]</span></a> and SwiGLU. <a class="footnote-reference brackets" href="#id6" id="id3" role="doc-noteref"><span class="fn-bracket">[</span>3<span class="fn-bracket">]</span></a> For RMS normalization we will use
<a class="reference internal" href="../python/_autosummary/mlx.nn.RMSNorm.html#mlx.nn.RMSNorm" title="mlx.nn.RMSNorm"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.RMSNorm</span></code></a> that is already provided in <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.nn</span></code>.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">LlamaEncoderLayer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">mlp_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
@@ -856,7 +856,7 @@ like <code class="docutils literal notranslate"><span class="pre">layers.2.atten
<p>which can then be used to update the model. Note that the method above incurs
several unnecessary copies from disk to numpy and then from numpy to MLX. It
will be replaced in the future with direct loading to MLX.</p>
<p>You can download the full example code in <a class="reference external" href="code">mlx-examples</a>. Assuming, the
<p>You can download the full example code in <a class="reference external" href="https://github.com/ml-explore/mlx-examples/tree/main/llama">mlx-examples</a>. Assuming, the
existence of <code class="docutils literal notranslate"><span class="pre">weights.pth</span></code> and <code class="docutils literal notranslate"><span class="pre">tokenizer.model</span></code> in the current working
directory we can play around with our inference script as follows (the timings
are representative of an M1 Ultra and the 7B parameter Llama model):</p>
@@ -899,20 +899,20 @@ take<span class="w"> </span>his<span class="w"> </span>eyes<span class="w"> </sp
<h2>Scripts<a class="headerlink" href="#scripts" title="Permalink to this heading">#</a></h2>
<div class="admonition-download-the-code admonition">
<p class="admonition-title">Download the code</p>
<p>The full example code is available in <a class="reference external" href="code">mlx-examples</a>.</p>
<p>The full example code is available in <a class="reference external" href="https://github.com/ml-explore/mlx-examples/tree/main/llama">mlx-examples</a>.</p>
</div>
<aside class="footnote brackets" id="id5" role="note">
<aside class="footnote brackets" id="id4" role="note">
<span class="label"><span class="fn-bracket">[</span><a role="doc-backlink" href="#id1">1</a><span class="fn-bracket">]</span></span>
<p>Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
Roformer: Enhanced transformer with rotary position embedding. arXiv
preprint arXiv:2104.09864.</p>
</aside>
<aside class="footnote brackets" id="id6" role="note">
<aside class="footnote brackets" id="id5" role="note">
<span class="label"><span class="fn-bracket">[</span><a role="doc-backlink" href="#id2">2</a><span class="fn-bracket">]</span></span>
<p>Zhang, B. and Sennrich, R., 2019. Root mean square layer normalization.
Advances in Neural Information Processing Systems, 32.</p>
</aside>
<aside class="footnote brackets" id="id7" role="note">
<aside class="footnote brackets" id="id6" role="note">
<span class="label"><span class="fn-bracket">[</span><a role="doc-backlink" href="#id3">3</a><span class="fn-bracket">]</span></span>
<p>Shazeer, N., 2020. Glu variants improve transformer. arXiv preprint
arXiv:2002.05202.</p>

View File

@@ -9,7 +9,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<title>Multi-Layer Perceptron &#8212; MLX 0.0.0 documentation</title>
<title>Multi-Layer Perceptron &#8212; MLX 0.0.3 documentation</title>
@@ -134,8 +134,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.0 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.3 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.3 documentation - Home"/>`);</script>
</a></div>
@@ -647,7 +647,7 @@ the gradient of a loss with respect to the trainable parameters of a model.
This should not be confused with <a class="reference internal" href="../python/_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a>.</p>
</div>
<p>The model should train to a decent accuracy (about 95%) after just a few passes
over the training set. The <a class="reference external" href="https://github.com/ml-explore/mlx-examples/tree/main/mlp">full example</a>
over the training set. The <a class="reference external" href="https://github.com/ml-explore/mlx-examples/tree/main/mnist">full example</a>
is available in the MLX GitHub repo.</p>
</section>