mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:39:36 +08:00
rebase
This commit is contained in:
63
docs/build/html/examples/linear_regression.html
vendored
63
docs/build/html/examples/linear_regression.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>Linear Regression — MLX 0.21.1 documentation</title>
|
||||
<title>Linear Regression — MLX 0.22.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
||||
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
||||
</script>
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
-->
|
||||
<noscript>
|
||||
<style>
|
||||
@@ -27,19 +27,19 @@
|
||||
</noscript>
|
||||
|
||||
<!-- Loaded before other Sphinx assets -->
|
||||
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
|
||||
|
||||
<!-- So that users can add custom icons -->
|
||||
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
|
||||
<!-- Pre-loaded scripts that we'll load fully later -->
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
|
||||
|
||||
<script src="../_static/documentation_options.js?v=acb17c73"></script>
|
||||
<script src="../_static/documentation_options.js?v=c952a61b"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -48,10 +48,10 @@
|
||||
<link rel="index" title="Index" href="../genindex.html" />
|
||||
<link rel="search" title="Search" href="../search.html" />
|
||||
<link rel="next" title="Multi-Layer Perceptron" href="mlp.html" />
|
||||
<link rel="prev" title="Using Streams" href="../usage/using_streams.html" />
|
||||
<link rel="prev" title="Exporting Functions" href="../usage/export.html" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<meta name="docsearch:language" content="en"/>
|
||||
<meta name="docsearch:version" content="0.21.1" />
|
||||
<meta name="docsearch:version" content="0.22.0" />
|
||||
</head>
|
||||
|
||||
|
||||
@@ -130,8 +130,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -160,6 +160,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/numpy.html">Conversion to NumPy and Other Frameworks</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/distributed.html">Distributed Communication</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/using_streams.html">Using Streams</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/export.html">Exporting Functions</a></li>
|
||||
</ul>
|
||||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
||||
<ul class="current nav bd-sidenav">
|
||||
@@ -228,6 +229,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
@@ -242,6 +244,13 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
|
||||
@@ -324,6 +333,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
|
||||
@@ -379,6 +389,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
|
||||
@@ -403,6 +415,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
|
||||
@@ -695,6 +708,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
@@ -703,9 +717,14 @@
|
||||
|
||||
|
||||
<div class="sidebar-primary-items__end sidebar-primary__section">
|
||||
<div class="sidebar-primary-item">
|
||||
<div id="ethical-ad-placement"
|
||||
class="flat"
|
||||
data-ea-publisher="readthedocs"
|
||||
data-ea-type="readthedocs-sidebar"
|
||||
data-ea-manual="true">
|
||||
</div></div>
|
||||
</div>
|
||||
|
||||
<div id="rtd-footer-container"></div>
|
||||
|
||||
|
||||
</div>
|
||||
@@ -856,7 +875,7 @@
|
||||
<span id="id1"></span><h1>Linear Regression<a class="headerlink" href="#linear-regression" title="Link to this heading">#</a></h1>
|
||||
<p>Let’s implement a basic linear regression model as a starting point to
|
||||
learn MLX. First import the core package and setup some problem metadata:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||||
|
||||
<span class="n">num_features</span> <span class="o">=</span> <span class="mi">100</span>
|
||||
<span class="n">num_examples</span> <span class="o">=</span> <span class="mi">1_000</span>
|
||||
@@ -883,7 +902,7 @@ learn MLX. First import the core package and setup some problem metadata:</p>
|
||||
</div>
|
||||
<p>We will use SGD to find the optimal weights. To start, define the squared loss
|
||||
and get the gradient function of the loss with respect to the parameters.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">w</span><span class="p">):</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">loss_fn</span><span class="p">(</span><span class="n">w</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">mx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">X</span> <span class="o">@</span> <span class="n">w</span> <span class="o">-</span> <span class="n">y</span><span class="p">))</span>
|
||||
|
||||
<span class="n">grad_fn</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss_fn</span><span class="p">)</span>
|
||||
@@ -927,12 +946,12 @@ examples are available in the MLX GitHub repo.</p>
|
||||
|
||||
<div class="prev-next-area">
|
||||
<a class="left-prev"
|
||||
href="../usage/using_streams.html"
|
||||
href="../usage/export.html"
|
||||
title="previous page">
|
||||
<i class="fa-solid fa-angle-left"></i>
|
||||
<div class="prev-next-info">
|
||||
<p class="prev-next-subtitle">previous</p>
|
||||
<p class="prev-next-title">Using Streams</p>
|
||||
<p class="prev-next-title">Exporting Functions</p>
|
||||
</div>
|
||||
</a>
|
||||
<a class="right-next"
|
||||
@@ -994,8 +1013,8 @@ By MLX Contributors
|
||||
</div>
|
||||
|
||||
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
|
||||
|
||||
<footer class="bd-footer">
|
||||
</footer>
|
||||
|
91
docs/build/html/examples/llama-inference.html
vendored
91
docs/build/html/examples/llama-inference.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>LLM inference — MLX 0.21.1 documentation</title>
|
||||
<title>LLM inference — MLX 0.22.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
||||
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
||||
</script>
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
-->
|
||||
<noscript>
|
||||
<style>
|
||||
@@ -27,19 +27,19 @@
|
||||
</noscript>
|
||||
|
||||
<!-- Loaded before other Sphinx assets -->
|
||||
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
|
||||
|
||||
<!-- So that users can add custom icons -->
|
||||
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
|
||||
<!-- Pre-loaded scripts that we'll load fully later -->
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
|
||||
|
||||
<script src="../_static/documentation_options.js?v=acb17c73"></script>
|
||||
<script src="../_static/documentation_options.js?v=c952a61b"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -51,7 +51,7 @@
|
||||
<link rel="prev" title="Multi-Layer Perceptron" href="mlp.html" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<meta name="docsearch:language" content="en"/>
|
||||
<meta name="docsearch:version" content="0.21.1" />
|
||||
<meta name="docsearch:version" content="0.22.0" />
|
||||
</head>
|
||||
|
||||
|
||||
@@ -130,8 +130,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -160,6 +160,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/numpy.html">Conversion to NumPy and Other Frameworks</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/distributed.html">Distributed Communication</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/using_streams.html">Using Streams</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/export.html">Exporting Functions</a></li>
|
||||
</ul>
|
||||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
||||
<ul class="current nav bd-sidenav">
|
||||
@@ -228,6 +229,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
@@ -242,6 +244,13 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
|
||||
@@ -324,6 +333,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
|
||||
@@ -379,6 +389,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
|
||||
@@ -403,6 +415,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
|
||||
@@ -695,6 +708,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
@@ -703,9 +717,14 @@
|
||||
|
||||
|
||||
<div class="sidebar-primary-items__end sidebar-primary__section">
|
||||
<div class="sidebar-primary-item">
|
||||
<div id="ethical-ad-placement"
|
||||
class="flat"
|
||||
data-ea-publisher="readthedocs"
|
||||
data-ea-type="readthedocs-sidebar"
|
||||
data-ea-manual="true">
|
||||
</div></div>
|
||||
</div>
|
||||
|
||||
<div id="rtd-footer-container"></div>
|
||||
|
||||
|
||||
</div>
|
||||
@@ -890,11 +909,11 @@ key/value cache that will be concatenated with the provided keys and values to
|
||||
support efficient inference.</p>
|
||||
<p>Our implementation uses <a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Linear.html#mlx.nn.Linear" title="mlx.nn.Linear"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Linear</span></code></a> for all the projections and
|
||||
<a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RoPE.html#mlx.nn.RoPE" title="mlx.nn.RoPE"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.RoPE</span></code></a> for the positional encoding.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||||
<span class="kn">import</span> <span class="nn">mlx.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||||
<span class="kn">import</span><span class="w"> </span><span class="nn">mlx.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
|
||||
|
||||
<span class="k">class</span> <span class="nc">LlamaAttention</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="k">class</span><span class="w"> </span><span class="nc">LlamaAttention</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span>
|
||||
@@ -905,7 +924,7 @@ support efficient inference.</p>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">value_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">out_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">queries</span><span class="p">,</span> <span class="n">keys</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">queries</span><span class="p">,</span> <span class="n">keys</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="n">queries</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_proj</span><span class="p">(</span><span class="n">queries</span><span class="p">)</span>
|
||||
<span class="n">keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_proj</span><span class="p">(</span><span class="n">keys</span><span class="p">)</span>
|
||||
<span class="n">values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_proj</span><span class="p">(</span><span class="n">values</span><span class="p">)</span>
|
||||
@@ -948,8 +967,8 @@ support efficient inference.</p>
|
||||
<p>The other component of the Llama model is the encoder layer which uses RMS
|
||||
normalization <a class="footnote-reference brackets" href="#id5" id="id2" role="doc-noteref"><span class="fn-bracket">[</span>2<span class="fn-bracket">]</span></a> and SwiGLU. <a class="footnote-reference brackets" href="#id6" id="id3" role="doc-noteref"><span class="fn-bracket">[</span>3<span class="fn-bracket">]</span></a> For RMS normalization we will use
|
||||
<a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RMSNorm.html#mlx.nn.RMSNorm" title="mlx.nn.RMSNorm"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.RMSNorm</span></code></a> that is already provided in <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.nn</span></code>.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">LlamaEncoderLayer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">mlp_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">LlamaEncoderLayer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">mlp_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">attention</span> <span class="o">=</span> <span class="n">LlamaAttention</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">)</span>
|
||||
@@ -961,7 +980,7 @@ normalization <a class="footnote-reference brackets" href="#id5" id="id2" role="
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">linear2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">mlp_dims</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">linear3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">mlp_dims</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="n">y</span><span class="p">,</span> <span class="n">cache</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">mask</span><span class="p">,</span> <span class="n">cache</span><span class="p">)</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
@@ -981,8 +1000,8 @@ normalization <a class="footnote-reference brackets" href="#id5" id="id2" role="
|
||||
<h3>Full model<a class="headerlink" href="#full-model" title="Link to this heading">#</a></h3>
|
||||
<p>To implement any Llama model we simply have to combine <code class="docutils literal notranslate"><span class="pre">LlamaEncoderLayer</span></code>
|
||||
instances with an <a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Embedding.html#mlx.nn.Embedding" title="mlx.nn.Embedding"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Embedding</span></code></a> to embed the input tokens.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Llama</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">Llama</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">mlp_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span>
|
||||
<span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
@@ -994,7 +1013,7 @@ instances with an <a class="reference internal" href="../python/nn/_autosummary/
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RMSNorm</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">out_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="o">.</span><span class="n">create_additive_causal_mask</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">embedding</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
|
||||
@@ -1015,10 +1034,10 @@ layers but using <code class="docutils literal notranslate"><span class="pre">mo
|
||||
performs no sampling whatsoever. In the rest of this subsection, we will
|
||||
implement the inference function as a python generator that processes the
|
||||
prompt and then autoregressively yields tokens one at a time.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Llama</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">Llama</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
||||
<span class="o">...</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">generate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">temp</span><span class="o">=</span><span class="mf">1.0</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">generate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">temp</span><span class="o">=</span><span class="mf">1.0</span><span class="p">):</span>
|
||||
<span class="n">cache</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
|
||||
<span class="c1"># Make an additive causal mask. We will need that to process the prompt.</span>
|
||||
@@ -1102,13 +1121,13 @@ it. In the following code, we randomly initialize a small Llama model, process
|
||||
SentencePiece model that comes with them. We will write a small script to
|
||||
convert the PyTorch weights to MLX compatible ones and write them in a NPZ file
|
||||
that can be loaded directly by MLX.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">argparse</span>
|
||||
<span class="kn">from</span> <span class="nn">itertools</span> <span class="kn">import</span> <span class="n">starmap</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">argparse</span>
|
||||
<span class="kn">from</span><span class="w"> </span><span class="nn">itertools</span><span class="w"> </span><span class="kn">import</span> <span class="n">starmap</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
<span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
|
||||
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">map_torch_to_mlx</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">map_torch_to_mlx</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="s2">"tok_embedding"</span> <span class="ow">in</span> <span class="n">key</span><span class="p">:</span>
|
||||
<span class="n">key</span> <span class="o">=</span> <span class="s2">"embedding.weight"</span>
|
||||
|
||||
@@ -1157,7 +1176,7 @@ left is to load them from disk and we can finally use the LLM to generate text.
|
||||
We can load numpy format files using the <a class="reference internal" href="../python/_autosummary/mlx.core.load.html#mlx.core.load" title="mlx.core.load"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.core.load()</span></code></a> operation.</p>
|
||||
<p>To create a parameter dictionary from the key/value representation of NPZ files
|
||||
we will use the <a class="reference internal" href="../python/_autosummary/mlx.utils.tree_unflatten.html#mlx.utils.tree_unflatten" title="mlx.utils.tree_unflatten"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.utils.tree_unflatten()</span></code></a> helper method as follows:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">mlx.utils</span> <span class="kn">import</span> <span class="n">tree_unflatten</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">mlx.utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">tree_unflatten</span>
|
||||
|
||||
<span class="n">model</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">tree_unflatten</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">weight_file</span><span class="p">)</span><span class="o">.</span><span class="n">items</span><span class="p">())))</span>
|
||||
</pre></div>
|
||||
@@ -1341,8 +1360,8 @@ By MLX Contributors
|
||||
</div>
|
||||
|
||||
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
|
||||
|
||||
<footer class="bd-footer">
|
||||
</footer>
|
||||
|
75
docs/build/html/examples/mlp.html
vendored
75
docs/build/html/examples/mlp.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>Multi-Layer Perceptron — MLX 0.21.1 documentation</title>
|
||||
<title>Multi-Layer Perceptron — MLX 0.22.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
||||
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
||||
</script>
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
-->
|
||||
<noscript>
|
||||
<style>
|
||||
@@ -27,19 +27,19 @@
|
||||
</noscript>
|
||||
|
||||
<!-- Loaded before other Sphinx assets -->
|
||||
<link href="../_static/styles/theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=26a4bc78f4c0ddb94549" rel="stylesheet" />
|
||||
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=fa44fd50" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
|
||||
|
||||
<!-- So that users can add custom icons -->
|
||||
<script src="../_static/scripts/fontawesome.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
|
||||
<!-- Pre-loaded scripts that we'll load fully later -->
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
|
||||
|
||||
<script src="../_static/documentation_options.js?v=acb17c73"></script>
|
||||
<script src="../_static/documentation_options.js?v=c952a61b"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -51,7 +51,7 @@
|
||||
<link rel="prev" title="Linear Regression" href="linear_regression.html" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<meta name="docsearch:language" content="en"/>
|
||||
<meta name="docsearch:version" content="0.21.1" />
|
||||
<meta name="docsearch:version" content="0.22.0" />
|
||||
</head>
|
||||
|
||||
|
||||
@@ -130,8 +130,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.21.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.22.0 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.22.0 documentation - Home"/>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -160,6 +160,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/numpy.html">Conversion to NumPy and Other Frameworks</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/distributed.html">Distributed Communication</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/using_streams.html">Using Streams</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../usage/export.html">Exporting Functions</a></li>
|
||||
</ul>
|
||||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
||||
<ul class="current nav bd-sidenav">
|
||||
@@ -228,6 +229,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
@@ -242,6 +244,13 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
|
||||
</ul>
|
||||
</details></li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
|
||||
@@ -324,6 +333,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
|
||||
@@ -379,6 +389,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
|
||||
@@ -403,6 +415,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
|
||||
@@ -695,6 +708,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
@@ -703,9 +717,14 @@
|
||||
|
||||
|
||||
<div class="sidebar-primary-items__end sidebar-primary__section">
|
||||
<div class="sidebar-primary-item">
|
||||
<div id="ethical-ad-placement"
|
||||
class="flat"
|
||||
data-ea-publisher="readthedocs"
|
||||
data-ea-type="readthedocs-sidebar"
|
||||
data-ea-manual="true">
|
||||
</div></div>
|
||||
</div>
|
||||
|
||||
<div id="rtd-footer-container"></div>
|
||||
|
||||
|
||||
</div>
|
||||
@@ -857,11 +876,11 @@
|
||||
<p>In this example we’ll learn to use <code class="docutils literal notranslate"><span class="pre">mlx.nn</span></code> by implementing a simple
|
||||
multi-layer perceptron to classify MNIST.</p>
|
||||
<p>As a first step import the MLX packages we need:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||||
<span class="kn">import</span> <span class="nn">mlx.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||||
<span class="kn">import</span> <span class="nn">mlx.optimizers</span> <span class="k">as</span> <span class="nn">optim</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||||
<span class="kn">import</span><span class="w"> </span><span class="nn">mlx.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
|
||||
<span class="kn">import</span><span class="w"> </span><span class="nn">mlx.optimizers</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">optim</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The model is defined as the <code class="docutils literal notranslate"><span class="pre">MLP</span></code> class which inherits from
|
||||
@@ -872,8 +891,8 @@ the <a class="reference internal" href="../python/nn.html#module-class"><span cl
|
||||
<a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> registers parameters.</p></li>
|
||||
<li><p>Define a <code class="docutils literal notranslate"><span class="pre">__call__</span></code> where the computation is implemented.</p></li>
|
||||
</ol>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">MLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">:</span> <span class="nb">int</span>
|
||||
<span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
@@ -883,7 +902,7 @@ the <a class="reference internal" href="../python/nn.html#module-class"><span cl
|
||||
<span class="k">for</span> <span class="n">idim</span><span class="p">,</span> <span class="n">odim</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span>
|
||||
<span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">l</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">](</span><span class="n">x</span><span class="p">)</span>
|
||||
@@ -892,13 +911,13 @@ the <a class="reference internal" href="../python/nn.html#module-class"><span cl
|
||||
<p>We define the loss function which takes the mean of the per-example cross
|
||||
entropy loss. The <code class="docutils literal notranslate"><span class="pre">mlx.nn.losses</span></code> sub-package has implementations of some
|
||||
commonly used loss functions.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">loss_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">cross_entropy</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">))</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We also need a function to compute the accuracy of the model on the validation
|
||||
set:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">eval_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">eval_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">y</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
@@ -913,7 +932,7 @@ we will import as <code class="docutils literal notranslate"><span class="pre">m
|
||||
<span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">1e-1</span>
|
||||
|
||||
<span class="c1"># Load the data</span>
|
||||
<span class="kn">import</span> <span class="nn">mnist</span>
|
||||
<span class="kn">import</span><span class="w"> </span><span class="nn">mnist</span>
|
||||
<span class="n">train_images</span><span class="p">,</span> <span class="n">train_labels</span><span class="p">,</span> <span class="n">test_images</span><span class="p">,</span> <span class="n">test_labels</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span>
|
||||
<span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">mnist</span><span class="o">.</span><span class="n">mnist</span><span class="p">()</span>
|
||||
<span class="p">)</span>
|
||||
@@ -921,7 +940,7 @@ we will import as <code class="docutils literal notranslate"><span class="pre">m
|
||||
</div>
|
||||
<p>Since we’re using SGD, we need an iterator which shuffles and constructs
|
||||
minibatches of examples in the training set:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">batch_iterate</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">batch_iterate</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="n">perm</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
|
||||
<span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">y</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="n">ids</span> <span class="o">=</span> <span class="n">perm</span><span class="p">[</span><span class="n">s</span> <span class="p">:</span> <span class="n">s</span> <span class="o">+</span> <span class="n">batch_size</span><span class="p">]</span>
|
||||
@@ -1046,8 +1065,8 @@ By MLX Contributors
|
||||
</div>
|
||||
|
||||
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=26a4bc78f4c0ddb94549"></script>
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
|
||||
|
||||
<footer class="bd-footer">
|
||||
</footer>
|
||||
|
Reference in New Issue
Block a user