mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
rebase
This commit is contained in:
453
docs/build/html/usage/distributed.html
vendored
453
docs/build/html/usage/distributed.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>Distributed Communication — MLX 0.23.1 documentation</title>
|
||||
<title>Distributed Communication — MLX 0.23.2 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -16,30 +16,27 @@
|
||||
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
||||
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
||||
</script>
|
||||
<!--
|
||||
this give us a css class that will be invisible only if js is disabled
|
||||
-->
|
||||
<noscript>
|
||||
<style>
|
||||
.pst-js-only { display: none !important; }
|
||||
|
||||
</style>
|
||||
</noscript>
|
||||
|
||||
<!-- Loaded before other Sphinx assets -->
|
||||
<link href="../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||||
<link href="../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||||
<link href="../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||||
|
||||
|
||||
<link href="../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||||
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" />
|
||||
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" />
|
||||
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" />
|
||||
|
||||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=a3416100" />
|
||||
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=eba8b062" />
|
||||
|
||||
<!-- So that users can add custom icons -->
|
||||
<script src="../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
|
||||
<!-- Pre-loaded scripts that we'll load fully later -->
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" />
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
|
||||
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||||
|
||||
<script src="../_static/documentation_options.js?v=8e7411ea"></script>
|
||||
<script src="../_static/documentation_options.js?v=9900918c"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -51,7 +48,6 @@
|
||||
<link rel="prev" title="Conversion to NumPy and Other Frameworks" href="numpy.html" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<meta name="docsearch:language" content="en"/>
|
||||
<meta name="docsearch:version" content="0.23.1" />
|
||||
</head>
|
||||
|
||||
|
||||
@@ -67,8 +63,19 @@
|
||||
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
|
||||
|
||||
|
||||
<dialog id="pst-search-dialog">
|
||||
|
||||
<input type="checkbox"
|
||||
class="sidebar-toggle"
|
||||
id="pst-primary-sidebar-checkbox"/>
|
||||
<label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label>
|
||||
|
||||
<input type="checkbox"
|
||||
class="sidebar-toggle"
|
||||
id="pst-secondary-sidebar-checkbox"/>
|
||||
<label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label>
|
||||
|
||||
<div class="search-button__wrapper">
|
||||
<div class="search-button__overlay"></div>
|
||||
<div class="search-button__search-container">
|
||||
<form class="bd-search d-flex align-items-center"
|
||||
action="../search.html"
|
||||
method="get">
|
||||
@@ -76,6 +83,7 @@
|
||||
<input type="search"
|
||||
class="form-control"
|
||||
name="q"
|
||||
id="search-input"
|
||||
placeholder="Search..."
|
||||
aria-label="Search..."
|
||||
autocomplete="off"
|
||||
@@ -83,8 +91,8 @@
|
||||
autocapitalize="off"
|
||||
spellcheck="false"/>
|
||||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
|
||||
</form>
|
||||
</dialog>
|
||||
</form></div>
|
||||
</div>
|
||||
|
||||
<div class="pst-async-banner-revealer d-none">
|
||||
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
|
||||
@@ -100,8 +108,7 @@
|
||||
|
||||
|
||||
|
||||
<dialog id="pst-primary-sidebar-modal"></dialog>
|
||||
<div id="pst-primary-sidebar" class="bd-sidebar-primary bd-sidebar">
|
||||
<div class="bd-sidebar-primary bd-sidebar">
|
||||
|
||||
|
||||
|
||||
@@ -130,18 +137,22 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.23.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark pst-js-only" alt="MLX 0.23.1 documentation - Home"/>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.23.2 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.23.2 documentation - Home"/>`);</script>
|
||||
|
||||
|
||||
</a></div>
|
||||
<div class="sidebar-primary-item">
|
||||
|
||||
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||||
<i class="fa-solid fa-magnifying-glass"></i>
|
||||
<span class="search-button__default-text">Search</span>
|
||||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
|
||||
</button></div>
|
||||
<script>
|
||||
document.write(`
|
||||
<button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||||
<i class="fa-solid fa-magnifying-glass"></i>
|
||||
<span class="search-button__default-text">Search</span>
|
||||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
|
||||
</button>
|
||||
`);
|
||||
</script></div>
|
||||
<div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main">
|
||||
<div class="bd-toc-item navbar-nav active">
|
||||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Install</span></p>
|
||||
@@ -511,6 +522,7 @@
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/nn.html">Neural Networks</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.value_and_grad.html">mlx.nn.value_and_grad</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.quantize.html">mlx.nn.quantize</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.average_gradients.html">mlx.nn.average_gradients</a></li>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/module.html">Module</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.training.html">mlx.nn.Module.training</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.state.html">mlx.nn.Module.state</a></li>
|
||||
@@ -722,14 +734,9 @@
|
||||
|
||||
|
||||
<div class="sidebar-primary-items__end sidebar-primary__section">
|
||||
<div class="sidebar-primary-item">
|
||||
<div id="ethical-ad-placement"
|
||||
class="flat"
|
||||
data-ea-publisher="readthedocs"
|
||||
data-ea-type="readthedocs-sidebar"
|
||||
data-ea-manual="true">
|
||||
</div></div>
|
||||
</div>
|
||||
|
||||
<div id="rtd-footer-container"></div>
|
||||
|
||||
|
||||
</div>
|
||||
@@ -841,16 +848,24 @@
|
||||
|
||||
|
||||
|
||||
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||||
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
|
||||
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
|
||||
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
|
||||
</button>
|
||||
<script>
|
||||
document.write(`
|
||||
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||||
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i>
|
||||
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i>
|
||||
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i>
|
||||
</button>
|
||||
`);
|
||||
</script>
|
||||
|
||||
|
||||
<button class="btn btn-sm pst-navbar-icon search-button search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||||
<script>
|
||||
document.write(`
|
||||
<button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||||
<i class="fa-solid fa-magnifying-glass fa-lg"></i>
|
||||
</button>
|
||||
</button>
|
||||
`);
|
||||
</script>
|
||||
<button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||||
<span class="fa-solid fa-list"></span>
|
||||
</button>
|
||||
@@ -874,11 +889,26 @@
|
||||
</div>
|
||||
<nav aria-label="Page">
|
||||
<ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started">Getting Started</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#installing-mpi">Installing MPI</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#setting-up-remote-hosts">Setting up Remote Hosts</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#training-example">Training Example</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#tuning-all-reduce">Tuning All Reduce</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started">Getting Started</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#running-distributed-programs">Running Distributed Programs</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#selecting-backend">Selecting Backend</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#training-example">Training Example</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#utilizing-nn-average-gradients">Utilizing <code class="docutils literal notranslate"><span class="pre">nn.average_gradients</span></code></a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started-with-mpi">Getting Started with MPI</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#installing-mpi">Installing MPI</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#setting-up-remote-hosts">Setting up Remote Hosts</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#tuning-mpi-all-reduce">Tuning MPI All Reduce</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started-with-ring">Getting Started with Ring</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#defining-a-ring">Defining a Ring</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#thunderbolt-ring">Thunderbolt Ring</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</nav>
|
||||
</div>
|
||||
@@ -892,20 +922,26 @@
|
||||
|
||||
<section id="distributed-communication">
|
||||
<span id="usage-distributed"></span><h1>Distributed Communication<a class="headerlink" href="#distributed-communication" title="Link to this heading">#</a></h1>
|
||||
<p>MLX utilizes <a class="reference external" href="https://en.wikipedia.org/wiki/Message_Passing_Interface">MPI</a> to
|
||||
provide distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. You can
|
||||
see a list of the supported operations in the <a class="reference internal" href="../python/distributed.html#distributed"><span class="std std-ref">API docs</span></a>.</p>
|
||||
<p>MLX supports distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. At the
|
||||
moment we support two different communication backends:</p>
|
||||
<ul class="simple">
|
||||
<li><p><a class="reference external" href="https://en.wikipedia.org/wiki/Message_Passing_Interface">MPI</a> a
|
||||
full-featured and mature distributed communications library</p></li>
|
||||
<li><p>A <strong>ring</strong> backend of our own that uses native TCP sockets and should be
|
||||
faster for thunderbolt connections.</p></li>
|
||||
</ul>
|
||||
<p>The list of all currently supported operations and their documentation can be
|
||||
seen in the <a class="reference internal" href="../python/distributed.html#distributed"><span class="std std-ref">API docs</span></a>.</p>
|
||||
<div class="admonition note">
|
||||
<p class="admonition-title">Note</p>
|
||||
<p>A lot of operations may not be supported or not as fast as they should be.
|
||||
<p>Some operations may not be supported or not as fast as they should be.
|
||||
We are adding more and tuning the ones we have as we are figuring out the
|
||||
best way to do distributed computing on Macs using MLX.</p>
|
||||
</div>
|
||||
<section id="getting-started">
|
||||
<h2>Getting Started<a class="headerlink" href="#getting-started" title="Link to this heading">#</a></h2>
|
||||
<p>MLX already comes with the ability to “talk” to MPI if it is installed on the
|
||||
machine. The minimal distributed program in MLX is as simple as:</p>
|
||||
<p>A distributed program in MLX is as simple as:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||||
|
||||
<span class="n">world</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init</span><span class="p">()</span>
|
||||
@@ -914,65 +950,71 @@ machine. The minimal distributed program in MLX is as simple as:</p>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The program above sums the array <code class="docutils literal notranslate"><span class="pre">mx.ones(10)</span></code> across all
|
||||
distributed processes. If simply run with <code class="docutils literal notranslate"><span class="pre">python</span></code>, however, only one
|
||||
process is launched and no distributed communication takes place.</p>
|
||||
<p>To launch the program in distributed mode we need to use <code class="docutils literal notranslate"><span class="pre">mpirun</span></code> or
|
||||
<code class="docutils literal notranslate"><span class="pre">mpiexec</span></code> depending on the MPI installation. The simplest possible way is the
|
||||
following:</p>
|
||||
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>$<span class="w"> </span>mpirun<span class="w"> </span>-np<span class="w"> </span><span class="m">2</span><span class="w"> </span>python<span class="w"> </span>test.py
|
||||
<span class="m">1</span><span class="w"> </span>array<span class="o">([</span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span>...,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span><span class="o">]</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>float32<span class="o">)</span>
|
||||
<span class="m">0</span><span class="w"> </span>array<span class="o">([</span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span>...,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span><span class="o">]</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>float32<span class="o">)</span>
|
||||
distributed processes. However, when this script is run with <code class="docutils literal notranslate"><span class="pre">python</span></code> only
|
||||
one process is launched and no distributed communication takes place. Namely,
|
||||
all operations in <code class="docutils literal notranslate"><span class="pre">mx.distributed</span></code> are noops when the distributed group has a
|
||||
size of one. This property allows us to avoid code that checks if we are in a
|
||||
distributed setting similar to the one below:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||||
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="o">...</span>
|
||||
<span class="n">world</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init</span><span class="p">()</span>
|
||||
<span class="c1"># No need for the check we can simply do x = mx.distributed.all_sum(x)</span>
|
||||
<span class="k">if</span> <span class="n">world</span><span class="o">.</span><span class="n">size</span><span class="p">()</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_sum</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with <code class="docutils literal notranslate"><span class="pre">mpirun</span> <span class="pre">-np</span> <span class="pre">4</span> <span class="pre">...</span></code> would
|
||||
print 4 etc.</p>
|
||||
<section id="running-distributed-programs">
|
||||
<h3>Running Distributed Programs<a class="headerlink" href="#running-distributed-programs" title="Link to this heading">#</a></h3>
|
||||
<p>MLX provides <code class="docutils literal notranslate"><span class="pre">mlx.launch</span></code> a helper script to launch distributed programs.
|
||||
Continuing with our initial example we can run it on localhost with 4 processes using</p>
|
||||
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>$<span class="w"> </span>mlx.launch<span class="w"> </span>-n<span class="w"> </span><span class="m">4</span><span class="w"> </span>my_script.py
|
||||
<span class="m">3</span><span class="w"> </span>array<span class="o">([</span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span>...,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span><span class="o">]</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>float32<span class="o">)</span>
|
||||
<span class="m">2</span><span class="w"> </span>array<span class="o">([</span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span>...,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span><span class="o">]</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>float32<span class="o">)</span>
|
||||
<span class="m">1</span><span class="w"> </span>array<span class="o">([</span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span>...,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span><span class="o">]</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>float32<span class="o">)</span>
|
||||
<span class="m">0</span><span class="w"> </span>array<span class="o">([</span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span>...,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span><span class="o">]</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>float32<span class="o">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can also run it on some remote hosts by providing their IPs (provided that
|
||||
the script exists on all hosts and they are reachable by ssh)</p>
|
||||
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>$<span class="w"> </span>mlx.launch<span class="w"> </span>--hosts<span class="w"> </span>ip1,ip2,ip3,ip4<span class="w"> </span>my_script.py
|
||||
<span class="m">3</span><span class="w"> </span>array<span class="o">([</span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span>...,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span><span class="o">]</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>float32<span class="o">)</span>
|
||||
<span class="m">2</span><span class="w"> </span>array<span class="o">([</span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span>...,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span><span class="o">]</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>float32<span class="o">)</span>
|
||||
<span class="m">1</span><span class="w"> </span>array<span class="o">([</span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span>...,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span><span class="o">]</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>float32<span class="o">)</span>
|
||||
<span class="m">0</span><span class="w"> </span>array<span class="o">([</span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span>...,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span>,<span class="w"> </span><span class="m">4</span><span class="o">]</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>float32<span class="o">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Consult the dedicated <a class="reference internal" href="launching_distributed.html"><span class="doc">usage guide</span></a> for more
|
||||
information on using <code class="docutils literal notranslate"><span class="pre">mlx.launch</span></code>.</p>
|
||||
</section>
|
||||
<section id="installing-mpi">
|
||||
<h2>Installing MPI<a class="headerlink" href="#installing-mpi" title="Link to this heading">#</a></h2>
|
||||
<p>MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using <code class="docutils literal notranslate"><span class="pre">openmpi</span></code> installed
|
||||
with the Anaconda package manager as follows:</p>
|
||||
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>$<span class="w"> </span>conda<span class="w"> </span>install<span class="w"> </span>conda-forge::openmpi
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Installing with Homebrew may require specifying the location of <code class="docutils literal notranslate"><span class="pre">libmpi.dyld</span></code>
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the <code class="docutils literal notranslate"><span class="pre">DYLD_LIBRARY_PATH</span></code> environment variable to <code class="docutils literal notranslate"><span class="pre">mpirun</span></code>.</p>
|
||||
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>$<span class="w"> </span>mpirun<span class="w"> </span>-np<span class="w"> </span><span class="m">2</span><span class="w"> </span>-x<span class="w"> </span><span class="nv">DYLD_LIBRARY_PATH</span><span class="o">=</span>/opt/homebrew/lib/<span class="w"> </span>python<span class="w"> </span>test.py
|
||||
</pre></div>
|
||||
</div>
|
||||
</section>
|
||||
<section id="setting-up-remote-hosts">
|
||||
<h2>Setting up Remote Hosts<a class="headerlink" href="#setting-up-remote-hosts" title="Link to this heading">#</a></h2>
|
||||
<p>MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:</p>
|
||||
<ul class="simple">
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">ssh</span> <span class="pre">hostname</span></code> works from all machines to all machines without asking for
|
||||
password or host confirmation</p></li>
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">mpirun</span></code> is accessible on all machines. You can call <code class="docutils literal notranslate"><span class="pre">mpirun</span></code> using its
|
||||
full path to force all machines to use a specific path.</p></li>
|
||||
<li><p>Ensure that the <code class="docutils literal notranslate"><span class="pre">hostname</span></code> used by MPI is the one that you have configured
|
||||
in the <code class="docutils literal notranslate"><span class="pre">.ssh/config</span></code> files on all machines.</p></li>
|
||||
</ul>
|
||||
<section id="selecting-backend">
|
||||
<h3>Selecting Backend<a class="headerlink" href="#selecting-backend" title="Link to this heading">#</a></h3>
|
||||
<p>You can select the backend you want to use when calling <a class="reference internal" href="../python/_autosummary/mlx.core.distributed.init.html#mlx.core.distributed.init" title="mlx.core.distributed.init"><code class="xref py py-func docutils literal notranslate"><span class="pre">init()</span></code></a> by passing
|
||||
one of <code class="docutils literal notranslate"><span class="pre">{'any',</span> <span class="pre">'ring',</span> <span class="pre">'mpi'}</span></code>. When passing <code class="docutils literal notranslate"><span class="pre">any</span></code>, MLX will try to
|
||||
initialize the <code class="docutils literal notranslate"><span class="pre">ring</span></code> backend and if it fails the <code class="docutils literal notranslate"><span class="pre">mpi</span></code> backend. If they
|
||||
both fail then a singleton group is created.</p>
|
||||
<div class="admonition note">
|
||||
<p class="admonition-title">Note</p>
|
||||
<p>For an example hostname <code class="docutils literal notranslate"><span class="pre">foo.bar.com</span></code> MPI can use only <code class="docutils literal notranslate"><span class="pre">foo</span></code> as
|
||||
the hostname passed to ssh if the current hostname matches <code class="docutils literal notranslate"><span class="pre">*.bar.com</span></code>.</p>
|
||||
<p>After a distributed backend is successfully initialized <a class="reference internal" href="../python/_autosummary/mlx.core.distributed.init.html#mlx.core.distributed.init" title="mlx.core.distributed.init"><code class="xref py py-func docutils literal notranslate"><span class="pre">init()</span></code></a> will
|
||||
return <strong>the same backend</strong> if called without arguments or with backend set to
|
||||
<code class="docutils literal notranslate"><span class="pre">any</span></code>.</p>
|
||||
</div>
|
||||
<p>An easy way to pass the host names to MPI is using a host file. A host file
|
||||
looks like the following, where <code class="docutils literal notranslate"><span class="pre">host1</span></code> and <code class="docutils literal notranslate"><span class="pre">host2</span></code> should be the fully
|
||||
qualified domain names or IPs for these hosts.</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">host1</span> <span class="n">slots</span><span class="o">=</span><span class="mi">1</span>
|
||||
<span class="n">host2</span> <span class="n">slots</span><span class="o">=</span><span class="mi">1</span>
|
||||
<p>The following examples aim to clarify the backend initialization logic in MLX:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># Case 1: Initialize MPI regardless if it was possible to initialize the ring backend</span>
|
||||
<span class="n">world</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">backend</span><span class="o">=</span><span class="s2">"mpi"</span><span class="p">)</span>
|
||||
<span class="n">world2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init</span><span class="p">()</span> <span class="c1"># subsequent calls return the MPI backend!</span>
|
||||
|
||||
<span class="c1"># Case 2: Initialize any backend</span>
|
||||
<span class="n">world</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">backend</span><span class="o">=</span><span class="s2">"any"</span><span class="p">)</span> <span class="c1"># equivalent to no arguments</span>
|
||||
<span class="n">world2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init</span><span class="p">()</span> <span class="c1"># same as above</span>
|
||||
|
||||
<span class="c1"># Case 3: Initialize both backends at the same time</span>
|
||||
<span class="n">world_mpi</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">backend</span><span class="o">=</span><span class="s2">"mpi"</span><span class="p">)</span>
|
||||
<span class="n">world_ring</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">backend</span><span class="o">=</span><span class="s2">"ring"</span><span class="p">)</span>
|
||||
<span class="n">world_any</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init</span><span class="p">()</span> <span class="c1"># same as MPI because it was initialized first!</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>When using MLX, it is very likely that you want to use 1 slot per host, ie one
|
||||
process per host. The hostfile also needs to contain the current
|
||||
host if you want to run on the local host. Passing the host file to
|
||||
<code class="docutils literal notranslate"><span class="pre">mpirun</span></code> is simply done using the <code class="docutils literal notranslate"><span class="pre">--hostfile</span></code> command line argument.</p>
|
||||
</section>
|
||||
</section>
|
||||
<section id="training-example">
|
||||
<h2>Training Example<a class="headerlink" href="#training-example" title="Link to this heading">#</a></h2>
|
||||
@@ -1022,17 +1064,158 @@ everything else remaining the same.</p>
|
||||
<span class="k">return</span> <span class="n">loss</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<section id="utilizing-nn-average-gradients">
|
||||
<h3>Utilizing <code class="docutils literal notranslate"><span class="pre">nn.average_gradients</span></code><a class="headerlink" href="#utilizing-nn-average-gradients" title="Link to this heading">#</a></h3>
|
||||
<p>Although the code example above works correctly; it performs one communication
|
||||
per gradient. It is significantly more efficient to aggregate several gradients
|
||||
together and perform fewer communication steps.</p>
|
||||
<p>This is the purpose of <a class="reference internal" href="../python/_autosummary/mlx.nn.average_gradients.html#mlx.nn.average_gradients" title="mlx.nn.average_gradients"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.nn.average_gradients()</span></code></a>. The final code looks
|
||||
almost identical to the example above:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">model</span> <span class="o">=</span> <span class="o">...</span>
|
||||
<span class="n">optimizer</span> <span class="o">=</span> <span class="o">...</span>
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="o">...</span>
|
||||
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">step</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="n">loss</span><span class="p">,</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">loss_grad_fn</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="n">grads</span> <span class="o">=</span> <span class="n">mlx</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">average_gradients</span><span class="p">(</span><span class="n">grads</span><span class="p">)</span> <span class="c1"># <---- This line was added</span>
|
||||
<span class="n">optimizer</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">grads</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">loss</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">dataset</span><span class="p">:</span>
|
||||
<span class="n">loss</span> <span class="o">=</span> <span class="n">step</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</section>
|
||||
</section>
|
||||
<section id="getting-started-with-mpi">
|
||||
<h2>Getting Started with MPI<a class="headerlink" href="#getting-started-with-mpi" title="Link to this heading">#</a></h2>
|
||||
<p>MLX already comes with the ability to “talk” to MPI if it is installed on the
|
||||
machine. Launching distributed MLX programs that use MPI can be done with
|
||||
<code class="docutils literal notranslate"><span class="pre">mpirun</span></code> as expected. However, in the following examples we will be using
|
||||
<code class="docutils literal notranslate"><span class="pre">mlx.launch</span> <span class="pre">--backend</span> <span class="pre">mpi</span></code> which takes care of some nuisances such as setting
|
||||
absolute paths for the <code class="docutils literal notranslate"><span class="pre">mpirun</span></code> executable and the <code class="docutils literal notranslate"><span class="pre">libmpi.dyld</span></code> shared
|
||||
library.</p>
|
||||
<p>The simplest possible usage is the following which, assuming the minimal
|
||||
example in the beginning of this page, should result in:</p>
|
||||
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>$<span class="w"> </span>mlx.launch<span class="w"> </span>--backend<span class="w"> </span>mpi<span class="w"> </span>-n<span class="w"> </span><span class="m">2</span><span class="w"> </span>test.py
|
||||
<span class="m">1</span><span class="w"> </span>array<span class="o">([</span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span>...,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span><span class="o">]</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>float32<span class="o">)</span>
|
||||
<span class="m">0</span><span class="w"> </span>array<span class="o">([</span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span>...,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span>,<span class="w"> </span><span class="m">2</span><span class="o">]</span>,<span class="w"> </span><span class="nv">dtype</span><span class="o">=</span>float32<span class="o">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with <code class="docutils literal notranslate"><span class="pre">mlx.launch</span> <span class="pre">-n</span> <span class="pre">4</span> <span class="pre">...</span></code> would
|
||||
print 4 etc.</p>
|
||||
<section id="installing-mpi">
|
||||
<h3>Installing MPI<a class="headerlink" href="#installing-mpi" title="Link to this heading">#</a></h3>
|
||||
<p>MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using <code class="docutils literal notranslate"><span class="pre">openmpi</span></code> installed
|
||||
with the Anaconda package manager as follows:</p>
|
||||
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>$<span class="w"> </span>conda<span class="w"> </span>install<span class="w"> </span>conda-forge::openmpi
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Installing with Homebrew may require specifying the location of <code class="docutils literal notranslate"><span class="pre">libmpi.dyld</span></code>
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the <code class="docutils literal notranslate"><span class="pre">DYLD_LIBRARY_PATH</span></code> environment variable to <code class="docutils literal notranslate"><span class="pre">mpirun</span></code> and it is
|
||||
done automatically by <code class="docutils literal notranslate"><span class="pre">mlx.launch</span></code>.</p>
|
||||
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>$<span class="w"> </span>mpirun<span class="w"> </span>-np<span class="w"> </span><span class="m">2</span><span class="w"> </span>-x<span class="w"> </span><span class="nv">DYLD_LIBRARY_PATH</span><span class="o">=</span>/opt/homebrew/lib/<span class="w"> </span>python<span class="w"> </span>test.py
|
||||
$<span class="w"> </span><span class="c1"># or simply</span>
|
||||
$<span class="w"> </span>mlx.launch<span class="w"> </span>-n<span class="w"> </span><span class="m">2</span><span class="w"> </span>test.py
|
||||
</pre></div>
|
||||
</div>
|
||||
</section>
|
||||
<section id="setting-up-remote-hosts">
|
||||
<h3>Setting up Remote Hosts<a class="headerlink" href="#setting-up-remote-hosts" title="Link to this heading">#</a></h3>
|
||||
<p>MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:</p>
|
||||
<ul class="simple">
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">ssh</span> <span class="pre">hostname</span></code> works from all machines to all machines without asking for
|
||||
password or host confirmation</p></li>
|
||||
<li><p><code class="docutils literal notranslate"><span class="pre">mpirun</span></code> is accessible on all machines.</p></li>
|
||||
<li><p>Ensure that the <code class="docutils literal notranslate"><span class="pre">hostname</span></code> used by MPI is the one that you have configured
|
||||
in the <code class="docutils literal notranslate"><span class="pre">.ssh/config</span></code> files on all machines.</p></li>
|
||||
</ul>
|
||||
</section>
|
||||
<section id="tuning-mpi-all-reduce">
|
||||
<h3>Tuning MPI All Reduce<a class="headerlink" href="#tuning-mpi-all-reduce" title="Link to this heading">#</a></h3>
|
||||
<div class="admonition note">
|
||||
<p class="admonition-title">Note</p>
|
||||
<p>For faster all reduce consider using the ring backend either with Thunderbolt
|
||||
connections or over Ethernet.</p>
|
||||
</div>
|
||||
<p>Configure MPI to use N tcp connections between each host to improve bandwidth
|
||||
by passing <code class="docutils literal notranslate"><span class="pre">--mca</span> <span class="pre">btl_tcp_links</span> <span class="pre">N</span></code>.</p>
|
||||
<p>Force MPI to use the most performant network interface by setting <code class="docutils literal notranslate"><span class="pre">--mca</span>
|
||||
<span class="pre">btl_tcp_if_include</span> <span class="pre"><iface></span></code> where <code class="docutils literal notranslate"><span class="pre"><iface></span></code> should be the interface you want
|
||||
to use.</p>
|
||||
</section>
|
||||
</section>
|
||||
<section id="getting-started-with-ring">
|
||||
<h2>Getting Started with Ring<a class="headerlink" href="#getting-started-with-ring" title="Link to this heading">#</a></h2>
|
||||
<p>The ring backend does not depend on any third party library so it is always
|
||||
available. It uses TCP sockets so the nodes need to be reachable via a network.
|
||||
As the name suggests the nodes are connected in a ring which means that rank 1
|
||||
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
|
||||
and so on and so forth. As a result <a class="reference internal" href="../python/_autosummary/mlx.core.distributed.send.html#mlx.core.distributed.send" title="mlx.core.distributed.send"><code class="xref py py-func docutils literal notranslate"><span class="pre">send()</span></code></a> and <a class="reference internal" href="../python/_autosummary/mlx.core.distributed.recv.html#mlx.core.distributed.recv" title="mlx.core.distributed.recv"><code class="xref py py-func docutils literal notranslate"><span class="pre">recv()</span></code></a> with
|
||||
arbitrary sender and receiver is not supported in the ring backend.</p>
|
||||
<section id="defining-a-ring">
|
||||
<h3>Defining a Ring<a class="headerlink" href="#defining-a-ring" title="Link to this heading">#</a></h3>
|
||||
<p>The easiest way to define and use a ring is via a JSON hostfile and the
|
||||
<code class="docutils literal notranslate"><span class="pre">mlx.launch</span></code> <a class="reference internal" href="launching_distributed.html"><span class="doc">helper script</span></a>. For each node one
|
||||
defines a hostname to ssh into to run commands on this node and one or more IPs
|
||||
that this node will listen to for connections.</p>
|
||||
<p>For example the hostfile below defines a 4 node ring. <code class="docutils literal notranslate"><span class="pre">hostname1</span></code> will be
|
||||
rank 0, <code class="docutils literal notranslate"><span class="pre">hostname2</span></code> rank 1 etc.</p>
|
||||
<div class="highlight-json notranslate"><div class="highlight"><pre><span></span><span class="p">[</span>
|
||||
<span class="w"> </span><span class="p">{</span><span class="nt">"ssh"</span><span class="p">:</span><span class="w"> </span><span class="s2">"hostname1"</span><span class="p">,</span><span class="w"> </span><span class="nt">"ips"</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="s2">"123.123.123.1"</span><span class="p">]},</span>
|
||||
<span class="w"> </span><span class="p">{</span><span class="nt">"ssh"</span><span class="p">:</span><span class="w"> </span><span class="s2">"hostname2"</span><span class="p">,</span><span class="w"> </span><span class="nt">"ips"</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="s2">"123.123.123.2"</span><span class="p">]},</span>
|
||||
<span class="w"> </span><span class="p">{</span><span class="nt">"ssh"</span><span class="p">:</span><span class="w"> </span><span class="s2">"hostname3"</span><span class="p">,</span><span class="w"> </span><span class="nt">"ips"</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="s2">"123.123.123.3"</span><span class="p">]},</span>
|
||||
<span class="w"> </span><span class="p">{</span><span class="nt">"ssh"</span><span class="p">:</span><span class="w"> </span><span class="s2">"hostname4"</span><span class="p">,</span><span class="w"> </span><span class="nt">"ips"</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="s2">"123.123.123.4"</span><span class="p">]}</span>
|
||||
<span class="p">]</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Running <code class="docutils literal notranslate"><span class="pre">mlx.launch</span> <span class="pre">--hostfile</span> <span class="pre">ring-4.json</span> <span class="pre">my_script.py</span></code> will ssh into each
|
||||
node, run the script which will listen for connections in each of the provided
|
||||
IPs. Specifically, <code class="docutils literal notranslate"><span class="pre">hostname1</span></code> will connect to <code class="docutils literal notranslate"><span class="pre">123.123.123.2</span></code> and accept a
|
||||
connection from <code class="docutils literal notranslate"><span class="pre">123.123.123.4</span></code> and so on and so forth.</p>
|
||||
</section>
|
||||
<section id="thunderbolt-ring">
|
||||
<h3>Thunderbolt Ring<a class="headerlink" href="#thunderbolt-ring" title="Link to this heading">#</a></h3>
|
||||
<p>Although the ring backend can have benefits over MPI even for Ethernet, its
|
||||
main purpose is to use Thunderbolt rings for higher bandwidth communication.
|
||||
Setting up such thunderbolt rings can be done manually, but is a relatively
|
||||
tedious process. To simplify this, we provide the utility <code class="docutils literal notranslate"><span class="pre">mlx.distributed_config</span></code>.</p>
|
||||
<p>To use <code class="docutils literal notranslate"><span class="pre">mlx.distributed_config</span></code> your computers need to be accessible by ssh via
|
||||
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
|
||||
utility as follows:</p>
|
||||
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>mlx.distributed_config<span class="w"> </span>--verbose<span class="w"> </span>--hosts<span class="w"> </span>host1,host2,host3,host4
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>By default the script will attempt to discover the thunderbolt ring and provide
|
||||
you with the commands to configure each node as well as the <code class="docutils literal notranslate"><span class="pre">hostfile.json</span></code>
|
||||
to use with <code class="docutils literal notranslate"><span class="pre">mlx.launch</span></code>. If password-less <code class="docutils literal notranslate"><span class="pre">sudo</span></code> is available on the nodes
|
||||
then <code class="docutils literal notranslate"><span class="pre">--auto-setup</span></code> can be used to configure them automatically.</p>
|
||||
<p>To validate your connection without configuring anything
|
||||
<code class="docutils literal notranslate"><span class="pre">mlx.distributed_config</span></code> can also plot the ring using DOT format.</p>
|
||||
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>mlx.distributed_config<span class="w"> </span>--verbose<span class="w"> </span>--hosts<span class="w"> </span>host1,host2,host3,host4<span class="w"> </span>--dot<span class="w"> </span>>ring.dot
|
||||
dot<span class="w"> </span>-Tpng<span class="w"> </span>ring.dot<span class="w"> </span>>ring.png
|
||||
open<span class="w"> </span>ring.png
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>If you want to go through the process manually, the steps are as follows:</p>
|
||||
<ul class="simple">
|
||||
<li><p>Disable the thunderbolt bridge interface</p></li>
|
||||
<li><p>For the cable connecting rank <code class="docutils literal notranslate"><span class="pre">i</span></code> to rank <code class="docutils literal notranslate"><span class="pre">i</span> <span class="pre">+</span> <span class="pre">1</span></code> find the interfaces
|
||||
corresponding to that cable in nodes <code class="docutils literal notranslate"><span class="pre">i</span></code> and <code class="docutils literal notranslate"><span class="pre">i</span> <span class="pre">+</span> <span class="pre">1</span></code>.</p></li>
|
||||
<li><p>Set up a unique subnetwork connecting the two nodes for the corresponding
|
||||
interfaces. For instance if the cable corresponds to <code class="docutils literal notranslate"><span class="pre">en2</span></code> on node <code class="docutils literal notranslate"><span class="pre">i</span></code>
|
||||
and <code class="docutils literal notranslate"><span class="pre">en2</span></code> also on node <code class="docutils literal notranslate"><span class="pre">i</span> <span class="pre">+</span> <span class="pre">1</span></code> then we may assign IPs <code class="docutils literal notranslate"><span class="pre">192.168.0.1</span></code> and
|
||||
<code class="docutils literal notranslate"><span class="pre">192.168.0.2</span></code> respectively to the two nodes. For more details you can see
|
||||
the commands prepared by the utility script.</p></li>
|
||||
</ul>
|
||||
</section>
|
||||
<section id="tuning-all-reduce">
|
||||
<h2>Tuning All Reduce<a class="headerlink" href="#tuning-all-reduce" title="Link to this heading">#</a></h2>
|
||||
<p>We are working on improving the performance of all reduce on MLX but for now
|
||||
the two main things one can do to extract the most out of distributed training with MLX are:</p>
|
||||
<ol class="arabic simple">
|
||||
<li><p>Perform a few large reductions instead of many small ones to improve
|
||||
bandwidth and latency</p></li>
|
||||
<li><p>Pass <code class="docutils literal notranslate"><span class="pre">--mca</span> <span class="pre">btl_tcp_links</span> <span class="pre">4</span></code> to <code class="docutils literal notranslate"><span class="pre">mpirun</span></code> to configure it to use 4 tcp
|
||||
connections between each host to improve bandwidth</p></li>
|
||||
</ol>
|
||||
</section>
|
||||
</section>
|
||||
|
||||
@@ -1072,8 +1255,7 @@ connections between each host to improve bandwidth</p></li>
|
||||
|
||||
|
||||
|
||||
<dialog id="pst-secondary-sidebar-modal"></dialog>
|
||||
<div id="pst-secondary-sidebar" class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner">
|
||||
<div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner">
|
||||
|
||||
|
||||
<div class="sidebar-secondary-item">
|
||||
@@ -1082,11 +1264,26 @@ connections between each host to improve bandwidth</p></li>
|
||||
</div>
|
||||
<nav class="bd-toc-nav page-toc">
|
||||
<ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started">Getting Started</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#installing-mpi">Installing MPI</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#setting-up-remote-hosts">Setting up Remote Hosts</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#training-example">Training Example</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#tuning-all-reduce">Tuning All Reduce</a></li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started">Getting Started</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#running-distributed-programs">Running Distributed Programs</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#selecting-backend">Selecting Backend</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#training-example">Training Example</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#utilizing-nn-average-gradients">Utilizing <code class="docutils literal notranslate"><span class="pre">nn.average_gradients</span></code></a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started-with-mpi">Getting Started with MPI</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#installing-mpi">Installing MPI</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#setting-up-remote-hosts">Setting up Remote Hosts</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#tuning-mpi-all-reduce">Tuning MPI All Reduce</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started-with-ring">Getting Started with Ring</a><ul class="visible nav section-nav flex-column">
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#defining-a-ring">Defining a Ring</a></li>
|
||||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#thunderbolt-ring">Thunderbolt Ring</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</nav></div>
|
||||
|
||||
@@ -1135,8 +1332,8 @@ By MLX Contributors
|
||||
</div>
|
||||
|
||||
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
||||
<script defer src="../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
|
||||
<script defer src="../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
|
||||
<script src="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||||
<script src="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||||
|
||||
<footer class="bd-footer">
|
||||
</footer>
|
||||
|
Reference in New Issue
Block a user