mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 01:19:21 +08:00
Compare commits
34 Commits
ibv-backen
...
ibv-backen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d2bc340df4 | ||
|
|
fabc947df4 | ||
|
|
5523087cfb | ||
|
|
2f939acefa | ||
|
|
3b416a2e36 | ||
|
|
753c6a4d0f | ||
|
|
d3a754c8aa | ||
|
|
595a4ad206 | ||
|
|
a4dc1fac6c | ||
|
|
ebda161a86 | ||
|
|
fa31a4b295 | ||
|
|
9d707ba3b5 | ||
|
|
405d30b6e5 | ||
|
|
cd4b12ce1b | ||
|
|
425043ccca | ||
|
|
95d92af8a0 | ||
|
|
bfdddd644b | ||
|
|
1216afdc91 | ||
|
|
04e94d78bb | ||
|
|
60d4e8b2a8 | ||
|
|
c5745fddd2 | ||
|
|
e937a8033f | ||
|
|
4dfe02d7c6 | ||
|
|
5c2cff9329 | ||
|
|
325dab9559 | ||
|
|
67e454ab0a | ||
|
|
27232db1ba | ||
|
|
a4b3bc969b | ||
|
|
667c0f3bb9 | ||
|
|
6245824d42 | ||
|
|
39289ef025 | ||
|
|
aefc9bd3f6 | ||
|
|
997cfc7699 | ||
|
|
1fa8dc5797 |
11
.github/actions/build-cuda-release/action.yml
vendored
11
.github/actions/build-cuda-release/action.yml
vendored
@@ -1,6 +1,15 @@
|
|||||||
name: 'Build CUDA wheel'
|
name: 'Build CUDA wheel'
|
||||||
description: 'Build CUDA wheel'
|
description: 'Build CUDA wheel'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
arch:
|
||||||
|
description: 'Platform architecture tag'
|
||||||
|
required: true
|
||||||
|
type: choice
|
||||||
|
options:
|
||||||
|
- x86_64
|
||||||
|
- aarch64
|
||||||
|
|
||||||
runs:
|
runs:
|
||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
@@ -12,4 +21,4 @@ runs:
|
|||||||
pip install auditwheel build patchelf setuptools
|
pip install auditwheel build patchelf setuptools
|
||||||
python setup.py clean --all
|
python setup.py clean --all
|
||||||
MLX_BUILD_STAGE=2 python -m build -w
|
MLX_BUILD_STAGE=2 python -m build -w
|
||||||
bash python/scripts/repair_cuda.sh
|
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}
|
||||||
|
|||||||
1
.github/actions/setup-linux/action.yml
vendored
1
.github/actions/setup-linux/action.yml
vendored
@@ -15,6 +15,7 @@ runs:
|
|||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
- name: Use ccache
|
- name: Use ccache
|
||||||
|
if: ${{ runner.arch == 'x86_64' }}
|
||||||
uses: hendrikmuhs/ccache-action@v1.2
|
uses: hendrikmuhs/ccache-action@v1.2
|
||||||
with:
|
with:
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
||||||
|
|||||||
10
.github/workflows/release.yml
vendored
10
.github/workflows/release.yml
vendored
@@ -128,7 +128,11 @@ jobs:
|
|||||||
|
|
||||||
build_cuda_release:
|
build_cuda_release:
|
||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: ubuntu-22-large
|
strategy:
|
||||||
|
matrix:
|
||||||
|
arch: ['x86_64', 'aarch64']
|
||||||
|
toolkit: ['cuda-12.9', 'cuda-13.0']
|
||||||
|
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
|
||||||
env:
|
env:
|
||||||
PYPI_RELEASE: 1
|
PYPI_RELEASE: 1
|
||||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
@@ -136,9 +140,11 @@ jobs:
|
|||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
toolkit: 'cuda-12.9'
|
toolkit: ${{ matrix.toolkit }}
|
||||||
- name: Build Python package
|
- name: Build Python package
|
||||||
uses: ./.github/actions/build-cuda-release
|
uses: ./.github/actions/build-cuda-release
|
||||||
|
with:
|
||||||
|
arch: ${{ matrix.arch }}
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
|
|||||||
BIN
docs/src/_static/distributed/m3-ultra-mesh-broken.png
Normal file
BIN
docs/src/_static/distributed/m3-ultra-mesh-broken.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
BIN
docs/src/_static/distributed/m3-ultra-mesh.png
Normal file
BIN
docs/src/_static/distributed/m3-ultra-mesh.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 22 KiB |
@@ -29,17 +29,20 @@ MLX has a CUDA backend which you can install with:
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install mlx[cuda]
|
pip install mlx[cuda12]
|
||||||
|
|
||||||
|
|
||||||
To install the CUDA package from PyPi your system must meet the following
|
To install the CUDA package from PyPi your system must meet the following
|
||||||
requirements:
|
requirements:
|
||||||
|
|
||||||
- Nvidia architecture >= SM 7.0 (Volta)
|
- Nvidia architecture >= SM 7.5
|
||||||
- Nvidia driver >= 550.54.14
|
- Nvidia driver >= 550.54.14
|
||||||
- CUDA toolkit >= 12.0
|
- CUDA toolkit >= 12.0
|
||||||
- Linux distribution with glibc >= 2.35
|
- Linux distribution with glibc >= 2.35
|
||||||
- Python >= 3.10
|
- Python >= 3.10
|
||||||
|
|
||||||
|
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
|
||||||
|
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
|
||||||
|
|
||||||
CPU-only (Linux)
|
CPU-only (Linux)
|
||||||
^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^
|
||||||
|
|||||||
@@ -7,22 +7,29 @@ Distributed Communication
|
|||||||
|
|
||||||
MLX supports distributed communication operations that allow the computational cost
|
MLX supports distributed communication operations that allow the computational cost
|
||||||
of training or inference to be shared across many physical machines. At the
|
of training or inference to be shared across many physical machines. At the
|
||||||
moment we support three different communication backends:
|
moment we support several different communication backends introduced below.
|
||||||
|
|
||||||
|
.. list-table::
|
||||||
|
:widths: 20 80
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - Backend
|
||||||
|
- Description
|
||||||
|
* - :ref:`MPI <mpi_section>`
|
||||||
|
- A full featured and mature distributed communications library.
|
||||||
|
* - :ref:`RING <ring_section>`
|
||||||
|
- Ring all reduce and all gather over TCP sockets. Always available and
|
||||||
|
usually faster than MPI.
|
||||||
|
* - :ref:`JACCL <ring_section>`
|
||||||
|
- Low latency communication with RDMA over thunderbolt. Necessary for
|
||||||
|
things like tensor parallelism.
|
||||||
|
* - :ref:`NCCL <nccl_section>`
|
||||||
|
- The backend of choice for CUDA environments.
|
||||||
|
|
||||||
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
|
||||||
full-featured and mature distributed communications library
|
|
||||||
* A **ring** backend of our own that uses native TCP sockets. It should be
|
|
||||||
faster for thunderbolt connections, but it also works over Ethernet.
|
|
||||||
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.
|
|
||||||
|
|
||||||
The list of all currently supported operations and their documentation can be
|
The list of all currently supported operations and their documentation can be
|
||||||
seen in the :ref:`API docs<distributed>`.
|
seen in the :ref:`API docs<distributed>`.
|
||||||
|
|
||||||
.. note::
|
|
||||||
Some operations may not be supported or not as fast as they should be.
|
|
||||||
We are adding more and tuning the ones we have as we are figuring out the
|
|
||||||
best way to do distributed computing on Macs using MLX.
|
|
||||||
|
|
||||||
Getting Started
|
Getting Started
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
@@ -85,7 +92,7 @@ Selecting Backend
|
|||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
You can select the backend you want to use when calling :func:`init` by passing
|
You can select the backend you want to use when calling :func:`init` by passing
|
||||||
one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
|
one of ``{'any', 'ring', 'jaccl', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
|
||||||
available backends. If they all fail then a singleton group is created.
|
available backends. If they all fail then a singleton group is created.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@@ -110,6 +117,8 @@ The following examples aim to clarify the backend initialization logic in MLX:
|
|||||||
world_ring = mx.distributed.init(backend="ring")
|
world_ring = mx.distributed.init(backend="ring")
|
||||||
world_any = mx.distributed.init() # same as MPI because it was initialized first!
|
world_any = mx.distributed.init() # same as MPI because it was initialized first!
|
||||||
|
|
||||||
|
.. _training_example:
|
||||||
|
|
||||||
Training Example
|
Training Example
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
@@ -192,16 +201,273 @@ almost identical to the example above:
|
|||||||
loss = step(model, x, y)
|
loss = step(model, x, y)
|
||||||
mx.eval(loss, model.parameters())
|
mx.eval(loss, model.parameters())
|
||||||
|
|
||||||
|
.. _ring_section:
|
||||||
|
|
||||||
|
Getting Started with Ring
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
The ring backend does not depend on any third party library so it is always
|
||||||
|
available. It uses TCP sockets so the nodes need to be reachable via a network.
|
||||||
|
As the name suggests the nodes are connected in a ring which means that rank 1
|
||||||
|
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
|
||||||
|
and so on and so forth. As a result :func:`send` and :func:`recv` with
|
||||||
|
arbitrary sender and receiver is not supported in the ring backend.
|
||||||
|
|
||||||
|
Defining a Ring
|
||||||
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The easiest way to define and use a ring is via a JSON hostfile and the
|
||||||
|
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
|
||||||
|
defines a hostname to ssh into to run commands on this node and one or more IPs
|
||||||
|
that this node will listen to for connections.
|
||||||
|
|
||||||
|
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
|
||||||
|
rank 0, ``hostname2`` rank 1 etc.
|
||||||
|
|
||||||
|
.. code:: json
|
||||||
|
|
||||||
|
[
|
||||||
|
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
|
||||||
|
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
|
||||||
|
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
|
||||||
|
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
|
||||||
|
]
|
||||||
|
|
||||||
|
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
|
||||||
|
node, run the script which will listen for connections in each of the provided
|
||||||
|
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
|
||||||
|
connection from ``123.123.123.4`` and so on and so forth.
|
||||||
|
|
||||||
|
Thunderbolt Ring
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Although the ring backend can have benefits over MPI even for Ethernet, its
|
||||||
|
main purpose is to use Thunderbolt rings for higher bandwidth communication.
|
||||||
|
Setting up such thunderbolt rings can be done manually, but is a relatively
|
||||||
|
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
|
||||||
|
|
||||||
|
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
|
||||||
|
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
|
||||||
|
utility as follows:
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --backend ring
|
||||||
|
|
||||||
|
By default the script will attempt to discover the thunderbolt ring and provide
|
||||||
|
you with the commands to configure each node as well as the ``hostfile.json``
|
||||||
|
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
|
||||||
|
then ``--auto-setup`` can be used to configure them automatically.
|
||||||
|
|
||||||
|
If you want to go through the process manually, the steps are as follows:
|
||||||
|
|
||||||
|
* Disable the thunderbolt bridge interface
|
||||||
|
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
|
||||||
|
corresponding to that cable in nodes ``i`` and ``i + 1``.
|
||||||
|
* Set up a unique subnetwork connecting the two nodes for the corresponding
|
||||||
|
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
|
||||||
|
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
|
||||||
|
``192.168.0.2`` respectively to the two nodes. For more details you can see
|
||||||
|
the commands prepared by the utility script.
|
||||||
|
|
||||||
|
.. _jaccl_section:
|
||||||
|
|
||||||
|
Getting Started with RDMA over Thunderbolt
|
||||||
|
------------------------------------------
|
||||||
|
|
||||||
|
Starting from version 26.2 RDMA over thunderbolt is available in MacOS and
|
||||||
|
enables low-latency communication between Macs with thunderbolt 5. MLX provides
|
||||||
|
the JACCL backend that uses this functionality to achieve communication latency
|
||||||
|
an order of magnitude lower than the ring backend.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
The name JACCL (pronounced Jackal) stands for *Jack and Angelos' Collective
|
||||||
|
Communication Library* and it is an obvious pun to Nvidia's NCCL but also
|
||||||
|
tribute to *Jack Beasley* who led the development of RDMA over Thunderbolt
|
||||||
|
at Apple.
|
||||||
|
|
||||||
|
Enabling RDMA
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Until the feature matures, enabling RDMA over thunderbolt is slightly more
|
||||||
|
involved and **cannot** be done remotely even with sudo. In fact, it has to be
|
||||||
|
done in macOS recovery:
|
||||||
|
|
||||||
|
1. `Start your computer in recovery <https://support.apple.com/en-us/102518>`_.
|
||||||
|
2. Open the Terminal by going to Utilities -> Terminal.
|
||||||
|
3. Run ``rdma_ctl enable``.
|
||||||
|
4. Reboot.
|
||||||
|
|
||||||
|
To verify that you have successfully enabled Thunderbolt RDMA you can run
|
||||||
|
``ibv_devices`` which should produce something like the following for an M3 Ultra.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
~ % ibv_devices
|
||||||
|
device node GUID
|
||||||
|
------ ----------------
|
||||||
|
rdma_en2 8096a9d9edbaac05
|
||||||
|
rdma_en3 8196a9d9edbaac05
|
||||||
|
rdma_en5 8396a9d9edbaac05
|
||||||
|
rdma_en4 8296a9d9edbaac05
|
||||||
|
rdma_en6 8496a9d9edbaac05
|
||||||
|
rdma_en7 8596a9d9edbaac05
|
||||||
|
|
||||||
|
Defining a Mesh
|
||||||
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The JACCL backend supports only fully connected topologies. Namely, there needs
|
||||||
|
to be a thunderbolt cable connecting all pairs of Macs directly. For example, in
|
||||||
|
the following topology visualizations, the left one is valid because there is a
|
||||||
|
connection from any node to any other node, while for the one on the right M3
|
||||||
|
Ultra 1 is not connected to M3 Ultra 2.
|
||||||
|
|
||||||
|
.. raw:: html
|
||||||
|
|
||||||
|
<div style="display: flex; text-align: center; align-items: end; font-size: 80%;">
|
||||||
|
<div>
|
||||||
|
<img src="/_static/distributed/m3-ultra-mesh.png" alt="M3 Ultra thunderbolt mesh" style="width: 55%">
|
||||||
|
<p>Fully connected mesh of four M3 Ultra.</p>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<img src="/_static/distributed/m3-ultra-mesh-broken.png" alt="M3 Ultra broken thunderbolt mesh" style="width: 55%">
|
||||||
|
<p>Not a valid mesh (M3 Ultra 1 is not connected to M3 Ultra 2).</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
Similar to the ring backend, the easiest way to use JACCL with MLX is to write
|
||||||
|
a JSON hostfile that will be used by ``mlx.launch``. The hostfile needs to contain
|
||||||
|
|
||||||
|
- Hostnames to use for launching scripts via ssh
|
||||||
|
- An IP for rank 0 that is reachable by all nodes
|
||||||
|
- A list of rdma devices that connect each node to each other node
|
||||||
|
|
||||||
|
The following JSON defines the valid 4-node mesh from the image above.
|
||||||
|
|
||||||
|
.. code-block:: json
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"ssh": "m3-ultra-1",
|
||||||
|
"ips": ["123.123.123.1"],
|
||||||
|
"rdma": [null, "rdma_en5", "rdma_en4", "rdma_en3"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ssh": "m3-ultra-2",
|
||||||
|
"ips": [],
|
||||||
|
"rdma": ["rdma_en5", null, "rdma_en3", "rdma_en4"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ssh": "m3-ultra-3",
|
||||||
|
"ips": [],
|
||||||
|
"rdma": ["rdma_en4", "rdma_en3", null, "rdma_en5"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ssh": "m3-ultra-4",
|
||||||
|
"ips": [],
|
||||||
|
"rdma": ["rdma_en3", "rdma_en4", "rdma_en5", null]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
Even though TCP/IP is not used when communicating with Thunderbolt RDMA,
|
||||||
|
disabling the thunderbolt bridge is still required as well as setting up
|
||||||
|
isolated local networks for each thunderbolt connection.
|
||||||
|
|
||||||
|
All of the above can be done instead via ``mlx.distributed_config``. This helper
|
||||||
|
script will
|
||||||
|
|
||||||
|
- ssh into each node
|
||||||
|
- extract the thunderbolt connectivity
|
||||||
|
- check for a valid mesh
|
||||||
|
- provide the commands to configure each node (or run them if sudo is available)
|
||||||
|
- generate the hostfile to be used with ``mlx.launch``
|
||||||
|
|
||||||
|
Putting it All Together
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
For example launching a distributed MLX script that uses JACCL is fairly simple
|
||||||
|
if the nodes are reachable via ssh and have password-less sudo.
|
||||||
|
|
||||||
|
First, connect all the thunderbolt cables. Then we can verify the connections
|
||||||
|
by using the ``mlx.distributed_config`` script to visualize them.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
mlx.distributed_config --verbose \
|
||||||
|
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
|
||||||
|
--over thunderbolt --dot | dot -Tpng | open -f -a Preview
|
||||||
|
|
||||||
|
After making sure that everything looks right we can auto-configure the nodes
|
||||||
|
and save the hostfile to ``m3-ultra-jaccl.json`` by running:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
mlx.distributed_config --verbose \
|
||||||
|
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
|
||||||
|
--over thunderbolt --backend jaccl \
|
||||||
|
--auto-setup --output m3-ultra-jaccl.json
|
||||||
|
|
||||||
|
And now we are ready to run a distributed MLX script such as distributed inference
|
||||||
|
of a gigantic model using MLX-LM.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \
|
||||||
|
--env MLX_METAL_FAST_SYNCH=1 -- \ # <--- important
|
||||||
|
/path/to/remote/python -m mlx_lm chat --model mlx-community/DeepSeek-V3.2-8bit --shard
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Defining the environment variable ``MLX_METAL_FAST_SYNCH=1`` enables a
|
||||||
|
different, faster way of synchronizing between the GPU and the CPU. It is
|
||||||
|
not specific to the JACCL backend and can be used in all cases where the CPU
|
||||||
|
and GPU need to collaborate for some computation and is pretty critical for
|
||||||
|
low-latency communication since the communication is done by the CPU.
|
||||||
|
|
||||||
|
.. _nccl_section:
|
||||||
|
|
||||||
|
Getting Started with NCCL
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
MLX on CUDA environments ships with the ability to talk to `NCCL
|
||||||
|
<https://developer.nvidia.com/nccl>`_ which is a high-performance collective
|
||||||
|
communication library that supports both multi-gpu and multi-node setups.
|
||||||
|
|
||||||
|
For CUDA environments, NCCL is the default backend for ``mlx.launch`` and all
|
||||||
|
it takes to run a distributed job is
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
mlx.launch -n 8 test.py
|
||||||
|
|
||||||
|
# perfect for interactive scripts
|
||||||
|
mlx.launch -n 8 python -m mlx_lm chat --model my-model --shard
|
||||||
|
|
||||||
|
You can also use ``mlx.launch`` to ssh to a remote node and launch a script
|
||||||
|
with the same ease
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
mlx.launch --hosts my-cuda-node -n 8 test.py
|
||||||
|
|
||||||
|
In many cases you may not want to use ``mlx.launch`` with the NCCL backend
|
||||||
|
because the cluster scheduler will be the one launching the processes. You can
|
||||||
|
:ref:`see which environment variables need to be defined <no_mlx_launch>` in
|
||||||
|
order for the MLX NCCL backend to be initialized correctly.
|
||||||
|
|
||||||
|
.. _mpi_section:
|
||||||
|
|
||||||
Getting Started with MPI
|
Getting Started with MPI
|
||||||
------------------------
|
------------------------
|
||||||
|
|
||||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
MLX already comes with the ability to "talk" to `MPI
|
||||||
machine. Launching distributed MLX programs that use MPI can be done with
|
<https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ if it is installed
|
||||||
``mpirun`` as expected. However, in the following examples we will be using
|
on the machine. Launching distributed MLX programs that use MPI can be done
|
||||||
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
|
with ``mpirun`` as expected. However, in the following examples we will be
|
||||||
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
|
using ``mlx.launch --backend mpi`` which takes care of some nuisances such as
|
||||||
library.
|
setting absolute paths for the ``mpirun`` executable and the ``libmpi.dyld``
|
||||||
|
shared library.
|
||||||
|
|
||||||
The simplest possible usage is the following which, assuming the minimal
|
The simplest possible usage is the following which, assuming the minimal
|
||||||
example in the beginning of this page, should result in:
|
example in the beginning of this page, should result in:
|
||||||
@@ -269,78 +535,116 @@ Force MPI to use the most performant network interface by setting ``--mca
|
|||||||
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
|
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
|
||||||
to use.
|
to use.
|
||||||
|
|
||||||
Getting Started with Ring
|
.. _no_mlx_launch:
|
||||||
-------------------------
|
|
||||||
|
|
||||||
The ring backend does not depend on any third party library so it is always
|
Distributed Without ``mlx.launch``
|
||||||
available. It uses TCP sockets so the nodes need to be reachable via a network.
|
----------------------------------
|
||||||
As the name suggests the nodes are connected in a ring which means that rank 1
|
|
||||||
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
|
|
||||||
and so on and so forth. As a result :func:`send` and :func:`recv` with
|
|
||||||
arbitrary sender and receiver is not supported in the ring backend.
|
|
||||||
|
|
||||||
Defining a Ring
|
None of the implementations of the distributed backends require launching with
|
||||||
^^^^^^^^^^^^^^^
|
``mlx.launch``. The script simply connects to each host. Starts a process per
|
||||||
|
rank and sets up the necessary environment variables before delegating to your
|
||||||
|
MLX script. See the :doc:`dedicated documentation page <launching_distributed>`
|
||||||
|
for more details.
|
||||||
|
|
||||||
The easiest way to define and use a ring is via a JSON hostfile and the
|
For many use-cases this will be the easiest way to perform distributed
|
||||||
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
|
computations in MLX. However, there may be reasons that you cannot or should
|
||||||
defines a hostname to ssh into to run commands on this node and one or more IPs
|
not use ``mlx.launch``. A common such case is the use of a scheduler that
|
||||||
that this node will listen to for connections.
|
starts all the processes for you on machines undetermined at the time of
|
||||||
|
scheduling the job.
|
||||||
|
|
||||||
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
|
Below we list the environment variables required to use each backend.
|
||||||
rank 0, ``hostname2`` rank 1 etc.
|
|
||||||
|
|
||||||
.. code:: json
|
Ring
|
||||||
|
^^^^^^
|
||||||
|
|
||||||
[
|
**MLX_RANK** should contain a single 0-based integer that defines the rank of
|
||||||
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
|
the process.
|
||||||
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
|
|
||||||
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
|
|
||||||
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
|
|
||||||
]
|
|
||||||
|
|
||||||
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
|
**MLX_HOSTFILE** should contain the path to a json file that contains IPs and
|
||||||
node, run the script which will listen for connections in each of the provided
|
ports for each rank to listen to, something like the following:
|
||||||
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
|
|
||||||
connection from ``123.123.123.4`` and so on and so forth.
|
|
||||||
|
|
||||||
Thunderbolt Ring
|
.. code-block:: json
|
||||||
^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
Although the ring backend can have benefits over MPI even for Ethernet, its
|
[
|
||||||
main purpose is to use Thunderbolt rings for higher bandwidth communication.
|
["123.123.1.1:5000", "123.123.1.2:5000"],
|
||||||
Setting up such thunderbolt rings can be done manually, but is a relatively
|
["123.123.2.1:5000", "123.123.2.2:5000"],
|
||||||
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
|
["123.123.3.1:5000", "123.123.3.2:5000"],
|
||||||
|
["123.123.4.1:5000", "123.123.4.2:5000"]
|
||||||
|
]
|
||||||
|
|
||||||
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
|
**MLX_RING_VERBOSE** is optional and if set to 1 it enables some more logging
|
||||||
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
|
from the distributed backend.
|
||||||
utility as follows:
|
|
||||||
|
|
||||||
.. code:: shell
|
JACCL
|
||||||
|
^^^^^
|
||||||
|
|
||||||
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
|
**MLX_RANK** should contain a single 0-based integer that defines the rank of
|
||||||
|
the process.
|
||||||
|
|
||||||
By default the script will attempt to discover the thunderbolt ring and provide
|
**MLX_JACCL_COORDINATOR** should contain the IP and port that rank 0 can listen
|
||||||
you with the commands to configure each node as well as the ``hostfile.json``
|
to all the other ranks connect to in order to establish the RDMA connections.
|
||||||
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
|
|
||||||
then ``--auto-setup`` can be used to configure them automatically.
|
|
||||||
|
|
||||||
To validate your connection without configuring anything
|
**MLX_IBV_DEVICES** should contain the path to a json file that contains the
|
||||||
``mlx.distributed_config`` can also plot the ring using DOT format.
|
ibverbs device names that connect each node to each other node, something like
|
||||||
|
the following:
|
||||||
|
|
||||||
.. code:: shell
|
.. code-block:: json
|
||||||
|
|
||||||
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
|
[
|
||||||
dot -Tpng ring.dot >ring.png
|
[null, "rdma_en5", "rdma_en4", "rdma_en3"],
|
||||||
open ring.png
|
["rdma_en5", null, "rdma_en3", "rdma_en4"],
|
||||||
|
["rdma_en4", "rdma_en3", null, "rdma_en5"],
|
||||||
|
["rdma_en3", "rdma_en4", "rdma_en5", null]
|
||||||
|
]
|
||||||
|
|
||||||
If you want to go through the process manually, the steps are as follows:
|
|
||||||
|
|
||||||
* Disable the thunderbolt bridge interface
|
NCCL
|
||||||
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
|
^^^^^
|
||||||
corresponding to that cable in nodes ``i`` and ``i + 1``.
|
|
||||||
* Set up a unique subnetwork connecting the two nodes for the corresponding
|
**MLX_RANK** should contain a single 0-based integer that defines the rank of
|
||||||
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
|
the process.
|
||||||
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
|
|
||||||
``192.168.0.2`` respectively to the two nodes. For more details you can see
|
**MLX_WORLD_SIZE** should contain the total number of processes that will be
|
||||||
the commands prepared by the utility script.
|
launched.
|
||||||
|
|
||||||
|
**NCCL_HOST_IP** and **NCCL_PORT** should contain the IP and port that all
|
||||||
|
hosts can connect to to establish the NCCL communication.
|
||||||
|
|
||||||
|
**CUDA_VISIBLE_DEVICES** should contain the local index of the gpu that
|
||||||
|
corresponds to this process.
|
||||||
|
|
||||||
|
Of course any `other environment variable
|
||||||
|
<https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html>`_ that is
|
||||||
|
used by NCCL can be set.
|
||||||
|
|
||||||
|
.. _tips_and_tricks:
|
||||||
|
|
||||||
|
Tips and Tricks
|
||||||
|
----------------
|
||||||
|
|
||||||
|
This is a small collection of tips to help you utilize better the distributed
|
||||||
|
communication capabilities of MLX.
|
||||||
|
|
||||||
|
- *Test locally first.*
|
||||||
|
|
||||||
|
You can use the pattern ``mlx.launch -n2 -- my_script.py`` to run a small
|
||||||
|
scale test on a single node first.
|
||||||
|
|
||||||
|
- *Batch your communication.*
|
||||||
|
|
||||||
|
As described in the :ref:`training example <training_example>`, performing a
|
||||||
|
lot of small communication can hurt performance. Copy the approach of
|
||||||
|
:func:`mlx.nn.average_gradients` to gather many small communications in a
|
||||||
|
single large one.
|
||||||
|
|
||||||
|
- *Visualize the connectivity.*
|
||||||
|
|
||||||
|
Use ``mlx.distributed_config --hosts h1,h2,h3 --over thunderbolt --dot`` to
|
||||||
|
visualize the connnections and make sure that the cables are connected
|
||||||
|
correctly. See the :ref:`JACCL section <jaccl_section>` for examples.
|
||||||
|
|
||||||
|
- *Use the debugger.*
|
||||||
|
|
||||||
|
``mlx.launch`` is meant for interactive use. It broadcasts stdin to all
|
||||||
|
processes and gathers stdout from all processes. This makes using ``pdb`` a
|
||||||
|
breeze.
|
||||||
|
|||||||
@@ -7,13 +7,106 @@ Launching Distributed Programs
|
|||||||
|
|
||||||
.. currentmodule:: mlx.core.distributed
|
.. currentmodule:: mlx.core.distributed
|
||||||
|
|
||||||
Installing the MLX python package provides a helper script ``mlx.launch`` that
|
Installing the MLX python package provides two utilities to help you configure
|
||||||
can be used to run python scripts distributed on several nodes. It allows
|
your Macs for distributed computation and also launch distributed programs on
|
||||||
launching using either the MPI backend or the ring backend. See the
|
multiple nodes or with many processes in a single node. These utilities are aptly named
|
||||||
:doc:`distributed docs <distributed>` for the different backends.
|
|
||||||
|
|
||||||
Usage
|
- ``mlx.launch``
|
||||||
-----
|
- ``mlx.distributed_config``
|
||||||
|
|
||||||
|
See the :doc:`distributed docs <distributed>` for an introduction and
|
||||||
|
getting-started guides to the various backends.
|
||||||
|
|
||||||
|
``mlx.distributed_config``
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
Unless you are launching distributed jobs locally for development or multi-gpu
|
||||||
|
CUDA environments, then you have several Macs that you need to configure for
|
||||||
|
distributed communication with MLX.
|
||||||
|
|
||||||
|
``mlx.distributed_config`` aims to automate the process of configuring the
|
||||||
|
network interfaces (especially for communication over thunderbolt) and also
|
||||||
|
creating the hostfile to be used with ``mlx.launch``.
|
||||||
|
|
||||||
|
We will analyse 3 cases of using ``mlx.distributed_config``
|
||||||
|
|
||||||
|
1. RDMA over thunderbolt using JACCL
|
||||||
|
2. TCP/IP over thunderbolt using the ring backend
|
||||||
|
3. TCP/IP over ethernet using the ring backend
|
||||||
|
|
||||||
|
JACCL
|
||||||
|
^^^^^^^
|
||||||
|
|
||||||
|
After following :ref:`the steps to enable RDMA <jaccl_section>` you can run the
|
||||||
|
following command to configure the nodes and create the hostfile.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
mlx.distributed_config --verbose --backend jaccl \
|
||||||
|
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 --over thunderbolt \
|
||||||
|
--auto-setup --output m3-ultra-jaccl.json
|
||||||
|
|
||||||
|
Let's walk through the steps that the script takes to configure the nodes.
|
||||||
|
|
||||||
|
1. Ssh to all nodes to verify that they are reachable
|
||||||
|
2. Extract the thunderbolt connectivity. Namely run commands on each node to
|
||||||
|
calculate which node is connected to which other node.
|
||||||
|
3. Verify that we have a valid fully connected mesh
|
||||||
|
4. Check that RDMA is enabled
|
||||||
|
5. Extract the ethernet IP from interface en0
|
||||||
|
6. Disable the thunderbolt bridge and set up peer to peer networks for each
|
||||||
|
thunderbolt cable
|
||||||
|
7. Write the hostfile
|
||||||
|
|
||||||
|
Knowing the above steps allows you to manually configure the nodes but also
|
||||||
|
debug any configuration issue. For instance changing the Ethernet IP to a
|
||||||
|
different interface directly in the config is possible (as long as it is
|
||||||
|
reachable from all nodes).
|
||||||
|
|
||||||
|
The ``--auto-setup`` argument requires password-less sudo on each node. If it
|
||||||
|
isn't available then the configuration script will print commands to be run on
|
||||||
|
each node.
|
||||||
|
|
||||||
|
Ring over thunderbolt
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Setting up a ring backend over thunderbolt only requires changing the
|
||||||
|
``--backend`` from ``jaccl`` to ``ring``.
|
||||||
|
|
||||||
|
The steps are very similar with the main difference being that instead of
|
||||||
|
verifying that the nodes are fully connected, the script attempts to identify a
|
||||||
|
ring topology (or multiple rings).
|
||||||
|
|
||||||
|
Ring over Ethernet
|
||||||
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Configuring the ring backend over ethernet doesn't require setting up network
|
||||||
|
interface and as such it simply extracts the ``en0`` IP from each node and
|
||||||
|
writes the hostfile.
|
||||||
|
|
||||||
|
Debugging cable connections
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
``mlx.distributed_config`` can help you debug the connectivity of your nodes
|
||||||
|
over thunderbolt by exporting a graph of the connections.
|
||||||
|
|
||||||
|
Running
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
mlx.distributed_config --verbose \
|
||||||
|
--hosts host1,host2,host3,host4 \
|
||||||
|
--over thunderbolt --dot
|
||||||
|
|
||||||
|
will export a `GraphViz <https://graphviz.org>`_ representation of the
|
||||||
|
connections between the nodes which makes it very easy to figure out which
|
||||||
|
cable is not connected correctly.
|
||||||
|
|
||||||
|
See :ref:`the JACCL section <jaccl_section>` for an example.
|
||||||
|
|
||||||
|
|
||||||
|
``mlx.launch``
|
||||||
|
--------------
|
||||||
|
|
||||||
The minimal usage example of ``mlx.launch`` is simply
|
The minimal usage example of ``mlx.launch`` is simply
|
||||||
|
|
||||||
@@ -33,6 +126,10 @@ the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
|
|||||||
It also takes care of forwarding the output of each remote process to stdout
|
It also takes care of forwarding the output of each remote process to stdout
|
||||||
and stderr respectively.
|
and stderr respectively.
|
||||||
|
|
||||||
|
Importantly, it also broadcasts stdin to each process which enables interactive
|
||||||
|
programs to work in distributed mode as well as debugging using the interactive
|
||||||
|
debugger.
|
||||||
|
|
||||||
Providing Hosts
|
Providing Hosts
|
||||||
^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
@@ -63,10 +160,62 @@ host and on the same path. A good checklist to debug errors is the following:
|
|||||||
``mlx.launch --print-python`` to see what that path is.
|
``mlx.launch --print-python`` to see what that path is.
|
||||||
* the script you want to run is available on all hosts at the same path
|
* the script you want to run is available on all hosts at the same path
|
||||||
|
|
||||||
|
If you are launching from a node with a completely different setup than the
|
||||||
|
nodes that the program will run on, you can specify ``--no-verify-script`` so
|
||||||
|
that ``mlx.launch`` does not attempt to verify that the executable and script
|
||||||
|
exist locally before launching the distributed job.
|
||||||
|
|
||||||
|
.. _ring_specifics:
|
||||||
|
|
||||||
|
Ring Specifics
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The :ref:`ring <ring_section>` backend, which is also the default
|
||||||
|
backend, can be explicitly selected with the argument ``--backend ring``. The
|
||||||
|
ring backend has some specific requirements and arguments that are different to
|
||||||
|
other backends:
|
||||||
|
|
||||||
|
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
|
||||||
|
ssh to a hostname that does not correspond to the IP we want to bind to we
|
||||||
|
have to provide a hostfile.
|
||||||
|
* ``--starting-port`` defines the port to bind to on the remote hosts.
|
||||||
|
Specifically rank 0 for the first IP will use this port and each subsequent
|
||||||
|
IP or rank will add 1 to this port.
|
||||||
|
* ``--connections-per-ip`` allows us to increase the number of connections
|
||||||
|
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
|
||||||
|
``mpirun``.
|
||||||
|
|
||||||
|
.. _jaccl_specifics:
|
||||||
|
|
||||||
|
JACCL Specifics
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The :ref:`JACCL <jaccl_section>` backend can be selected with the argument
|
||||||
|
``--backend jaccl``. A hostfile is necessary to launch with this backend
|
||||||
|
because it needs to contain the RDMA devices connecting each node to each other
|
||||||
|
node.
|
||||||
|
|
||||||
|
NCCL Specifics
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The :ref:`NCCL <nccl_section>` backend is the default backend for CUDA
|
||||||
|
environments. When launching from a Mac to a Linux machine with CUDA then the
|
||||||
|
backend should be selected using ``--backend nccl``.
|
||||||
|
|
||||||
|
The ``--repeat-hosts, -n`` argument should be used to launch multi-node and
|
||||||
|
multi-gpu jobs. For instance
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
mlx.launch --backend nccl --hosts linux-1,linux-2 -n 8 --no-verify-script -- ./my-job.sh
|
||||||
|
|
||||||
|
will attempt to launch 16 processes, 8 on each node that will all run
|
||||||
|
``my-job.sh``.
|
||||||
|
|
||||||
.. _mpi_specifics:
|
.. _mpi_specifics:
|
||||||
|
|
||||||
MPI Specifics
|
MPI Specifics
|
||||||
-------------
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
|
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
|
||||||
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
|
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
|
||||||
@@ -83,23 +232,3 @@ to choose a specific interface for the byte-transfer-layer of MPI we can call
|
|||||||
.. code:: shell
|
.. code:: shell
|
||||||
|
|
||||||
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
|
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
|
||||||
|
|
||||||
|
|
||||||
.. _ring_specifics:
|
|
||||||
|
|
||||||
Ring Specifics
|
|
||||||
--------------
|
|
||||||
|
|
||||||
The ring backend, which is also the default backend, can be explicitly selected
|
|
||||||
with the argument ``--backend ring``. The ring backend has some specific
|
|
||||||
requirements and arguments that are different to MPI:
|
|
||||||
|
|
||||||
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
|
|
||||||
ssh to a hostname that does not correspond to the IP we want to bind to we
|
|
||||||
have to provide a hostfile.
|
|
||||||
* ``--starting-port`` defines the port to bind to on the remote hosts.
|
|
||||||
Specifically rank 0 for the first IP will use this port and each subsequent
|
|
||||||
IP or rank will add 1 to this port.
|
|
||||||
* ``--connections-per-ip`` allows us to increase the number of connections
|
|
||||||
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
|
|
||||||
``mpirun``.
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
|
|
||||||
namespace mlx::core::allocator {
|
|
||||||
|
|
||||||
Buffer malloc(size_t size) {
|
|
||||||
auto buffer = allocator().malloc(size);
|
|
||||||
if (size && !buffer.ptr()) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
void free(Buffer buffer) {
|
|
||||||
allocator().free(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
|
||||||
@@ -28,16 +28,16 @@ class Buffer {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
Buffer malloc(size_t size);
|
|
||||||
|
|
||||||
void free(Buffer buffer);
|
|
||||||
|
|
||||||
class Allocator {
|
class Allocator {
|
||||||
/** Abstract base class for a memory allocator. */
|
/** Abstract base class for a memory allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size) = 0;
|
virtual Buffer malloc(size_t size) = 0;
|
||||||
virtual void free(Buffer buffer) = 0;
|
virtual void free(Buffer buffer) = 0;
|
||||||
virtual size_t size(Buffer buffer) const = 0;
|
virtual size_t size(Buffer buffer) const = 0;
|
||||||
|
virtual Buffer make_buffer(void* ptr, size_t size) {
|
||||||
|
return Buffer{nullptr};
|
||||||
|
};
|
||||||
|
virtual void release(Buffer buffer) {}
|
||||||
|
|
||||||
Allocator() = default;
|
Allocator() = default;
|
||||||
Allocator(const Allocator& other) = delete;
|
Allocator(const Allocator& other) = delete;
|
||||||
@@ -49,4 +49,25 @@ class Allocator {
|
|||||||
|
|
||||||
Allocator& allocator();
|
Allocator& allocator();
|
||||||
|
|
||||||
|
inline Buffer malloc(size_t size) {
|
||||||
|
return allocator().malloc(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void free(Buffer buffer) {
|
||||||
|
allocator().free(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a Buffer from a raw pointer of the given size without a copy. If a
|
||||||
|
// no-copy conversion is not possible then the returned buffer.ptr() will be
|
||||||
|
// nullptr. Any buffer created with this function must be released with
|
||||||
|
// release(buffer)
|
||||||
|
inline Buffer make_buffer(void* ptr, size_t size) {
|
||||||
|
return allocator().make_buffer(ptr, size);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Release a buffer from the allocator made with make_buffer
|
||||||
|
inline void release(Buffer buffer) {
|
||||||
|
allocator().release(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
|||||||
@@ -82,6 +82,28 @@ array::array(std::initializer_list<int> data, Dtype dtype)
|
|||||||
init(data.begin());
|
init(data.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array::array(
|
||||||
|
void* data,
|
||||||
|
Shape shape,
|
||||||
|
Dtype dtype,
|
||||||
|
const std::function<void(void*)>& deleter)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
|
auto buffer = allocator::make_buffer(data, nbytes());
|
||||||
|
if (buffer.ptr() == nullptr) {
|
||||||
|
set_data(allocator::malloc(nbytes()));
|
||||||
|
auto ptr = static_cast<char*>(data);
|
||||||
|
std::copy(ptr, ptr + nbytes(), this->data<char>());
|
||||||
|
deleter(data);
|
||||||
|
} else {
|
||||||
|
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
|
||||||
|
auto ptr = buffer.ptr();
|
||||||
|
allocator::release(buffer);
|
||||||
|
return deleter(ptr);
|
||||||
|
};
|
||||||
|
set_data(buffer, std::move(wrapped_deleter));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* Build an array from a shared buffer */
|
/* Build an array from a shared buffer */
|
||||||
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
|
|||||||
10
mlx/array.h
10
mlx/array.h
@@ -57,6 +57,16 @@ class array {
|
|||||||
Shape shape,
|
Shape shape,
|
||||||
Dtype dtype = TypeToDtype<T>());
|
Dtype dtype = TypeToDtype<T>());
|
||||||
|
|
||||||
|
/* Build an array from a raw pointer. The constructor will attempt to use the
|
||||||
|
* input data without a copy. The deleter will be called when the array no
|
||||||
|
* longer needs the underlying memory - after the array is destroyed in the
|
||||||
|
* no-copy case and after the copy otherwise. */
|
||||||
|
explicit array(
|
||||||
|
void* data,
|
||||||
|
Shape shape,
|
||||||
|
Dtype dtype,
|
||||||
|
const std::function<void(void*)>& deleter);
|
||||||
|
|
||||||
/* Build an array from a buffer */
|
/* Build an array from a buffer */
|
||||||
explicit array(
|
explicit array(
|
||||||
allocator::Buffer data,
|
allocator::Buffer data,
|
||||||
|
|||||||
@@ -20,6 +20,19 @@ constexpr int page_size = 16384;
|
|||||||
// Any allocations smaller than this will try to use the small pool
|
// Any allocations smaller than this will try to use the small pool
|
||||||
constexpr int small_block_size = 8;
|
constexpr int small_block_size = 8;
|
||||||
|
|
||||||
|
#if CUDART_VERSION >= 13000
|
||||||
|
inline cudaMemLocation cuda_mem_loc(int i) {
|
||||||
|
cudaMemLocation loc;
|
||||||
|
loc.type = cudaMemLocationTypeDevice;
|
||||||
|
loc.id = i;
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
inline int cuda_mem_loc(int i) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
#endif // CUDART_VERSION >= 13000
|
||||||
|
|
||||||
// The small pool size in bytes. This should be a multiple of the host page
|
// The small pool size in bytes. This should be a multiple of the host page
|
||||||
// size and small_block_size.
|
// size and small_block_size.
|
||||||
constexpr int small_pool_size = 4 * page_size;
|
constexpr int small_pool_size = 4 * page_size;
|
||||||
@@ -35,13 +48,7 @@ SmallSizePool::SmallSizePool() {
|
|||||||
int device_count = 0;
|
int device_count = 0;
|
||||||
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
||||||
for (int i = 0; i < device_count; ++i) {
|
for (int i = 0; i < device_count; ++i) {
|
||||||
#if CUDART_VERSION >= 13000
|
auto loc = cuda_mem_loc(i);
|
||||||
cudaMemLocation loc;
|
|
||||||
loc.type = cudaMemLocationTypeDevice;
|
|
||||||
loc.id = i;
|
|
||||||
#else
|
|
||||||
int loc = i;
|
|
||||||
#endif // CUDART_VERSION >= 13000
|
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
|
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
|
||||||
}
|
}
|
||||||
@@ -90,9 +97,10 @@ CudaAllocator::CudaAllocator()
|
|||||||
page_size,
|
page_size,
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||||
size_t free, total;
|
size_t free;
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
|
||||||
memory_limit_ = total * 0.9;
|
memory_limit_ = total_memory_ * 0.95;
|
||||||
|
free_limit_ = total_memory_ - memory_limit_;
|
||||||
max_pool_size_ = memory_limit_;
|
max_pool_size_ = memory_limit_;
|
||||||
|
|
||||||
int device_count = 0;
|
int device_count = 0;
|
||||||
@@ -104,6 +112,10 @@ CudaAllocator::CudaAllocator()
|
|||||||
cudaStream_t s;
|
cudaStream_t s;
|
||||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
|
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
|
||||||
free_streams_.push_back(s);
|
free_streams_.push_back(s);
|
||||||
|
|
||||||
|
cudaMemPool_t mem_pool;
|
||||||
|
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pool, i));
|
||||||
|
mem_pools_.push_back(mem_pool);
|
||||||
}
|
}
|
||||||
CHECK_CUDA_ERROR(cudaSetDevice(curr));
|
CHECK_CUDA_ERROR(cudaSetDevice(curr));
|
||||||
}
|
}
|
||||||
@@ -154,23 +166,35 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
|||||||
}
|
}
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
cudaError_t err;
|
|
||||||
void* data = nullptr;
|
void* data = nullptr;
|
||||||
if (device == -1) {
|
if (device == -1) {
|
||||||
err = cudaMallocManaged(&data, size);
|
CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
|
||||||
} else {
|
} else {
|
||||||
err = cudaMallocAsync(&data, size, stream);
|
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
|
||||||
}
|
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
|
||||||
throw std::runtime_error(fmt::format(
|
|
||||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
|
||||||
}
|
}
|
||||||
if (!data) {
|
if (!data) {
|
||||||
return Buffer{nullptr};
|
std::ostringstream msg;
|
||||||
|
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
buf = new CudaBuffer{data, size, device};
|
buf = new CudaBuffer{data, size, device};
|
||||||
}
|
}
|
||||||
lock.lock();
|
lock.lock();
|
||||||
|
|
||||||
|
// If any cuda memory pool has too much reserved memory, clear some
|
||||||
|
// memory from the cache. This prevents graph / kernel execution failing
|
||||||
|
// from OOM
|
||||||
|
if (get_cache_memory() > 0) {
|
||||||
|
for (auto p : mem_pools_) {
|
||||||
|
size_t used = 0;
|
||||||
|
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
|
||||||
|
p, cudaMemPoolAttrReservedMemCurrent, &used));
|
||||||
|
if (used > (total_memory_ - free_limit_)) {
|
||||||
|
buffer_cache_.release_cached_buffers(free_limit_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
active_memory_ += buf->size;
|
active_memory_ += buf->size;
|
||||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||||
|
|||||||
@@ -71,11 +71,14 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
|
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
size_t memory_limit_;
|
size_t memory_limit_;
|
||||||
|
size_t free_limit_;
|
||||||
|
size_t total_memory_;
|
||||||
size_t max_pool_size_;
|
size_t max_pool_size_;
|
||||||
BufferCache<CudaBuffer> buffer_cache_;
|
BufferCache<CudaBuffer> buffer_cache_;
|
||||||
size_t active_memory_{0};
|
size_t active_memory_{0};
|
||||||
size_t peak_memory_{0};
|
size_t peak_memory_{0};
|
||||||
std::vector<cudaStream_t> free_streams_;
|
std::vector<cudaStream_t> free_streams_;
|
||||||
|
std::vector<cudaMemPool_t> mem_pools_;
|
||||||
SmallSizePool scalar_pool_;
|
SmallSizePool scalar_pool_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -95,11 +95,14 @@ void copy_general_input(
|
|||||||
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||||
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
int work_per_thread = 1;
|
|
||||||
|
int work_per_thread = 8;
|
||||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||||
auto rest = out.size() / dim0;
|
auto rest = out.size() / dim0;
|
||||||
if (dim0 >= 4) {
|
if (dim0 >= 4 && dim0 < 8) {
|
||||||
work_per_thread = 4;
|
work_per_thread = 4;
|
||||||
|
} else if (dim0 < 4) {
|
||||||
|
work_per_thread = 1;
|
||||||
}
|
}
|
||||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||||
@@ -110,7 +113,10 @@ void copy_general_input(
|
|||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel =
|
auto kernel =
|
||||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
||||||
if (work_per_thread == 4) {
|
if (work_per_thread == 8) {
|
||||||
|
kernel =
|
||||||
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
|
||||||
|
} else if (work_per_thread == 4) {
|
||||||
kernel =
|
kernel =
|
||||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
||||||
}
|
}
|
||||||
@@ -127,7 +133,9 @@ void copy_general_input(
|
|||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
||||||
if (work_per_thread == 4) {
|
if (work_per_thread == 8) {
|
||||||
|
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
|
||||||
|
} else if (work_per_thread == 4) {
|
||||||
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
||||||
}
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
|
|||||||
@@ -318,46 +318,52 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
|||||||
insert_graph_dependencies(GraphNode{node, "K"});
|
insert_graph_dependencies(GraphNode{node, "K"});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
|
||||||
// CUDA graphs do not get updated correctly if a kernel node getting updated
|
// Constructs a key representing the nodes of a sub-graph.
|
||||||
// has a different cluster shape than the node it's being updated with.
|
// Also checks if the sub-graph is updatable as CUDA graphs do not get
|
||||||
|
// updated correctly if a kernel node getting updated has a different cluster
|
||||||
|
// shape than the node it's being updated with.
|
||||||
|
std::string key = "(";
|
||||||
size_t num_nodes = 0;
|
size_t num_nodes = 0;
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
|
||||||
if (num_nodes == 0) {
|
if (num_nodes == 0) {
|
||||||
return true;
|
return {key + ")", true};
|
||||||
}
|
}
|
||||||
|
bool is_updatable = true;
|
||||||
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
|
||||||
for (const auto& node : nodes) {
|
for (const auto& node : nodes) {
|
||||||
|
if (!is_updatable) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
cudaGraphNodeType type;
|
cudaGraphNodeType type;
|
||||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||||
if (type == cudaGraphNodeTypeGraph) {
|
if (type == cudaGraphNodeTypeGraph) {
|
||||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||||
if (num_nodes > 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
cudaGraph_t child;
|
cudaGraph_t child;
|
||||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||||
return is_graph_updatable(child, cluster_dim_x);
|
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
||||||
|
is_updatable &= sub_is_updatable;
|
||||||
|
key += subkey;
|
||||||
|
} else if (type == cudaGraphNodeTypeMemset) {
|
||||||
|
key += "M";
|
||||||
} else if (type != cudaGraphNodeTypeKernel) {
|
} else if (type != cudaGraphNodeTypeKernel) {
|
||||||
return false;
|
is_updatable = false;
|
||||||
} else {
|
} else {
|
||||||
cudaLaunchAttributeValue cluster_dim;
|
cudaLaunchAttributeValue cluster_dim;
|
||||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||||
// Only dim.x can be greater than 1
|
// Only allow dim.x to be greater than 1
|
||||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||||
return false;
|
is_updatable = false;
|
||||||
|
} else {
|
||||||
|
key += "K";
|
||||||
|
key += std::to_string(cluster_dim.clusterDim.x);
|
||||||
}
|
}
|
||||||
// Only one child node allowed when subgraph uses clusters
|
|
||||||
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
cluster_dim_x = cluster_dim.clusterDim.x;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
key += ")";
|
||||||
|
return {key, is_updatable};
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||||
@@ -370,11 +376,10 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
int cluster_dim_x = 0;
|
auto [sub_graph_key, is_updatable] = subgraph_to_key(child);
|
||||||
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
|
is_graph_updatable_ &= is_updatable;
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
insert_graph_dependencies(
|
insert_graph_dependencies(GraphNode{node, sub_graph_key});
|
||||||
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CommandEncoder::needs_commit() {
|
bool CommandEncoder::needs_commit() {
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ class CommandEncoder {
|
|||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
// K = kernel
|
// K = kernel
|
||||||
// E = empty
|
// E = empty
|
||||||
// G* = subgraph (with metadata)
|
// () = subgraph (with metadata)
|
||||||
// Symbols ':', '-' are reserved as separators
|
// Symbols ':', '-' are reserved as separators
|
||||||
std::string node_type;
|
std::string node_type;
|
||||||
std::string id;
|
std::string id;
|
||||||
|
|||||||
@@ -89,9 +89,13 @@ template <
|
|||||||
int NDIM,
|
int NDIM,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int N_READS = 4>
|
int N_READS = 4,
|
||||||
__global__ void
|
int BLOCKS = 1>
|
||||||
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
__global__ void col_reduce_looped(
|
||||||
|
T* in,
|
||||||
|
U* out,
|
||||||
|
const __grid_constant__ ColReduceArgs args,
|
||||||
|
int64_t out_size) {
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
@@ -102,6 +106,8 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
size_t tile_idx = grid.block_rank();
|
size_t tile_idx = grid.block_rank();
|
||||||
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
||||||
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
||||||
|
size_t tile_out = tile_y / out_size;
|
||||||
|
tile_y = tile_y % out_size;
|
||||||
|
|
||||||
// Compute the indices for the thread within the tile
|
// Compute the indices for the thread within the tile
|
||||||
short thread_x = block.thread_rank() % threads_per_row;
|
short thread_x = block.thread_rank() % threads_per_row;
|
||||||
@@ -118,12 +124,23 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
totals[i] = ReduceInit<Op, T>::value();
|
totals[i] = ReduceInit<Op, T>::value();
|
||||||
}
|
}
|
||||||
|
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
|
||||||
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
size_t total = args.non_col_reductions * args.reduction_size;
|
size_t total = args.non_col_reductions * args.reduction_size;
|
||||||
|
size_t per_block, start, end;
|
||||||
|
if constexpr (BLOCKS > 1) {
|
||||||
|
per_block = (total + BLOCKS - 1) / BLOCKS;
|
||||||
|
start = tile_out * per_block + thread_y;
|
||||||
|
end = min((tile_out + 1) * per_block, total);
|
||||||
|
} else {
|
||||||
|
per_block = total;
|
||||||
|
start = thread_y;
|
||||||
|
end = total;
|
||||||
|
}
|
||||||
|
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
if (tile_x * BN + BN <= args.reduction_stride) {
|
if (tile_x * BN + BN <= args.reduction_stride) {
|
||||||
if (args.reduction_stride % N_READS == 0) {
|
if (args.reduction_stride % N_READS == 0) {
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
for (size_t r = start; r < end; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@@ -132,7 +149,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
for (size_t r = start; r < end; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@@ -142,7 +159,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
for (size_t r = start; r < end; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
thread_x,
|
thread_x,
|
||||||
@@ -173,6 +190,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
|
|
||||||
// Write result.
|
// Write result.
|
||||||
if (warp.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
|
if (BLOCKS > 1) {
|
||||||
|
out += tile_out * out_size * args.reduction_stride;
|
||||||
|
}
|
||||||
cub::StoreDirectBlocked(
|
cub::StoreDirectBlocked(
|
||||||
warp.meta_group_rank(),
|
warp.meta_group_rank(),
|
||||||
out + tile_y * args.reduction_stride + tile_x * BN,
|
out + tile_y * args.reduction_stride + tile_x * BN,
|
||||||
@@ -227,11 +247,12 @@ __global__ void col_reduce_small(
|
|||||||
inline auto output_grid_for_col_reduce(
|
inline auto output_grid_for_col_reduce(
|
||||||
const array& out,
|
const array& out,
|
||||||
const cu::ColReduceArgs& args,
|
const cu::ColReduceArgs& args,
|
||||||
int bn) {
|
int bn,
|
||||||
|
int outer = 1) {
|
||||||
int gx, gy = 1;
|
int gx, gy = 1;
|
||||||
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
||||||
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
||||||
size_t n_blocks = n_outer_blocks * n_inner_blocks;
|
size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
|
||||||
while (n_blocks / gy > INT32_MAX) {
|
while (n_blocks / gy > INT32_MAX) {
|
||||||
gy *= 2;
|
gy *= 2;
|
||||||
}
|
}
|
||||||
@@ -277,7 +298,8 @@ void col_reduce_looped(
|
|||||||
0,
|
0,
|
||||||
indata,
|
indata,
|
||||||
gpu_ptr<U>(out),
|
gpu_ptr<U>(out),
|
||||||
static_cast<cu::ColReduceArgs>(args));
|
static_cast<cu::ColReduceArgs>(args),
|
||||||
|
out.size() / args.reduction_stride);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -320,6 +342,117 @@ void col_reduce_small(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void col_reduce_two_pass(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan,
|
||||||
|
const cu::ColReduceArgs& args) {
|
||||||
|
// Allocate data for the output using in's layout to access them as
|
||||||
|
// contiguously as possible.
|
||||||
|
allocate_same_layout(out, in, axes, encoder);
|
||||||
|
|
||||||
|
// Allocate an intermediate array to hold the 1st pass result
|
||||||
|
constexpr int outer = 32;
|
||||||
|
|
||||||
|
Shape intermediate_shape;
|
||||||
|
intermediate_shape.push_back(outer);
|
||||||
|
intermediate_shape.insert(
|
||||||
|
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||||
|
|
||||||
|
Strides intermediate_strides;
|
||||||
|
intermediate_strides.push_back(out.size());
|
||||||
|
intermediate_strides.insert(
|
||||||
|
intermediate_strides.end(), out.strides().begin(), out.strides().end());
|
||||||
|
|
||||||
|
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
|
||||||
|
auto [data_size, rc, cc] =
|
||||||
|
check_contiguity(intermediate_shape, intermediate_strides);
|
||||||
|
auto fl = out.flags();
|
||||||
|
fl.row_contiguous = rc;
|
||||||
|
fl.col_contiguous = cc;
|
||||||
|
fl.contiguous = true;
|
||||||
|
intermediate.set_data(
|
||||||
|
cu::malloc_async(intermediate.nbytes(), encoder),
|
||||||
|
data_size,
|
||||||
|
intermediate_strides,
|
||||||
|
fl,
|
||||||
|
allocator::free);
|
||||||
|
|
||||||
|
encoder.add_temporary(intermediate);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(intermediate);
|
||||||
|
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||||
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||||
|
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||||
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
|
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||||
|
T* indata = const_cast<T*>(gpu_ptr<T>(in));
|
||||||
|
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
constexpr int BM = 32;
|
||||||
|
constexpr int BN = 32;
|
||||||
|
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
|
||||||
|
int blocks = BM * BN / N_READS;
|
||||||
|
auto kernel = cu::
|
||||||
|
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid,
|
||||||
|
blocks,
|
||||||
|
0,
|
||||||
|
indata,
|
||||||
|
gpu_ptr<U>(intermediate),
|
||||||
|
static_cast<cu::ColReduceArgs>(args),
|
||||||
|
out.size() / args.reduction_stride);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Prepare the reduction arguments for the 2nd pass
|
||||||
|
cu::ColReduceArgs second_args = args;
|
||||||
|
second_args.reduction_size = outer;
|
||||||
|
second_args.reduction_stride = out.size();
|
||||||
|
second_args.ndim = 0;
|
||||||
|
second_args.reduce_shape[0] = outer;
|
||||||
|
second_args.reduce_strides[0] = out.size();
|
||||||
|
second_args.reduce_ndim = 1;
|
||||||
|
second_args.non_col_reductions = 1;
|
||||||
|
|
||||||
|
encoder.set_input_array(intermediate);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
|
||||||
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||||
|
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
|
||||||
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
constexpr int BM = 32;
|
||||||
|
constexpr int BN = 32;
|
||||||
|
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
|
||||||
|
int blocks = BM * BN / N_READS;
|
||||||
|
auto kernel =
|
||||||
|
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid,
|
||||||
|
blocks,
|
||||||
|
0,
|
||||||
|
gpu_ptr<T>(intermediate),
|
||||||
|
gpu_ptr<U>(out),
|
||||||
|
second_args,
|
||||||
|
second_args.reduction_stride);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void col_reduce(
|
void col_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
@@ -334,6 +467,18 @@ void col_reduce(
|
|||||||
// It is a general strided reduce. Each threadblock computes the output for
|
// It is a general strided reduce. Each threadblock computes the output for
|
||||||
// a subrow of the fast moving axis. For instance 32 elements.
|
// a subrow of the fast moving axis. For instance 32 elements.
|
||||||
//
|
//
|
||||||
|
// - col_reduce_small
|
||||||
|
//
|
||||||
|
// It is a column reduce for small columns. Each thread loops over the whole
|
||||||
|
// column without communicating with any other thread.
|
||||||
|
//
|
||||||
|
// - col_reduce_two_pass
|
||||||
|
//
|
||||||
|
// It is a reduce for long columns. To increase parallelism, we split the
|
||||||
|
// reduction in two passes. First we do a column reduce where many
|
||||||
|
// threadblocks operate on different parts of the reduced axis. Then we
|
||||||
|
// perform a final column reduce.
|
||||||
|
//
|
||||||
// Notes: As in row reduce we opt to read as much in order as possible and
|
// Notes: As in row reduce we opt to read as much in order as possible and
|
||||||
// leave transpositions as they are (contrary to our Metal backend).
|
// leave transpositions as they are (contrary to our Metal backend).
|
||||||
//
|
//
|
||||||
@@ -349,6 +494,14 @@ void col_reduce(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Long column with smallish row
|
||||||
|
size_t total_sums = args.non_col_reductions * args.reduction_size;
|
||||||
|
size_t approx_threads = out.size();
|
||||||
|
if (total_sums / approx_threads > 32) {
|
||||||
|
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Fallback col reduce
|
// Fallback col reduce
|
||||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,7 +80,6 @@ CudaGraph::CudaGraph(cu::Device& device) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CudaGraph::end_capture(cudaStream_t stream) {
|
void CudaGraph::end_capture(cudaStream_t stream) {
|
||||||
assert(handle_ == nullptr);
|
|
||||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
|
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
|
|
||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -149,7 +149,9 @@ Buffer MetalAllocator::malloc(size_t size) {
|
|||||||
buf = device_->newBuffer(size, resource_options);
|
buf = device_->newBuffer(size, resource_options);
|
||||||
}
|
}
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
return Buffer{nullptr};
|
std::ostringstream msg;
|
||||||
|
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
lk.lock();
|
lk.lock();
|
||||||
num_resources_++;
|
num_resources_++;
|
||||||
@@ -201,6 +203,32 @@ size_t MetalAllocator::size(Buffer buffer) const {
|
|||||||
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
|
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Buffer MetalAllocator::make_buffer(void* ptr, size_t size) {
|
||||||
|
auto buf = device_->newBuffer(ptr, size, resource_options, nullptr);
|
||||||
|
if (!buf) {
|
||||||
|
return Buffer{nullptr};
|
||||||
|
}
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
|
residency_set_.insert(buf);
|
||||||
|
active_memory_ += buf->length();
|
||||||
|
peak_memory_ = std::max(peak_memory_, active_memory_);
|
||||||
|
num_resources_++;
|
||||||
|
return Buffer{static_cast<void*>(buf)};
|
||||||
|
}
|
||||||
|
|
||||||
|
void MetalAllocator::release(Buffer buffer) {
|
||||||
|
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||||
|
if (buf == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
|
active_memory_ -= buf->length();
|
||||||
|
num_resources_--;
|
||||||
|
lk.unlock();
|
||||||
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
|
buf->release();
|
||||||
|
}
|
||||||
|
|
||||||
MetalAllocator& allocator() {
|
MetalAllocator& allocator() {
|
||||||
// By creating the |allocator_| on heap, the destructor of MetalAllocator
|
// By creating the |allocator_| on heap, the destructor of MetalAllocator
|
||||||
// will not be called on exit and buffers in the cache will be leaked. This
|
// will not be called on exit and buffers in the cache will be leaked. This
|
||||||
|
|||||||
@@ -21,6 +21,9 @@ class MetalAllocator : public allocator::Allocator {
|
|||||||
virtual Buffer malloc(size_t size) override;
|
virtual Buffer malloc(size_t size) override;
|
||||||
virtual void free(Buffer buffer) override;
|
virtual void free(Buffer buffer) override;
|
||||||
virtual size_t size(Buffer buffer) const override;
|
virtual size_t size(Buffer buffer) const override;
|
||||||
|
virtual Buffer make_buffer(void* ptr, size_t size) override;
|
||||||
|
virtual void release(Buffer buffer) override;
|
||||||
|
|
||||||
size_t get_active_memory() {
|
size_t get_active_memory() {
|
||||||
return active_memory_;
|
return active_memory_;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ class CommonAllocator : public Allocator {
|
|||||||
virtual Buffer malloc(size_t size) override;
|
virtual Buffer malloc(size_t size) override;
|
||||||
virtual void free(Buffer buffer) override;
|
virtual void free(Buffer buffer) override;
|
||||||
virtual size_t size(Buffer buffer) const override;
|
virtual size_t size(Buffer buffer) const override;
|
||||||
|
|
||||||
size_t get_active_memory() const {
|
size_t get_active_memory() const {
|
||||||
return active_memory_;
|
return active_memory_;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -4,6 +4,11 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
||||||
|
|
||||||
|
if(MLX_BUILD_CPU AND NOT WIN32)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "mlx/backend/cuda/cuda.h"
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/distributed/jaccl/jaccl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
#include "mlx/distributed/nccl/nccl.h"
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
#include "mlx/distributed/ring/ring.h"
|
#include "mlx/distributed/ring/ring.h"
|
||||||
@@ -102,7 +103,27 @@ class EmptyGroup : public GroupImpl {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
|
||||||
|
jaccl::is_available();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_available(const std::string& bk) {
|
||||||
|
if (bk == "any") {
|
||||||
|
return is_available();
|
||||||
|
}
|
||||||
|
if (bk == "mpi") {
|
||||||
|
return mpi::is_available();
|
||||||
|
}
|
||||||
|
if (bk == "ring") {
|
||||||
|
return ring::is_available();
|
||||||
|
}
|
||||||
|
if (bk == "nccl") {
|
||||||
|
return nccl::is_available();
|
||||||
|
}
|
||||||
|
if (bk == "jaccl") {
|
||||||
|
return jaccl::is_available();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
int Group::rank() const {
|
int Group::rank() const {
|
||||||
@@ -135,6 +156,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = ring::init(strict);
|
group = ring::init(strict);
|
||||||
} else if (bk == "nccl") {
|
} else if (bk == "nccl") {
|
||||||
group = nccl::init(strict);
|
group = nccl::init(strict);
|
||||||
|
} else if (bk == "jaccl") {
|
||||||
|
group = jaccl::init(strict);
|
||||||
} else if (bk == "any") {
|
} else if (bk == "any") {
|
||||||
if (mlx::core::cu::is_available()) {
|
if (mlx::core::cu::is_available()) {
|
||||||
group = nccl::init(false);
|
group = nccl::init(false);
|
||||||
@@ -148,13 +171,17 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = mpi::init(false);
|
group = mpi::init(false);
|
||||||
bk_ = "mpi";
|
bk_ = "mpi";
|
||||||
}
|
}
|
||||||
|
if (group == nullptr) {
|
||||||
|
group = jaccl::init(false);
|
||||||
|
bk_ = "jaccl";
|
||||||
|
}
|
||||||
if (group == nullptr && strict) {
|
if (group == nullptr && strict) {
|
||||||
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
|
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
|
||||||
<< "and 'ring' but '" << bk << "' was provided.";
|
<< "'jaccl' and 'ring' but '" << bk << "' was provided.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ class GroupImpl;
|
|||||||
|
|
||||||
/* Check if a communication backend is available */
|
/* Check if a communication backend is available */
|
||||||
bool is_available();
|
bool is_available();
|
||||||
|
bool is_available(const std::string& bk);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A distributed::Group represents a group of independent mlx processes that
|
* A distributed::Group represents a group of independent mlx processes that
|
||||||
|
|||||||
8
mlx/distributed/jaccl/CMakeLists.txt
Normal file
8
mlx/distributed/jaccl/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
if(MLX_BUILD_CPU
|
||||||
|
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
|
||||||
|
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp)
|
||||||
|
target_link_libraries(mlx PRIVATE rdma)
|
||||||
|
else()
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)
|
||||||
|
endif()
|
||||||
1123
mlx/distributed/jaccl/jaccl.cpp
Normal file
1123
mlx/distributed/jaccl/jaccl.cpp
Normal file
File diff suppressed because it is too large
Load Diff
12
mlx/distributed/jaccl/jaccl.h
Normal file
12
mlx/distributed/jaccl/jaccl.h
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::jaccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::jaccl
|
||||||
20
mlx/distributed/jaccl/no_jaccl.cpp
Normal file
20
mlx/distributed/jaccl/no_jaccl.cpp
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/jaccl/jaccl.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::jaccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
if (strict) {
|
||||||
|
throw std::runtime_error("Cannot initialize jaccl distributed backend.");
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::jaccl
|
||||||
38
mlx/distributed/reduction_ops.h
Normal file
38
mlx/distributed/reduction_ops.h
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::detail {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SumOp {
|
||||||
|
void operator()(const T* input, T* output, size_t N) const {
|
||||||
|
while (N-- > 0) {
|
||||||
|
*output += *input;
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct MaxOp {
|
||||||
|
void operator()(const T* input, T* output, size_t N) const {
|
||||||
|
while (N-- > 0) {
|
||||||
|
*output = std::max(*output, *input);
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct MinOp {
|
||||||
|
void operator()(const T* input, T* output, size_t N) const {
|
||||||
|
while (N-- > 0) {
|
||||||
|
*output = std::min(*output, *input);
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::detail
|
||||||
@@ -1,9 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include <arpa/inet.h>
|
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <netdb.h>
|
|
||||||
#include <netinet/in.h>
|
|
||||||
#include <netinet/tcp.h>
|
#include <netinet/tcp.h>
|
||||||
#include <sys/socket.h>
|
#include <sys/socket.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
@@ -22,6 +19,8 @@
|
|||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/distributed/reduction_ops.h"
|
||||||
|
#include "mlx/distributed/utils.h"
|
||||||
#include "mlx/threadpool.h"
|
#include "mlx/threadpool.h"
|
||||||
|
|
||||||
#ifndef SOL_TCP
|
#ifndef SOL_TCP
|
||||||
@@ -94,6 +93,7 @@ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
|
|||||||
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
||||||
constexpr const int CONN_ATTEMPTS = 5;
|
constexpr const int CONN_ATTEMPTS = 5;
|
||||||
constexpr const int CONN_WAIT = 1000;
|
constexpr const int CONN_WAIT = 1000;
|
||||||
|
constexpr const char* RING_TAG = "[ring]";
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
@@ -296,55 +296,6 @@ class CommunicationThreads {
|
|||||||
std::unordered_map<int, SocketThread> threads_;
|
std::unordered_map<int, SocketThread> threads_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct address_t {
|
|
||||||
sockaddr_storage addr;
|
|
||||||
socklen_t len;
|
|
||||||
|
|
||||||
const sockaddr* get() const {
|
|
||||||
return (struct sockaddr*)&addr;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse a sockaddr from an ip and port provided as strings.
|
|
||||||
*/
|
|
||||||
address_t parse_address(const std::string& ip, const std::string& port) {
|
|
||||||
struct addrinfo hints, *res;
|
|
||||||
memset(&hints, 0, sizeof(hints));
|
|
||||||
hints.ai_family = AF_UNSPEC;
|
|
||||||
hints.ai_socktype = SOCK_STREAM;
|
|
||||||
|
|
||||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
|
||||||
if (status != 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Can't parse address " << ip << ":" << port;
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
address_t result;
|
|
||||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
|
||||||
result.len = res->ai_addrlen;
|
|
||||||
freeaddrinfo(res);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
|
||||||
*/
|
|
||||||
address_t parse_address(const std::string& ip_port) {
|
|
||||||
auto colon = ip_port.find(":");
|
|
||||||
if (colon == std::string::npos) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Can't parse address " << ip_port;
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
|
||||||
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
|
||||||
|
|
||||||
return parse_address(ip, port);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load all addresses from the json hostfile. The hostfile is a list of
|
* Load all addresses from the json hostfile. The hostfile is a list of
|
||||||
* addresses in order of rank. For each rank there can be many addresses so
|
* addresses in order of rank. For each rank there can be many addresses so
|
||||||
@@ -357,15 +308,15 @@ address_t parse_address(const std::string& ip_port) {
|
|||||||
* ["ip3:5000", "ip3:5001"],
|
* ["ip3:5000", "ip3:5001"],
|
||||||
* ]
|
* ]
|
||||||
*/
|
*/
|
||||||
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
|
||||||
std::vector<std::vector<address_t>> nodes;
|
std::vector<std::vector<detail::address_t>> nodes;
|
||||||
std::ifstream f(hostfile);
|
std::ifstream f(hostfile);
|
||||||
|
|
||||||
json hosts = json::parse(f);
|
json hosts = json::parse(f);
|
||||||
for (auto& h : hosts) {
|
for (auto& h : hosts) {
|
||||||
std::vector<address_t> host;
|
std::vector<detail::address_t> host;
|
||||||
for (auto& ips : h) {
|
for (auto& ips : h) {
|
||||||
host.push_back(parse_address(ips.get<std::string>()));
|
host.push_back(std::move(detail::parse_address(ips.get<std::string>())));
|
||||||
}
|
}
|
||||||
nodes.push_back(std::move(host));
|
nodes.push_back(std::move(host));
|
||||||
}
|
}
|
||||||
@@ -377,73 +328,15 @@ std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
|||||||
* Create a socket and accept one connection for each of the provided
|
* Create a socket and accept one connection for each of the provided
|
||||||
* addresses.
|
* addresses.
|
||||||
*/
|
*/
|
||||||
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
std::vector<int> accept_connections(
|
||||||
|
const std::vector<detail::address_t>& addresses) {
|
||||||
std::vector<int> sockets;
|
std::vector<int> sockets;
|
||||||
int success;
|
int success;
|
||||||
|
|
||||||
for (auto& address : addresses) {
|
for (auto& address : addresses) {
|
||||||
// Create the socket to wait for connections from the peers
|
detail::TCPSocket socket(RING_TAG);
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
socket.listen(RING_TAG, address);
|
||||||
if (sock < 0) {
|
sockets.push_back(socket.accept(RING_TAG).detach());
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure we can launch immediately after shutdown by setting the
|
|
||||||
// reuseaddr option so that we don't get address already in use errors
|
|
||||||
int enable = 1;
|
|
||||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bind the socket to the address and port
|
|
||||||
success = bind(sock, address.get(), address.len);
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for connections
|
|
||||||
success = listen(sock, 0);
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't listen (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
int peer_socket = accept(sock, nullptr, nullptr);
|
|
||||||
if (peer_socket < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Accept failed (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the listening socket
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
|
|
||||||
sockets.push_back(peer_socket);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sockets;
|
return sockets;
|
||||||
@@ -454,93 +347,42 @@ std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
|||||||
* provided addresses.
|
* provided addresses.
|
||||||
*/
|
*/
|
||||||
std::vector<int> make_connections(
|
std::vector<int> make_connections(
|
||||||
const std::vector<address_t>& addresses,
|
const std::vector<detail::address_t>& addresses,
|
||||||
bool verbose) {
|
bool verbose) {
|
||||||
std::vector<int> sockets;
|
std::vector<int> sockets;
|
||||||
int success;
|
int success;
|
||||||
|
|
||||||
for (auto& address : addresses) {
|
for (auto& address : addresses) {
|
||||||
int sock;
|
sockets.push_back(detail::TCPSocket::connect(
|
||||||
|
RING_TAG,
|
||||||
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
|
address,
|
||||||
// backoff. TODO: Do we need that?
|
CONN_ATTEMPTS,
|
||||||
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
|
CONN_WAIT,
|
||||||
// Create the socket
|
[verbose](int attempt, int wait) {
|
||||||
sock = socket(AF_INET, SOCK_STREAM, 0);
|
log_info(
|
||||||
if (sock < 0) {
|
verbose,
|
||||||
std::ostringstream msg;
|
"Attempt",
|
||||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
attempt,
|
||||||
throw std::runtime_error(msg.str());
|
"waiting",
|
||||||
}
|
wait,
|
||||||
|
"ms (error:",
|
||||||
if (attempt > 0) {
|
errno,
|
||||||
int wait = (1 << (attempt - 1)) * CONN_WAIT;
|
")");
|
||||||
log_info(
|
})
|
||||||
verbose,
|
.detach());
|
||||||
"Attempt",
|
|
||||||
attempt,
|
|
||||||
"wait",
|
|
||||||
wait,
|
|
||||||
"ms (error:",
|
|
||||||
errno,
|
|
||||||
")");
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
|
||||||
}
|
|
||||||
|
|
||||||
success = connect(sock, address.get(), address.len);
|
|
||||||
if (success == 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (success < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't connect (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
sockets.push_back(sock);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sockets;
|
return sockets;
|
||||||
}
|
}
|
||||||
template <typename T>
|
|
||||||
struct SumOp {
|
|
||||||
void operator()(const T* input, T* output, size_t N) {
|
|
||||||
while (N-- > 0) {
|
|
||||||
*output += *input;
|
|
||||||
input++;
|
|
||||||
output++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct MaxOp {
|
|
||||||
void operator()(const T* input, T* output, size_t N) {
|
|
||||||
while (N-- > 0) {
|
|
||||||
*output = std::max(*output, *input);
|
|
||||||
input++;
|
|
||||||
output++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct MinOp {
|
|
||||||
void operator()(const T* input, T* output, size_t N) {
|
|
||||||
while (N-- > 0) {
|
|
||||||
*output = std::min(*output, *input);
|
|
||||||
input++;
|
|
||||||
output++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class RingGroup : public GroupImpl {
|
class RingGroup : public GroupImpl {
|
||||||
public:
|
public:
|
||||||
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
|
RingGroup(
|
||||||
|
int rank,
|
||||||
|
std::vector<std::vector<detail::address_t>> nodes,
|
||||||
|
bool verbose)
|
||||||
: rank_(rank), verbose_(verbose), pool_(0) {
|
: rank_(rank), verbose_(verbose), pool_(0) {
|
||||||
if (rank_ > 0 && rank_ >= nodes.size()) {
|
if (rank_ > 0 && rank_ >= nodes.size()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
@@ -633,17 +475,17 @@ class RingGroup : public GroupImpl {
|
|||||||
|
|
||||||
void all_sum(const array& input, array& output, Stream stream) override {
|
void all_sum(const array& input, array& output, Stream stream) override {
|
||||||
SWITCH_TYPE(
|
SWITCH_TYPE(
|
||||||
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
|
output, all_reduce<T>(input, output, stream, detail::SumOp<T>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_max(const array& input, array& output, Stream stream) override {
|
void all_max(const array& input, array& output, Stream stream) override {
|
||||||
SWITCH_TYPE(
|
SWITCH_TYPE(
|
||||||
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
|
output, all_reduce<T>(input, output, stream, detail::MaxOp<T>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_min(const array& input, array& output, Stream stream) override {
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
SWITCH_TYPE(
|
SWITCH_TYPE(
|
||||||
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
|
output, all_reduce<T>(input, output, stream, detail::MinOp<T>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
|
|||||||
204
mlx/distributed/utils.cpp
Normal file
204
mlx/distributed/utils.cpp
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <netdb.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <cstring>
|
||||||
|
#include <sstream>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include "mlx/distributed/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::detail {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr from an ip and port provided as strings.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip, const std::string& port) {
|
||||||
|
struct addrinfo hints, *res;
|
||||||
|
std::memset(&hints, 0, sizeof(hints));
|
||||||
|
hints.ai_family = AF_UNSPEC;
|
||||||
|
hints.ai_socktype = SOCK_STREAM;
|
||||||
|
|
||||||
|
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||||
|
if (status != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Can't parse address " << ip << ":" << port;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
address_t result;
|
||||||
|
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||||
|
result.len = res->ai_addrlen;
|
||||||
|
freeaddrinfo(res);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip_port) {
|
||||||
|
auto colon = ip_port.find(":");
|
||||||
|
if (colon == std::string::npos) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Can't parse address " << ip_port;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
||||||
|
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
||||||
|
|
||||||
|
return parse_address(ip, port);
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket::TCPSocket(const char* tag) {
|
||||||
|
sock_ = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock_ < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't create socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket::TCPSocket(TCPSocket&& s) {
|
||||||
|
sock_ = s.sock_;
|
||||||
|
s.sock_ = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket& TCPSocket::operator=(TCPSocket&& s) {
|
||||||
|
if (this != &s) {
|
||||||
|
sock_ = s.sock_;
|
||||||
|
s.sock_ = -1;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket::TCPSocket(int s) : sock_(s) {}
|
||||||
|
|
||||||
|
TCPSocket::~TCPSocket() {
|
||||||
|
if (sock_ > 0) {
|
||||||
|
shutdown(sock_, 2);
|
||||||
|
close(sock_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int TCPSocket::detach() {
|
||||||
|
int s = sock_;
|
||||||
|
sock_ = -1;
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TCPSocket::listen(const char* tag, const address_t& addr) {
|
||||||
|
int success;
|
||||||
|
|
||||||
|
// Make sure we can launch immediately after shutdown by setting the
|
||||||
|
// reuseaddr option so that we don't get address already in use errors
|
||||||
|
int enable = 1;
|
||||||
|
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't enable reuseport (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind the socket to the address and port
|
||||||
|
success = bind(sock_, addr.get(), addr.len);
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't bind socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare waiting for connections
|
||||||
|
success = ::listen(sock_, 0);
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't listen (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket TCPSocket::accept(const char* tag) {
|
||||||
|
int peer = ::accept(sock_, nullptr, nullptr);
|
||||||
|
if (peer < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Accept failed (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return TCPSocket(peer);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TCPSocket::send(const char* tag, const void* data, size_t len) {
|
||||||
|
while (len > 0) {
|
||||||
|
auto n = ::send(sock_, data, len, 0);
|
||||||
|
if (n <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Send failed with errno=" << errno;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
len -= n;
|
||||||
|
data = static_cast<const char*>(data) + n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TCPSocket::recv(const char* tag, void* data, size_t len) {
|
||||||
|
while (len > 0) {
|
||||||
|
auto n = ::recv(sock_, data, len, 0);
|
||||||
|
if (n <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Recv failed with errno=" << errno;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
len -= n;
|
||||||
|
data = static_cast<char*>(data) + n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket TCPSocket::connect(
|
||||||
|
const char* tag,
|
||||||
|
const address_t& addr,
|
||||||
|
int num_retries,
|
||||||
|
int wait,
|
||||||
|
std::function<void(int, int)> cb) {
|
||||||
|
int sock, success;
|
||||||
|
|
||||||
|
// Attempt to connect `num_retries` times with exponential backoff.
|
||||||
|
for (int attempt = 0; attempt < num_retries; attempt++) {
|
||||||
|
// Create the socket
|
||||||
|
sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't create socket to connect (error: " << errno
|
||||||
|
<< ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
success = ::connect(sock, addr.get(), addr.len);
|
||||||
|
if (success == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(attempt, wait);
|
||||||
|
if (wait > 0) {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||||
|
}
|
||||||
|
|
||||||
|
wait <<= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't connect (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return TCPSocket(sock);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::detail
|
||||||
67
mlx/distributed/utils.h
Normal file
67
mlx/distributed/utils.h
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::detail {
|
||||||
|
|
||||||
|
struct address_t {
|
||||||
|
sockaddr_storage addr;
|
||||||
|
socklen_t len;
|
||||||
|
|
||||||
|
const sockaddr* get() const {
|
||||||
|
return (struct sockaddr*)&addr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr from an ip and port provided as strings.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip, const std::string& port);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip_port);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Small wrapper over a TCP socket to simplify initiating connections.
|
||||||
|
*/
|
||||||
|
class TCPSocket {
|
||||||
|
public:
|
||||||
|
TCPSocket(const char* tag);
|
||||||
|
TCPSocket(const TCPSocket&) = delete;
|
||||||
|
TCPSocket& operator=(const TCPSocket&) = delete;
|
||||||
|
TCPSocket(TCPSocket&& s);
|
||||||
|
TCPSocket& operator=(TCPSocket&&);
|
||||||
|
~TCPSocket();
|
||||||
|
|
||||||
|
void listen(const char* tag, const address_t& addr);
|
||||||
|
TCPSocket accept(const char* tag);
|
||||||
|
|
||||||
|
void send(const char* tag, const void* data, size_t len);
|
||||||
|
void recv(const char* tag, void* data, size_t len);
|
||||||
|
|
||||||
|
int detach();
|
||||||
|
|
||||||
|
operator int() const {
|
||||||
|
return sock_;
|
||||||
|
}
|
||||||
|
|
||||||
|
static TCPSocket connect(
|
||||||
|
const char* tag,
|
||||||
|
const address_t& addr,
|
||||||
|
int num_retries = 1,
|
||||||
|
int wait = 0,
|
||||||
|
std::function<void(int, int)> cb = nullptr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
TCPSocket(int sock);
|
||||||
|
|
||||||
|
int sock_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::detail
|
||||||
95
python/mlx/_distributed_utils/common.py
Normal file
95
python/mlx/_distributed_utils/common.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import ipaddress
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Host:
|
||||||
|
rank: int
|
||||||
|
ssh_hostname: str
|
||||||
|
ips: list[str]
|
||||||
|
rdma: list[Optional[str]]
|
||||||
|
|
||||||
|
|
||||||
|
class OptionalBoolAction(argparse.Action):
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
if option_string.startswith("--no-"):
|
||||||
|
setattr(namespace, self.dest, False)
|
||||||
|
else:
|
||||||
|
setattr(namespace, self.dest, True)
|
||||||
|
|
||||||
|
|
||||||
|
def positive_number(x):
|
||||||
|
x = int(x)
|
||||||
|
if x <= 0:
|
||||||
|
raise ValueError("Number should be positive")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def log(verbose, *args, **kwargs):
|
||||||
|
if not verbose:
|
||||||
|
return
|
||||||
|
kwargs["file"] = sys.stderr
|
||||||
|
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def log_warning(*args, **kwargs):
|
||||||
|
kwargs["file"] = sys.stderr
|
||||||
|
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def log_error(*args, **kwargs):
|
||||||
|
kwargs["file"] = sys.stderr
|
||||||
|
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_hostlist(parser, hostlist, repeats):
|
||||||
|
hosts = []
|
||||||
|
for i, h in enumerate(hostlist.split(",")):
|
||||||
|
if h == "":
|
||||||
|
raise ValueError("Hostname cannot be empty")
|
||||||
|
try:
|
||||||
|
ipaddress.ip_address(h)
|
||||||
|
ips = [h]
|
||||||
|
except ValueError:
|
||||||
|
ips = []
|
||||||
|
for i in range(repeats):
|
||||||
|
hosts.append(Host(i, h, ips, []))
|
||||||
|
return hosts
|
||||||
|
|
||||||
|
|
||||||
|
def parse_hostfile(parser, hostfile):
|
||||||
|
"""Parse the json hostfile that contains both the hostnames to ssh into and
|
||||||
|
the ips to communicate over when using the ring backend.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
[
|
||||||
|
{"ssh": "hostname1", "ips": ["123.123.123.1"], "rdma": [null, "rdma_en2", "rdma_en3"]},
|
||||||
|
{"ssh": "hostname2", "ips": ["123.123.123.2"], "rdma": ["rdma_en2", null, "rdma_en3"]},
|
||||||
|
...
|
||||||
|
{"ssh": "hostnameN", "ips": ["123.123.123.N"], "rdma": ["rdma_en2", "rdma_en3", null]},
|
||||||
|
]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hostfile (str): The path to the json file containing the host
|
||||||
|
information
|
||||||
|
"""
|
||||||
|
hostfile = Path(hostfile)
|
||||||
|
if not hostfile.exists():
|
||||||
|
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
|
||||||
|
|
||||||
|
try:
|
||||||
|
hosts = []
|
||||||
|
with open(hostfile) as f:
|
||||||
|
for i, h in enumerate(json.load(f)):
|
||||||
|
hosts.append(Host(i, h["ssh"], h.get("ips", []), h.get("rdma", [])))
|
||||||
|
return hosts
|
||||||
|
except Exception as e:
|
||||||
|
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")
|
||||||
570
python/mlx/_distributed_utils/config.py
Normal file
570
python/mlx/_distributed_utils/config.py
Normal file
@@ -0,0 +1,570 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import shlex
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from subprocess import DEVNULL, run
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from .common import (
|
||||||
|
Host,
|
||||||
|
OptionalBoolAction,
|
||||||
|
log,
|
||||||
|
log_error,
|
||||||
|
parse_hostfile,
|
||||||
|
parse_hostlist,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SSHInfo:
|
||||||
|
can_ssh: bool
|
||||||
|
has_sudo: bool
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return self.can_ssh
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ThunderboltPort:
|
||||||
|
iface: str
|
||||||
|
uuid: str
|
||||||
|
connected_to: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ThunderboltHost:
|
||||||
|
name: str
|
||||||
|
ports: list[ThunderboltPort]
|
||||||
|
|
||||||
|
|
||||||
|
def add_ethernet_ips(hosts, verbose=False):
|
||||||
|
# Get the ips for each host
|
||||||
|
for h in hosts:
|
||||||
|
log(verbose, "Getting the ip from", h.ssh_hostname)
|
||||||
|
h.ips.append(
|
||||||
|
run(
|
||||||
|
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
).stdout.strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_rdma(hosts, verbose=False):
|
||||||
|
# Check whether the hosts are capable of RDMA over thunderbolt
|
||||||
|
warn = False
|
||||||
|
for h in hosts:
|
||||||
|
log(verbose, "Checking that", h.ssh_hostname, "supports RDMA")
|
||||||
|
rdma_devs = (
|
||||||
|
run(["ssh", h.ssh_hostname, "ibv_devices"], capture_output=True, text=True)
|
||||||
|
.stdout.strip()
|
||||||
|
.split()
|
||||||
|
)
|
||||||
|
rdma_devs = [d for d in rdma_devs if d.startswith("rdma_")]
|
||||||
|
if not rdma_devs:
|
||||||
|
log_warning(h.ssh_hostname, "does not seem to have RDMA enabled")
|
||||||
|
warn = True
|
||||||
|
|
||||||
|
if warn:
|
||||||
|
log_warning()
|
||||||
|
log_warning(
|
||||||
|
"Some of the hosts don't have RDMA enabled or they don't support RDMA."
|
||||||
|
)
|
||||||
|
log_warning()
|
||||||
|
log_warning(
|
||||||
|
"See https://ml-explore.github.io/mlx/build/html/usage/distributed.html"
|
||||||
|
)
|
||||||
|
log_warning("for instructions on how to enable RDMA.")
|
||||||
|
|
||||||
|
|
||||||
|
def can_auto_setup(hosts, sshinfo, auto_setup=False):
|
||||||
|
has_sudo = all(info.has_sudo for info in sshinfo)
|
||||||
|
if not has_sudo and auto_setup:
|
||||||
|
log_warning(
|
||||||
|
"Automatic setup requested but the following hosts do not have passwordless sudo"
|
||||||
|
)
|
||||||
|
for h, i in zip(hosts, sshinfo):
|
||||||
|
if not i.has_sudo:
|
||||||
|
log_warning(" - ", h.ssh_hostname)
|
||||||
|
return has_sudo
|
||||||
|
|
||||||
|
|
||||||
|
class IPConfigurator:
|
||||||
|
def __init__(self, hosts, tb_hosts, uuid_reverse_index):
|
||||||
|
assigned = set()
|
||||||
|
ips = defaultdict(list)
|
||||||
|
ip0 = 0
|
||||||
|
ip1 = 0
|
||||||
|
for src_node, h in enumerate(tb_hosts):
|
||||||
|
for src_port, p in enumerate(h.ports):
|
||||||
|
if not p.connected_to:
|
||||||
|
continue
|
||||||
|
if p.connected_to not in uuid_reverse_index:
|
||||||
|
continue
|
||||||
|
if (src_node, src_port) in assigned:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dst_node, dst_port = uuid_reverse_index[p.connected_to]
|
||||||
|
|
||||||
|
ip_src = f"192.168.{ip0}.{ip1 + 1}"
|
||||||
|
ip_dst = f"192.168.{ip0}.{ip1 + 2}"
|
||||||
|
iface_src = p.iface
|
||||||
|
iface_dst = tb_hosts[dst_node].ports[dst_port].iface
|
||||||
|
|
||||||
|
ips[src_node, dst_node].append((iface_src, ip_src))
|
||||||
|
ips[dst_node, src_node].append((iface_dst, ip_dst))
|
||||||
|
|
||||||
|
assigned.add((src_node, src_port))
|
||||||
|
assigned.add((dst_node, dst_port))
|
||||||
|
|
||||||
|
ip1 += 4
|
||||||
|
if ip1 > 255:
|
||||||
|
ip0 += 1
|
||||||
|
ip1 = 0
|
||||||
|
if ip0 > 255:
|
||||||
|
raise ValueError("Ran out of available local IPs")
|
||||||
|
|
||||||
|
self.ips = ips
|
||||||
|
self.hosts = hosts
|
||||||
|
self.tb_hosts = tb_hosts
|
||||||
|
|
||||||
|
def setup(self, verbose=False, auto_setup=False):
|
||||||
|
netmask = "255.255.255.252"
|
||||||
|
for i, (h, th) in enumerate(zip(self.hosts, self.tb_hosts)):
|
||||||
|
command = ""
|
||||||
|
command += "sudo ifconfig bridge0 down\n"
|
||||||
|
for j in range(len(self.hosts)):
|
||||||
|
if i == j or (i, j) not in self.ips:
|
||||||
|
continue
|
||||||
|
for (iface, ip), (_, peer) in zip(self.ips[i, j], self.ips[j, i]):
|
||||||
|
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
|
||||||
|
command += f"sudo route change {peer} -interface {iface}\n"
|
||||||
|
if auto_setup:
|
||||||
|
print(f"Running auto setup for {h.ssh_hostname}")
|
||||||
|
command = command.strip().replace("\n", " ; ")
|
||||||
|
command = ["ssh", h.ssh_hostname, command]
|
||||||
|
log(verbose, shlex.join(command))
|
||||||
|
run(command)
|
||||||
|
else:
|
||||||
|
msg = f"Setup for {h.ssh_hostname}"
|
||||||
|
print(msg)
|
||||||
|
print("=" * len(msg))
|
||||||
|
print(command)
|
||||||
|
input("Enter to continue")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_hardware_ports(ports_string):
|
||||||
|
ports = {}
|
||||||
|
port_name = None
|
||||||
|
for l in ports_string.decode("utf-8").split("\n"):
|
||||||
|
if l.startswith("Hardware Port:"):
|
||||||
|
port_name = l.strip()[15:]
|
||||||
|
elif l.startswith("Device:"):
|
||||||
|
ports[port_name] = l.strip()[8:]
|
||||||
|
port_name = None
|
||||||
|
return ports
|
||||||
|
|
||||||
|
|
||||||
|
def extract_connectivity(hosts, verbose):
|
||||||
|
# Extract the current connectivity from the remote hosts
|
||||||
|
thunderbolt_connections = []
|
||||||
|
for h in hosts:
|
||||||
|
log(verbose, "Getting connectivity from", h.ssh_hostname)
|
||||||
|
thunderbolt_connections.append(
|
||||||
|
json.loads(
|
||||||
|
run(
|
||||||
|
[
|
||||||
|
"ssh",
|
||||||
|
h.ssh_hostname,
|
||||||
|
"system_profiler",
|
||||||
|
"SPThunderboltDataType",
|
||||||
|
"-json",
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
).stdout
|
||||||
|
)
|
||||||
|
)
|
||||||
|
interface_maps = []
|
||||||
|
for h in hosts:
|
||||||
|
log(verbose, "Getting interface names from", h.ssh_hostname)
|
||||||
|
interface_maps.append(
|
||||||
|
parse_hardware_ports(
|
||||||
|
run(
|
||||||
|
[
|
||||||
|
"ssh",
|
||||||
|
h.ssh_hostname,
|
||||||
|
"networksetup",
|
||||||
|
"-listallhardwareports",
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
).stdout
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the connectivity into some simple dataclasses
|
||||||
|
tb_hosts = []
|
||||||
|
for c, iface_map in zip(thunderbolt_connections, interface_maps):
|
||||||
|
name = ""
|
||||||
|
ports = []
|
||||||
|
for t in c["SPThunderboltDataType"]:
|
||||||
|
uuid = t.get("domain_uuid_key")
|
||||||
|
if uuid is None:
|
||||||
|
continue
|
||||||
|
name = t["device_name_key"]
|
||||||
|
tag = t["receptacle_1_tag"]["receptacle_id_key"]
|
||||||
|
items = t.get("_items", [])
|
||||||
|
connected_items = [item for item in items if "domain_uuid_key" in item]
|
||||||
|
connected_to = (
|
||||||
|
connected_items[0]["domain_uuid_key"] if connected_items else None
|
||||||
|
)
|
||||||
|
iface = iface_map[f"Thunderbolt {tag}"]
|
||||||
|
ports.append(ThunderboltPort(iface, uuid, connected_to))
|
||||||
|
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
|
||||||
|
|
||||||
|
# Create a reverse index to be able to map uuids to (host, port) quickly
|
||||||
|
uuid_reverse_index = {}
|
||||||
|
for i, h in enumerate(tb_hosts):
|
||||||
|
for j, p in enumerate(h.ports):
|
||||||
|
uuid_reverse_index[p.uuid] = (i, j)
|
||||||
|
|
||||||
|
return tb_hosts, uuid_reverse_index
|
||||||
|
|
||||||
|
|
||||||
|
def make_connectivity_matrix(tb_hosts, uuid_reverse_index):
|
||||||
|
connectivity = []
|
||||||
|
for i, h in enumerate(tb_hosts):
|
||||||
|
c = [0] * len(tb_hosts)
|
||||||
|
for p in h.ports:
|
||||||
|
if p.connected_to in uuid_reverse_index:
|
||||||
|
j, _ = uuid_reverse_index[p.connected_to]
|
||||||
|
c[j] += 1
|
||||||
|
connectivity.append(c)
|
||||||
|
return connectivity
|
||||||
|
|
||||||
|
|
||||||
|
def tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index):
|
||||||
|
# Make ids per node
|
||||||
|
names = []
|
||||||
|
for i in range(len(tb_hosts)):
|
||||||
|
n = ""
|
||||||
|
j = i
|
||||||
|
while True:
|
||||||
|
n += chr(97 + j % 26)
|
||||||
|
j //= 26
|
||||||
|
if j == 0:
|
||||||
|
break
|
||||||
|
names.append(n)
|
||||||
|
|
||||||
|
print("graph G {")
|
||||||
|
print(" node [shape=rectangle];")
|
||||||
|
for i, h in enumerate(hosts):
|
||||||
|
print(f' {names[i]} [label="{h.ssh_hostname}"];')
|
||||||
|
for i, h in enumerate(tb_hosts):
|
||||||
|
for p in h.ports:
|
||||||
|
if not p.connected_to:
|
||||||
|
continue
|
||||||
|
dst = uuid_reverse_index[p.connected_to]
|
||||||
|
if dst[0] < i:
|
||||||
|
continue
|
||||||
|
print(f" {names[i]} -- {names[dst[0]]}", end="")
|
||||||
|
print(f' [label="{p.iface}/{tb_hosts[dst[0]].ports[dst[1]].iface}"]')
|
||||||
|
print("}")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_rings(connectivity):
|
||||||
|
rings = []
|
||||||
|
existing_rings = set()
|
||||||
|
num_nodes = len(connectivity)
|
||||||
|
|
||||||
|
def dfs(start_node, node, path, visited):
|
||||||
|
path.append(node)
|
||||||
|
visited.add(node)
|
||||||
|
for j in range(num_nodes):
|
||||||
|
if connectivity[node][j] <= 0:
|
||||||
|
continue
|
||||||
|
if j == start_node:
|
||||||
|
yield path[:]
|
||||||
|
if j not in visited:
|
||||||
|
yield from dfs(start_node, j, path, visited)
|
||||||
|
path.pop()
|
||||||
|
visited.remove(node)
|
||||||
|
|
||||||
|
for start in range(num_nodes):
|
||||||
|
for r in dfs(start, start, [], set()):
|
||||||
|
cnt = min(connectivity[r[i]][r[(i + 1) % len(r)]] for i in range(len(r)))
|
||||||
|
rkey = tuple(sorted(r))
|
||||||
|
if rkey not in existing_rings:
|
||||||
|
rings.append((r, cnt))
|
||||||
|
existing_rings.add(rkey)
|
||||||
|
|
||||||
|
return sorted(rings, key=lambda x: -len(x[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def check_valid_mesh(hosts, connectivity, strict=True):
|
||||||
|
num_nodes = len(connectivity)
|
||||||
|
for i in range(num_nodes):
|
||||||
|
for j in range(num_nodes):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
if connectivity[i][j] <= 0:
|
||||||
|
if strict:
|
||||||
|
log_error(
|
||||||
|
f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}"
|
||||||
|
)
|
||||||
|
log_error()
|
||||||
|
log_error("Try passing --dot to visualize the connectivity")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def check_ssh_connections(hosts):
|
||||||
|
results = [None] * len(hosts)
|
||||||
|
|
||||||
|
def _check(hostname, i):
|
||||||
|
info = SSHInfo(False, False)
|
||||||
|
results[i] = info
|
||||||
|
|
||||||
|
# Check for ssh
|
||||||
|
result = run(
|
||||||
|
[
|
||||||
|
"ssh",
|
||||||
|
"-o",
|
||||||
|
"BatchMode=yes",
|
||||||
|
"-o",
|
||||||
|
"ConnectTimeout=5",
|
||||||
|
hostname,
|
||||||
|
"echo",
|
||||||
|
"success",
|
||||||
|
],
|
||||||
|
stdout=DEVNULL,
|
||||||
|
stderr=DEVNULL,
|
||||||
|
)
|
||||||
|
info.can_ssh = result.returncode == 0
|
||||||
|
if not info.can_ssh:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check for sudo
|
||||||
|
result = run(
|
||||||
|
[
|
||||||
|
"ssh",
|
||||||
|
"-o",
|
||||||
|
"BatchMode=yes",
|
||||||
|
"-o",
|
||||||
|
"ConnectTimeout=5",
|
||||||
|
hostname,
|
||||||
|
"sudo",
|
||||||
|
"ls",
|
||||||
|
],
|
||||||
|
stdout=DEVNULL,
|
||||||
|
stderr=DEVNULL,
|
||||||
|
)
|
||||||
|
info.has_sudo = result.returncode == 0
|
||||||
|
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=_check, args=(h.ssh_hostname, i))
|
||||||
|
for i, h in enumerate(hosts)
|
||||||
|
]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
if not all(results):
|
||||||
|
log_error("Could not ssh to the following hosts:")
|
||||||
|
for i, h in enumerate(hosts):
|
||||||
|
if not results[i]:
|
||||||
|
log_error(" - ", h.ssh_hostname)
|
||||||
|
log_error()
|
||||||
|
log_error("Maybe they are not set-up for password-less ssh?")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_ethernet_hostfile(args, hosts):
|
||||||
|
log(args.verbose, f"Preparing an ethernet hostfile")
|
||||||
|
add_ethernet_ips(hosts, args.verbose)
|
||||||
|
|
||||||
|
hostfile = []
|
||||||
|
for h in hosts:
|
||||||
|
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
|
||||||
|
|
||||||
|
if args.output_hostfile:
|
||||||
|
with open(args.output_hostfile, "w") as f:
|
||||||
|
json.dump(hostfile, f, indent=4)
|
||||||
|
else:
|
||||||
|
print("Hostfile")
|
||||||
|
print("========")
|
||||||
|
print(json.dumps(hostfile, indent=4))
|
||||||
|
|
||||||
|
|
||||||
|
def configure_ring(args, hosts, ips, ring, sshinfo):
|
||||||
|
log(args.verbose, "Prepare a ring hostfile")
|
||||||
|
ring, count = ring
|
||||||
|
hostfile = []
|
||||||
|
for i, node in enumerate(ring):
|
||||||
|
h = hosts[node]
|
||||||
|
peer = ring[i - 1]
|
||||||
|
hostfile.append(
|
||||||
|
{
|
||||||
|
"ssh": h.ssh_hostname,
|
||||||
|
"ips": [ips.ips[node, peer][c][1] for c in range(count)],
|
||||||
|
"rdma": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
|
||||||
|
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
|
||||||
|
|
||||||
|
if args.output_hostfile:
|
||||||
|
with open(args.output_hostfile, "w") as f:
|
||||||
|
json.dump(hostfile, f, indent=4)
|
||||||
|
else:
|
||||||
|
print("Hostfile")
|
||||||
|
print("========")
|
||||||
|
print(json.dumps(hostfile, indent=4))
|
||||||
|
|
||||||
|
|
||||||
|
def configure_jaccl(args, hosts, ips, sshinfo):
|
||||||
|
log(args.verbose, "Prepare a jaccl hostfile")
|
||||||
|
check_rdma(hosts, args.verbose)
|
||||||
|
add_ethernet_ips(hosts, args.verbose)
|
||||||
|
|
||||||
|
hostfile = []
|
||||||
|
for i, h in enumerate(hosts):
|
||||||
|
rdma = []
|
||||||
|
for j in range(len(hosts)):
|
||||||
|
if i == j:
|
||||||
|
rdma.append(None)
|
||||||
|
else:
|
||||||
|
rdma.append(f"rdma_{ips.ips[i, j][0][0]}")
|
||||||
|
hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma})
|
||||||
|
|
||||||
|
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
|
||||||
|
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
|
||||||
|
|
||||||
|
if args.output_hostfile:
|
||||||
|
with open(args.output_hostfile, "w") as f:
|
||||||
|
json.dump(hostfile, f, indent=4)
|
||||||
|
else:
|
||||||
|
print("Hostfile")
|
||||||
|
print("========")
|
||||||
|
print(json.dumps(hostfile, indent=4))
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_tb_hostfile(args, hosts, sshinfo):
|
||||||
|
log(args.verbose, f"Preparing for communication over thunderbolt")
|
||||||
|
tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)
|
||||||
|
|
||||||
|
if args.dot:
|
||||||
|
tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index)
|
||||||
|
return
|
||||||
|
|
||||||
|
ips = IPConfigurator(hosts, tb_hosts, uuid_reverse_index)
|
||||||
|
connectivity = make_connectivity_matrix(tb_hosts, uuid_reverse_index)
|
||||||
|
|
||||||
|
if args.backend is None:
|
||||||
|
rings = extract_rings(connectivity)
|
||||||
|
has_mesh = check_valid_mesh(hosts, connectivity, False)
|
||||||
|
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
|
||||||
|
|
||||||
|
if not has_ring and not has_mesh:
|
||||||
|
log_error("Neither thunderbolt mesh nor ring found.")
|
||||||
|
log_error("Perhaps run with --dot to generate a plot of the connectivity.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
elif has_ring:
|
||||||
|
configure_ring(args, hosts, ips, rings[0], sshinfo)
|
||||||
|
|
||||||
|
else:
|
||||||
|
configure_jaccl(args, hosts, ips, sshinfo)
|
||||||
|
|
||||||
|
elif args.backend == "ring":
|
||||||
|
rings = extract_rings(connectivity)
|
||||||
|
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
|
||||||
|
if not has_ring:
|
||||||
|
log_error("Could not find a full ring.")
|
||||||
|
log_error()
|
||||||
|
log_error("Try passing --dot to visualize the connectivity")
|
||||||
|
if len(rings) > 0:
|
||||||
|
log_error("Rings found:")
|
||||||
|
for r in rings:
|
||||||
|
log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}")
|
||||||
|
sys.exit(1)
|
||||||
|
configure_ring(args, hosts, ips, rings[0], sshinfo)
|
||||||
|
|
||||||
|
elif args.backend == "jaccl":
|
||||||
|
check_valid_mesh(hosts, connectivity)
|
||||||
|
configure_jaccl(args, hosts, ips, sshinfo)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Configure remote machines for use with MLX distributed"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose", action="store_true", help="Print debug messages in stdout"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
|
||||||
|
)
|
||||||
|
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||||
|
parser.add_argument(
|
||||||
|
"--over",
|
||||||
|
choices=["thunderbolt", "ethernet"],
|
||||||
|
default="thunderbolt",
|
||||||
|
help="What type of connectivity to configure",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-hostfile", help="If provided, save the hostfile to this path"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--auto-setup",
|
||||||
|
"--no-auto-setup",
|
||||||
|
action=OptionalBoolAction,
|
||||||
|
nargs=0,
|
||||||
|
dest="auto_setup",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dot", action="store_true", help="Output the topology in DOT format and exit"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
choices=["ring", "jaccl"],
|
||||||
|
default=None,
|
||||||
|
help="Which distributed backend to configure",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.hostfile is not None:
|
||||||
|
hosts = parse_hostfile(parser, args.hostfile)
|
||||||
|
else:
|
||||||
|
hosts = parse_hostlist(parser, args.hosts, 1)
|
||||||
|
|
||||||
|
# Check that we can ssh
|
||||||
|
log(
|
||||||
|
args.verbose,
|
||||||
|
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
|
||||||
|
)
|
||||||
|
sshinfo = check_ssh_connections(hosts)
|
||||||
|
|
||||||
|
# Prepare a hostfile for communication over ethernet using the ips of the
|
||||||
|
# provided hostnames
|
||||||
|
if args.over == "ethernet":
|
||||||
|
prepare_ethernet_hostfile(args, hosts)
|
||||||
|
|
||||||
|
# Configure the macs for communication over thunderbolt, both via RDMA and IP
|
||||||
|
else:
|
||||||
|
prepare_tb_hostfile(args, hosts, sshinfo)
|
||||||
557
python/mlx/_distributed_utils/launch.py
Normal file
557
python/mlx/_distributed_utils/launch.py
Normal file
@@ -0,0 +1,557 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
from collections import Counter
|
||||||
|
from itertools import chain
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Empty as QueueEmpty
|
||||||
|
from queue import Queue
|
||||||
|
from select import select
|
||||||
|
from subprocess import PIPE, Popen, run
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from .common import log, log_warning, parse_hostfile, parse_hostlist, positive_number
|
||||||
|
|
||||||
|
|
||||||
|
class CommandProcess:
|
||||||
|
@property
|
||||||
|
def process(self):
|
||||||
|
"""Return the Popen object that refers to the current command."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exit_status(self):
|
||||||
|
"""Return a tuple (returncode, killed) for the command. It should be
|
||||||
|
(None, None) while the command is running normally."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def preprocess_output(self, data: str, is_stdout=False):
|
||||||
|
"""Preprocess the output of the command so that extra data can be
|
||||||
|
capture or the format changed on the fly."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def terminate(self):
|
||||||
|
"""Terminate or return the exit code."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteProcess(CommandProcess):
|
||||||
|
def __init__(self, rank, host, python, cwd, files, env, command):
|
||||||
|
is_local = host == "127.0.0.1"
|
||||||
|
cmd = RemoteProcess.make_launch_script(rank, cwd, files, env, command)
|
||||||
|
if not is_local:
|
||||||
|
cmd = f"ssh -tt -o LogLevel=QUIET {host} {shlex.quote(cmd)}"
|
||||||
|
|
||||||
|
self._host = host
|
||||||
|
self._pidfile = None
|
||||||
|
self._is_local = is_local
|
||||||
|
self._process = Popen(
|
||||||
|
cmd,
|
||||||
|
shell=True,
|
||||||
|
executable="/bin/bash",
|
||||||
|
stdin=PIPE,
|
||||||
|
stdout=PIPE,
|
||||||
|
stderr=PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._killed = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def process(self):
|
||||||
|
return self._process
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exit_status(self):
|
||||||
|
return self._process.poll(), self._killed
|
||||||
|
|
||||||
|
def preprocess_output(self, data, is_stdout=False):
|
||||||
|
if self._pidfile is None:
|
||||||
|
pidfile, *rest = data.split("\n", maxsplit=1)
|
||||||
|
self._pidfile = pidfile
|
||||||
|
return rest[0] if rest else ""
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def terminate(self):
|
||||||
|
if self._killed:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._process.terminate()
|
||||||
|
self._process.wait()
|
||||||
|
|
||||||
|
# Kill the remote program if possible
|
||||||
|
cmd = RemoteProcess.make_kill_script(self._pidfile)
|
||||||
|
if not self._is_local:
|
||||||
|
cmd = f"ssh {self._host} {shlex.quote(cmd)}"
|
||||||
|
c = run(
|
||||||
|
cmd,
|
||||||
|
check=True,
|
||||||
|
shell=True,
|
||||||
|
executable="/bin/bash",
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._killed = c.stdout.strip() == "1"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_launch_script(rank, cwd, files, env, command):
|
||||||
|
script = ""
|
||||||
|
|
||||||
|
# Disable echo
|
||||||
|
script = "stty -echo; "
|
||||||
|
|
||||||
|
# Write the PID to a file so we can kill the process if needed
|
||||||
|
script += "pidfile=$(mktemp); "
|
||||||
|
script += "echo $$ > $pidfile; "
|
||||||
|
script += 'printf "%s\\n" $pidfile; '
|
||||||
|
|
||||||
|
# Change the working directory if one was requested. Otherwise attempt to
|
||||||
|
# change to the current one but don't fail if it wasn't possible.
|
||||||
|
d = cwd or os.getcwd()
|
||||||
|
script += f"if [[ -d {repr(d)} ]]; then "
|
||||||
|
script += f" cd {repr(d)}; "
|
||||||
|
if cwd is not None:
|
||||||
|
script += "else "
|
||||||
|
script += f" echo 'Failed to change directory to' {repr(d)} >2; "
|
||||||
|
script += "fi; "
|
||||||
|
|
||||||
|
# Add the environment variables that were requested
|
||||||
|
for e in env:
|
||||||
|
key, *value = e.split("=", maxsplit=1)
|
||||||
|
value = shlex.quote(value[0]) if len(value) > 0 else ""
|
||||||
|
if not all(c.isalnum() or c == "_" for c in key):
|
||||||
|
log_warning(
|
||||||
|
f"'{e}' is an invalid environment variable so it is ignored"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
script += f"export {key}={value}; "
|
||||||
|
|
||||||
|
# Make the temporary files
|
||||||
|
for env_name, content in files.items():
|
||||||
|
script += "fname=$(mktemp); "
|
||||||
|
script += f"echo {shlex.quote(content)} >$fname; "
|
||||||
|
script += f"export {env_name}=$fname; "
|
||||||
|
|
||||||
|
# Finally add the rank
|
||||||
|
script += f"export MLX_RANK={rank}; "
|
||||||
|
|
||||||
|
# Replace the process with the script
|
||||||
|
script += f"cmd=({' '.join(map(shlex.quote, command))}); "
|
||||||
|
script += 'exec "${cmd[@]}"'
|
||||||
|
|
||||||
|
return script
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_kill_script(pidfile):
|
||||||
|
script = ""
|
||||||
|
script += f"pid=$(cat {pidfile}); "
|
||||||
|
script += "if ps -p $pid >/dev/null; then "
|
||||||
|
script += " kill $pid; "
|
||||||
|
script += " echo 1; "
|
||||||
|
script += "else "
|
||||||
|
script += " echo 0; "
|
||||||
|
script += "fi; "
|
||||||
|
script += f"rm {pidfile}"
|
||||||
|
|
||||||
|
return script
|
||||||
|
|
||||||
|
|
||||||
|
def _launch_with_io(command_class, arguments, verbose):
|
||||||
|
stop = False
|
||||||
|
exit_codes = [(None, None)] * len(arguments)
|
||||||
|
|
||||||
|
def _thread_fn(rank, *args, **kwargs):
|
||||||
|
stdin_queue = kwargs.pop("stdin_queue")
|
||||||
|
stdout_queue = kwargs.pop("stdout_queue")
|
||||||
|
stderr_queue = kwargs.pop("stderr_queue")
|
||||||
|
|
||||||
|
command = command_class(rank, *args, **kwargs)
|
||||||
|
p = command.process
|
||||||
|
os.set_blocking(p.stdout.fileno(), False)
|
||||||
|
os.set_blocking(p.stderr.fileno(), False)
|
||||||
|
os.set_blocking(p.stdin.fileno(), False)
|
||||||
|
|
||||||
|
to_read = [p.stdout.fileno(), p.stderr.fileno()]
|
||||||
|
to_write = [p.stdin.fileno()]
|
||||||
|
|
||||||
|
stdin_buffer = b""
|
||||||
|
while p.poll() is None:
|
||||||
|
try:
|
||||||
|
stdin_buffer += stdin_queue.get_nowait()
|
||||||
|
except QueueEmpty:
|
||||||
|
pass
|
||||||
|
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
|
||||||
|
for fd in rlist:
|
||||||
|
is_stdout = fd == p.stdout.fileno()
|
||||||
|
msg = os.read(fd, 8192).decode(errors="ignore")
|
||||||
|
msg = command.preprocess_output(msg, is_stdout)
|
||||||
|
if is_stdout:
|
||||||
|
stdout_queue.put(msg.encode())
|
||||||
|
else:
|
||||||
|
stderr_queue.put(msg.encode())
|
||||||
|
for fd in wlist:
|
||||||
|
if len(stdin_buffer) > 0:
|
||||||
|
n = os.write(fd, stdin_buffer)
|
||||||
|
stdin_buffer = stdin_buffer[n:]
|
||||||
|
if stop:
|
||||||
|
command.terminate()
|
||||||
|
break
|
||||||
|
exit_codes[rank] = command.exit_status
|
||||||
|
|
||||||
|
if exit_codes[rank][1]:
|
||||||
|
log_warning(f"Node with rank {rank} was killed")
|
||||||
|
elif exit_codes[rank][0] != 0:
|
||||||
|
log_warning(f"Node with rank {rank} exited with code {exit_codes[rank][0]}")
|
||||||
|
else:
|
||||||
|
log(verbose, f"Node with rank {rank} completed")
|
||||||
|
|
||||||
|
stdin_queues = []
|
||||||
|
stdout_queues = []
|
||||||
|
stderr_queues = []
|
||||||
|
threads = []
|
||||||
|
for i, (args, kwargs) in enumerate(arguments):
|
||||||
|
stdin_queues.append(Queue())
|
||||||
|
stdout_queues.append(Queue())
|
||||||
|
stderr_queues.append(Queue())
|
||||||
|
t = threading.Thread(
|
||||||
|
target=_thread_fn,
|
||||||
|
args=args,
|
||||||
|
kwargs=kwargs
|
||||||
|
| {
|
||||||
|
"stdin_queue": stdin_queues[-1],
|
||||||
|
"stdout_queue": stdout_queues[-1],
|
||||||
|
"stderr_queue": stderr_queues[-1],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
t.start()
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
os.set_blocking(sys.stdin.fileno(), False)
|
||||||
|
os.set_blocking(sys.stdout.fileno(), True)
|
||||||
|
os.set_blocking(sys.stderr.fileno(), True)
|
||||||
|
while not stop or any(not q.empty() for q in chain(stdout_queues, stderr_queues)):
|
||||||
|
# Broadcast user input to the jobs
|
||||||
|
rlist, _, _ = select([sys.stdin.fileno()], [], [], 0.1)
|
||||||
|
for fd in rlist:
|
||||||
|
stdin_buffer = os.read(fd, 8192)
|
||||||
|
for q in stdin_queues:
|
||||||
|
q.put(stdin_buffer)
|
||||||
|
|
||||||
|
# Gather job output
|
||||||
|
for q in stdout_queues:
|
||||||
|
try:
|
||||||
|
while not q.empty():
|
||||||
|
sys.stdout.buffer.write(q.get_nowait())
|
||||||
|
except QueueEmpty:
|
||||||
|
pass
|
||||||
|
for q in stderr_queues:
|
||||||
|
try:
|
||||||
|
while not q.empty():
|
||||||
|
sys.stderr.buffer.write(q.get_nowait())
|
||||||
|
except QueueEmpty:
|
||||||
|
pass
|
||||||
|
sys.stdout.buffer.flush()
|
||||||
|
sys.stderr.buffer.flush()
|
||||||
|
|
||||||
|
# Check if all are running and terminate otherwise
|
||||||
|
if any(t.is_alive() for t in threads):
|
||||||
|
for i, t in enumerate(threads):
|
||||||
|
if not t.is_alive():
|
||||||
|
if exit_codes[i][0] != 0:
|
||||||
|
stop = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Wait for the jobs to finish
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Process any remaining outputs
|
||||||
|
for q in stdout_queues:
|
||||||
|
while not q.empty():
|
||||||
|
sys.stdout.buffer.write(q.get())
|
||||||
|
for q in stderr_queues:
|
||||||
|
while not q.empty():
|
||||||
|
sys.stderr.buffer.write(q.get())
|
||||||
|
sys.stdout.buffer.flush()
|
||||||
|
sys.stderr.buffer.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def launch_ring(parser, hosts, args, command):
|
||||||
|
if any(len(h.ips) == 0 for h in hosts):
|
||||||
|
parser.error(
|
||||||
|
"The ring backend requires IPs to be provided instead of hostnames"
|
||||||
|
)
|
||||||
|
|
||||||
|
port = args.starting_port
|
||||||
|
ring_hosts = []
|
||||||
|
for h in hosts:
|
||||||
|
node = []
|
||||||
|
for ip in h.ips:
|
||||||
|
for i in range(args.connections_per_ip):
|
||||||
|
node.append(f"{ip}:{port}")
|
||||||
|
port += 1
|
||||||
|
ring_hosts.append(node)
|
||||||
|
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
|
||||||
|
|
||||||
|
files = {"MLX_HOSTFILE": hostfile}
|
||||||
|
env = args.env
|
||||||
|
if args.verbose:
|
||||||
|
env.append("MLX_RING_VERBOSE=1")
|
||||||
|
cwd = args.cwd
|
||||||
|
|
||||||
|
log(args.verbose, "Running", shlex.join(command))
|
||||||
|
|
||||||
|
_launch_with_io(
|
||||||
|
RemoteProcess,
|
||||||
|
[
|
||||||
|
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
|
||||||
|
for rank, h in enumerate(hosts)
|
||||||
|
],
|
||||||
|
args.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_nccl(parser, hosts, args, command):
|
||||||
|
if not hosts[0].ips:
|
||||||
|
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
|
||||||
|
|
||||||
|
master_host = hosts[0].ips[0]
|
||||||
|
master_port = args.nccl_port
|
||||||
|
world_size = len(hosts)
|
||||||
|
|
||||||
|
env = args.env
|
||||||
|
cwd = args.cwd
|
||||||
|
if args.verbose:
|
||||||
|
env.append("NCCL_DEBUG=INFO")
|
||||||
|
env.append(f"NCCL_HOST_IP={master_host}")
|
||||||
|
env.append(f"NCCL_PORT={master_port}")
|
||||||
|
env.append(f"MLX_WORLD_SIZE={world_size}")
|
||||||
|
|
||||||
|
log(args.verbose, "Running", shlex.join(command))
|
||||||
|
|
||||||
|
_launch_with_io(
|
||||||
|
RemoteProcess,
|
||||||
|
[
|
||||||
|
(
|
||||||
|
(
|
||||||
|
rank,
|
||||||
|
h.ssh_hostname,
|
||||||
|
args.python,
|
||||||
|
cwd,
|
||||||
|
{},
|
||||||
|
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
|
||||||
|
command,
|
||||||
|
),
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
for rank, h in enumerate(hosts)
|
||||||
|
],
|
||||||
|
args.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_jaccl(parser, hosts, args, command):
|
||||||
|
if not hosts[0].ips:
|
||||||
|
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
|
||||||
|
|
||||||
|
have_rdmas = all(len(h.rdma) == len(hosts) for h in hosts)
|
||||||
|
have_nulls = all(h.rdma[i] is None for i, h in enumerate(hosts))
|
||||||
|
if not have_rdmas or not have_nulls:
|
||||||
|
raise ValueError("Malformed hostfile for jaccl backend")
|
||||||
|
|
||||||
|
coordinator = hosts[0].ips[0]
|
||||||
|
env = args.env
|
||||||
|
cwd = args.cwd
|
||||||
|
env.append(f"MLX_JACCL_COORDINATOR={coordinator}:{args.starting_port}")
|
||||||
|
files = {"MLX_IBV_DEVICES": json.dumps([h.rdma for h in hosts])}
|
||||||
|
|
||||||
|
log(args.verbose, "Running", shlex.join(command))
|
||||||
|
|
||||||
|
_launch_with_io(
|
||||||
|
RemoteProcess,
|
||||||
|
[
|
||||||
|
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
|
||||||
|
for rank, h in enumerate(hosts)
|
||||||
|
],
|
||||||
|
args.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mpi_libname():
|
||||||
|
try:
|
||||||
|
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
|
||||||
|
ompi_info = ompi_info.stdout.strip().decode()
|
||||||
|
|
||||||
|
if platform.system() == "Darwin":
|
||||||
|
otool_output = run(
|
||||||
|
["otool", "-L", ompi_info], check=True, capture_output=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
|
||||||
|
otool_output = otool_output.stdout.decode()
|
||||||
|
|
||||||
|
# StopIteration if not found
|
||||||
|
libmpi_line = next(
|
||||||
|
filter(lambda line: "libmpi" in line, otool_output.splitlines())
|
||||||
|
)
|
||||||
|
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def launch_mpi(parser, hosts, args, command):
|
||||||
|
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
|
||||||
|
mpirun = mpirun.stdout.strip().decode()
|
||||||
|
|
||||||
|
# Compatibility with homebrew and pip installs
|
||||||
|
mpi_libname = get_mpi_libname()
|
||||||
|
if mpi_libname is not None:
|
||||||
|
dyld = Path(mpirun).parent.parent / "lib"
|
||||||
|
args.env = [
|
||||||
|
f"DYLD_LIBRARY_PATH={str(dyld)}",
|
||||||
|
f"MLX_MPI_LIBNAME={mpi_libname}",
|
||||||
|
] + args.env
|
||||||
|
|
||||||
|
log(args.verbose, f"Using '{mpirun}'")
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||||
|
hosts = Counter((h.ssh_hostname for h in hosts))
|
||||||
|
for h, n in hosts.items():
|
||||||
|
print(f"{h} slots={n}", file=f)
|
||||||
|
f.flush()
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
mpirun,
|
||||||
|
"--output",
|
||||||
|
":raw", # do not line buffer output
|
||||||
|
"--hostfile",
|
||||||
|
f.name,
|
||||||
|
*(["-cwd", args.cwd] if args.cwd else []),
|
||||||
|
*sum((["-x", e] for e in args.env), []),
|
||||||
|
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
|
||||||
|
"--",
|
||||||
|
*command,
|
||||||
|
]
|
||||||
|
log(args.verbose, "Running", " ".join(cmd))
|
||||||
|
try:
|
||||||
|
run(cmd)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
|
||||||
|
parser.add_argument(
|
||||||
|
"--print-python",
|
||||||
|
action="store_true",
|
||||||
|
help="Print the path to the current python executable and exit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose", action="store_true", help="Print debug messages in stdout"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeat-hosts",
|
||||||
|
"-n",
|
||||||
|
type=positive_number,
|
||||||
|
default=1,
|
||||||
|
help="Repeat each host a given number of times",
|
||||||
|
)
|
||||||
|
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
choices=["ring", "mpi", "nccl", "jaccl"],
|
||||||
|
default="nccl" if mx.cuda.is_available() else "ring",
|
||||||
|
help="Which distributed backend to launch",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--env",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="Set environment variables for the jobs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mpi-arg",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="Arguments to pass directly to mpirun",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--connections-per-ip",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="How many connections per ip to use for the ring backend",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--starting-port",
|
||||||
|
"-p",
|
||||||
|
type=int,
|
||||||
|
default=32323,
|
||||||
|
help="For the ring backend listen on this port increasing by 1 per rank and IP",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cwd", help="Set the working directory on each node to the provided one"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--nccl-port",
|
||||||
|
type=int,
|
||||||
|
default=12345,
|
||||||
|
help="The port to use for the NCCL communication (only for nccl backend)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-verify-script",
|
||||||
|
action="store_false",
|
||||||
|
dest="verify_script",
|
||||||
|
help="Do not verify that the script exists",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--python", default=sys.executable, help="Use this python on the remote hosts"
|
||||||
|
)
|
||||||
|
|
||||||
|
args, rest = parser.parse_known_args()
|
||||||
|
|
||||||
|
if args.print_python:
|
||||||
|
print(args.python)
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(rest) == 0:
|
||||||
|
parser.error("No script is provided")
|
||||||
|
if rest[0] == "--":
|
||||||
|
rest.pop(0)
|
||||||
|
|
||||||
|
# Try to extract a list of hosts and corresponding ips
|
||||||
|
if args.hostfile is not None:
|
||||||
|
hosts = parse_hostfile(parser, args.hostfile)
|
||||||
|
else:
|
||||||
|
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
|
||||||
|
|
||||||
|
# Check if the script is a file and convert it to a full path
|
||||||
|
if (script := Path(rest[0])).exists() and script.is_file():
|
||||||
|
rest[0:1] = [args.python, str(script.resolve())]
|
||||||
|
elif (command := shutil.which(rest[0])) is not None:
|
||||||
|
rest[0] = command
|
||||||
|
elif args.verify_script:
|
||||||
|
raise ValueError(f"Invalid script or command {rest[0]}")
|
||||||
|
|
||||||
|
# Launch
|
||||||
|
if args.backend == "ring":
|
||||||
|
launch_ring(parser, hosts, args, rest)
|
||||||
|
if args.backend == "mpi":
|
||||||
|
launch_mpi(parser, hosts, args, rest)
|
||||||
|
if args.backend == "nccl":
|
||||||
|
launch_nccl(parser, hosts, args, rest)
|
||||||
|
if args.backend == "jaccl":
|
||||||
|
launch_jaccl(parser, hosts, args, rest)
|
||||||
@@ -1,909 +0,0 @@
|
|||||||
# Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import base64
|
|
||||||
import ipaddress
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import shlex
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from collections import Counter
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from queue import Empty as QueueEmpty
|
|
||||||
from queue import Queue
|
|
||||||
from select import select
|
|
||||||
from subprocess import PIPE, Popen, run
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Host:
|
|
||||||
rank: int
|
|
||||||
ssh_hostname: str
|
|
||||||
ips: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ThunderboltPort:
|
|
||||||
iface: str
|
|
||||||
uuid: str
|
|
||||||
connected_to: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ThunderboltHost:
|
|
||||||
name: str
|
|
||||||
ports: list[ThunderboltPort]
|
|
||||||
|
|
||||||
|
|
||||||
def parse_hardware_ports(ports_string):
|
|
||||||
ports = {}
|
|
||||||
port_name = None
|
|
||||||
for l in ports_string.decode("utf-8").split("\n"):
|
|
||||||
if l.startswith("Hardware Port:"):
|
|
||||||
port_name = l.strip()[15:]
|
|
||||||
elif l.startswith("Device:"):
|
|
||||||
ports[port_name] = l.strip()[8:]
|
|
||||||
port_name = None
|
|
||||||
return ports
|
|
||||||
|
|
||||||
|
|
||||||
def get_num_nvidia_gpus():
|
|
||||||
result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
|
|
||||||
return len(result.stdout.strip().split("\n"))
|
|
||||||
|
|
||||||
|
|
||||||
def extract_rings(hosts, index):
|
|
||||||
def usable_port(i, j, used_ports):
|
|
||||||
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
|
|
||||||
|
|
||||||
def dfs(start_node, node, path, visited, used_ports):
|
|
||||||
path.append(node)
|
|
||||||
visited.add(node)
|
|
||||||
for j, p in enumerate(hosts[node].ports):
|
|
||||||
if not usable_port(node, j, used_ports):
|
|
||||||
continue
|
|
||||||
next_node, _ = index[p.connected_to]
|
|
||||||
if next_node == start_node:
|
|
||||||
yield path[:]
|
|
||||||
if next_node not in visited:
|
|
||||||
yield from dfs(start_node, next_node, path, visited, used_ports)
|
|
||||||
path.pop()
|
|
||||||
visited.remove(node)
|
|
||||||
|
|
||||||
# Concretize maps the found cycle to real thunderbolt ports. It also adds
|
|
||||||
# those ports to the used set so next cycles can't use them again.
|
|
||||||
def concretize(cycle, used_ports):
|
|
||||||
concrete_path = []
|
|
||||||
for n1, n2 in zip(cycle, cycle[1:] + cycle[:1]):
|
|
||||||
for j, p in enumerate(hosts[n1].ports):
|
|
||||||
if not usable_port(n1, j, used_ports):
|
|
||||||
continue
|
|
||||||
n2_hat, nj = index[p.connected_to]
|
|
||||||
if n2 == n2_hat:
|
|
||||||
concrete_path.append(((n1, j), (n2, nj)))
|
|
||||||
used_ports.add((n1, j))
|
|
||||||
used_ports.add((n2, nj))
|
|
||||||
break
|
|
||||||
if concrete_path[-1][0][0] != n1:
|
|
||||||
raise RuntimeError("Couldn't concretize the cycle")
|
|
||||||
return concrete_path
|
|
||||||
|
|
||||||
# Normalize tries to ensure that the cycles have the same direction so we can
|
|
||||||
# use them together. We achieve this by selecting the direction such that
|
|
||||||
# the smallest rank hosts connect to larger rank hosts.
|
|
||||||
def normalize(path):
|
|
||||||
small_to_large = sum(1 for p in path if p[0][0] < p[1][0])
|
|
||||||
if small_to_large > len(path) - small_to_large:
|
|
||||||
return path
|
|
||||||
else:
|
|
||||||
return [(p[1], p[0]) for p in path]
|
|
||||||
|
|
||||||
rings = []
|
|
||||||
used_ports = set()
|
|
||||||
for start_node in range(len(hosts)):
|
|
||||||
while True:
|
|
||||||
ring = []
|
|
||||||
for r in dfs(start_node, start_node, [], set(), used_ports):
|
|
||||||
if len(r) > len(ring):
|
|
||||||
ring = r
|
|
||||||
# Break early since we won't find a bigger ring no matter what
|
|
||||||
if len(ring) == len(hosts):
|
|
||||||
break
|
|
||||||
if not ring:
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
rings.append(normalize(concretize(ring, used_ports)))
|
|
||||||
except RuntimeError:
|
|
||||||
if len(rings) > 0:
|
|
||||||
return rings
|
|
||||||
raise
|
|
||||||
|
|
||||||
return rings
|
|
||||||
|
|
||||||
|
|
||||||
def positive_number(x):
|
|
||||||
x = int(x)
|
|
||||||
if x <= 0:
|
|
||||||
raise ValueError("Number should be positive")
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def log(verbose, *args, **kwargs):
|
|
||||||
if not verbose:
|
|
||||||
return
|
|
||||||
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def log_warning(*args, **kwargs):
|
|
||||||
kwargs["file"] = sys.stderr
|
|
||||||
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def log_error(*args, **kwargs):
|
|
||||||
kwargs["file"] = sys.stderr
|
|
||||||
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_hostfile(parser, hostfile):
|
|
||||||
"""Parse the json hostfile that contains both the hostnames to ssh into and
|
|
||||||
the ips to communicate over when using the ring backend.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
[
|
|
||||||
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
|
|
||||||
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
|
|
||||||
...
|
|
||||||
{"ssh": "hostnameN", "ips": ["123.123.123.N"]},
|
|
||||||
]
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hostfile (str): The path to the json file containing the host
|
|
||||||
information
|
|
||||||
"""
|
|
||||||
hostfile = Path(hostfile)
|
|
||||||
if not hostfile.exists():
|
|
||||||
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
|
|
||||||
|
|
||||||
try:
|
|
||||||
hosts = []
|
|
||||||
with open(hostfile) as f:
|
|
||||||
for i, h in enumerate(json.load(f)):
|
|
||||||
hosts.append(Host(i, h["ssh"], h.get("ips", [])))
|
|
||||||
return hosts
|
|
||||||
except Exception as e:
|
|
||||||
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_hostlist(parser, hostlist, repeats):
|
|
||||||
hosts = []
|
|
||||||
for i, h in enumerate(hostlist.split(",")):
|
|
||||||
if h == "":
|
|
||||||
raise ValueError("Hostname cannot be empty")
|
|
||||||
try:
|
|
||||||
ipaddress.ip_address(h)
|
|
||||||
ips = [h]
|
|
||||||
except ValueError:
|
|
||||||
ips = []
|
|
||||||
for i in range(repeats):
|
|
||||||
hosts.append(Host(i, h, ips))
|
|
||||||
return hosts
|
|
||||||
|
|
||||||
|
|
||||||
def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
|
|
||||||
# Imports that are used throughout
|
|
||||||
script = ""
|
|
||||||
script += "import os\n"
|
|
||||||
script += "import sys\n"
|
|
||||||
script += "import tempfile\n"
|
|
||||||
script += "from pathlib import Path\n"
|
|
||||||
|
|
||||||
# Write the PID to a file so we can kill the process if needed
|
|
||||||
script += "_, pidfile = tempfile.mkstemp() \n"
|
|
||||||
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
|
|
||||||
script += "print(pidfile, flush=True)\n"
|
|
||||||
|
|
||||||
# Change the working directory if one was requested. Otherwise attempt to
|
|
||||||
# change to the current one but don't fail if it wasn't possible.
|
|
||||||
d = cwd or os.getcwd()
|
|
||||||
script += f"if Path({repr(d)}).exists():\n"
|
|
||||||
script += f" os.chdir({repr(d)})\n"
|
|
||||||
if cwd is not None:
|
|
||||||
script += "else:\n"
|
|
||||||
script += (
|
|
||||||
f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
|
|
||||||
)
|
|
||||||
script += f" sys.exit(1)\n"
|
|
||||||
|
|
||||||
# Add the environment variables that were given to us
|
|
||||||
script += "env = dict(os.environ)\n"
|
|
||||||
for e in env:
|
|
||||||
key, *value = e.split("=", maxsplit=1)
|
|
||||||
value = shlex.quote(value[0]) if len(value) > 0 else ""
|
|
||||||
if not all(c.isalnum() or c == "_" for c in key):
|
|
||||||
log_warning(f"'{e}' is an invalid environment variable so it is ignored")
|
|
||||||
continue
|
|
||||||
script += f"env[{repr(key)}] = {repr(value)}\n"
|
|
||||||
|
|
||||||
# Add the environment variables to enable the ring distributed backend
|
|
||||||
if hostfile != "":
|
|
||||||
script += "_, hostfile = tempfile.mkstemp()\n"
|
|
||||||
script += "with open(hostfile, 'w') as f:\n"
|
|
||||||
script += f" f.write({repr(hostfile)})\n"
|
|
||||||
if verbose:
|
|
||||||
script += "env['MLX_RING_VERBOSE'] = '1'\n"
|
|
||||||
script += "env['MLX_HOSTFILE'] = hostfile\n"
|
|
||||||
script += f"env['MLX_RANK'] = '{rank}'\n"
|
|
||||||
script += "\n"
|
|
||||||
|
|
||||||
# Replace the process with the script
|
|
||||||
script += f"command = [{','.join(map(repr, command))}]\n"
|
|
||||||
script += "os.execve(command[0], command, env)\n"
|
|
||||||
|
|
||||||
return script
|
|
||||||
|
|
||||||
|
|
||||||
def launch_ring(parser, hosts, args, command):
|
|
||||||
stop = False
|
|
||||||
exit_codes = [None] * len(hosts)
|
|
||||||
|
|
||||||
def node_thread(rank, host, hostfile, input_queue):
|
|
||||||
is_local = host == "127.0.0.1"
|
|
||||||
script = make_monitor_script(
|
|
||||||
rank, hostfile, args.cwd, args.env, command, args.verbose
|
|
||||||
)
|
|
||||||
script_b64 = base64.b64encode(script.encode()).decode()
|
|
||||||
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
|
|
||||||
if not is_local:
|
|
||||||
cmd = f"ssh {host} '{cmd}'"
|
|
||||||
p = Popen(
|
|
||||||
cmd,
|
|
||||||
shell=True,
|
|
||||||
stdin=PIPE,
|
|
||||||
stdout=PIPE,
|
|
||||||
stderr=PIPE,
|
|
||||||
)
|
|
||||||
os.set_blocking(p.stdout.fileno(), False)
|
|
||||||
os.set_blocking(p.stderr.fileno(), False)
|
|
||||||
os.set_blocking(p.stdin.fileno(), False)
|
|
||||||
|
|
||||||
# Repeat the stdout and stderr to the local machine
|
|
||||||
to_read = [p.stdout.fileno(), p.stderr.fileno()]
|
|
||||||
to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()]
|
|
||||||
pidfile = ""
|
|
||||||
stdin_buffer = b""
|
|
||||||
stdout_buffer = b""
|
|
||||||
stderr_buffer = b""
|
|
||||||
while p.poll() is None:
|
|
||||||
try:
|
|
||||||
stdin_buffer += input_queue.get_nowait()
|
|
||||||
except QueueEmpty:
|
|
||||||
pass
|
|
||||||
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
|
|
||||||
for fd in rlist:
|
|
||||||
msg = os.read(fd, 8192).decode(errors="ignore")
|
|
||||||
|
|
||||||
# Fetch the PID file first if we haven't already
|
|
||||||
if pidfile == "":
|
|
||||||
pidfile, *msg = msg.split("\n", maxsplit=1)
|
|
||||||
msg = msg[0] if msg else ""
|
|
||||||
|
|
||||||
is_stdout = fd == p.stdout.fileno()
|
|
||||||
if is_stdout:
|
|
||||||
stdout_buffer += msg.encode()
|
|
||||||
else:
|
|
||||||
stderr_buffer += msg.encode()
|
|
||||||
for fd in wlist:
|
|
||||||
if fd == p.stdin.fileno() and len(stdin_buffer) > 0:
|
|
||||||
n = os.write(fd, stdin_buffer)
|
|
||||||
stdin_buffer = stdin_buffer[n:]
|
|
||||||
elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0:
|
|
||||||
n = os.write(fd, stdout_buffer)
|
|
||||||
stdout_buffer = stdout_buffer[n:]
|
|
||||||
elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0:
|
|
||||||
n = os.write(fd, stderr_buffer)
|
|
||||||
stderr_buffer = stderr_buffer[n:]
|
|
||||||
if stop:
|
|
||||||
p.terminate()
|
|
||||||
break
|
|
||||||
p.wait()
|
|
||||||
exit_codes[rank] = p.returncode
|
|
||||||
|
|
||||||
# Kill the remote program if possible
|
|
||||||
cmd = ""
|
|
||||||
cmd += f"pid=$(cat {pidfile}); "
|
|
||||||
cmd += "if ps -p $pid >/dev/null; then "
|
|
||||||
cmd += " kill $pid; "
|
|
||||||
cmd += " echo 1; "
|
|
||||||
cmd += "else "
|
|
||||||
cmd += " echo 0; "
|
|
||||||
cmd += "fi; "
|
|
||||||
cmd += f"rm {pidfile}"
|
|
||||||
if not is_local:
|
|
||||||
cmd = f"ssh {host} '{cmd}'"
|
|
||||||
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
|
|
||||||
if c.stdout.strip() == "1":
|
|
||||||
log_warning(f"Node with rank {rank} was killed")
|
|
||||||
elif p.returncode != 0:
|
|
||||||
log_warning(f"Node with rank {rank} exited with code {p.returncode}")
|
|
||||||
else:
|
|
||||||
log(args.verbose, f"Node with rank {rank} completed")
|
|
||||||
|
|
||||||
if all(len(h.ips) == 0 for h in hosts):
|
|
||||||
parser.error(
|
|
||||||
"The ring backend requires IPs to be provided instead of hostnames"
|
|
||||||
)
|
|
||||||
|
|
||||||
port = args.starting_port
|
|
||||||
ring_hosts = []
|
|
||||||
for h in hosts:
|
|
||||||
node = []
|
|
||||||
for ip in h.ips:
|
|
||||||
for i in range(args.connections_per_ip):
|
|
||||||
node.append(f"{ip}:{port}")
|
|
||||||
port += 1
|
|
||||||
ring_hosts.append(node)
|
|
||||||
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
|
|
||||||
|
|
||||||
log(args.verbose, "Running", shlex.join(command))
|
|
||||||
|
|
||||||
input_queues = []
|
|
||||||
threads = []
|
|
||||||
for i, h in enumerate(hosts):
|
|
||||||
if i + 1 == len(hosts):
|
|
||||||
time.sleep(1.0)
|
|
||||||
input_queues.append(Queue())
|
|
||||||
t = threading.Thread(
|
|
||||||
target=node_thread, args=(i, h.ssh_hostname, hostfile, input_queues[-1])
|
|
||||||
)
|
|
||||||
t.start()
|
|
||||||
threads.append(t)
|
|
||||||
|
|
||||||
os.set_blocking(sys.stdin.fileno(), False)
|
|
||||||
while not stop:
|
|
||||||
rlist, _, _ = select([sys.stdin.fileno()], [], [], 1.0)
|
|
||||||
for fd in rlist:
|
|
||||||
stdin_buffer = os.read(fd, 8192)
|
|
||||||
for q in input_queues:
|
|
||||||
q.put(stdin_buffer)
|
|
||||||
if any(t.is_alive() for t in threads):
|
|
||||||
for i, t in enumerate(threads):
|
|
||||||
if not t.is_alive():
|
|
||||||
if exit_codes[i] != 0:
|
|
||||||
stop = True
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
for t in threads:
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
|
|
||||||
def get_mpi_libname():
|
|
||||||
try:
|
|
||||||
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
|
|
||||||
ompi_info = ompi_info.stdout.strip().decode()
|
|
||||||
|
|
||||||
if platform.system() == "Darwin":
|
|
||||||
otool_output = run(
|
|
||||||
["otool", "-L", ompi_info], check=True, capture_output=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
|
|
||||||
otool_output = otool_output.stdout.decode()
|
|
||||||
|
|
||||||
# StopIteration if not found
|
|
||||||
libmpi_line = next(
|
|
||||||
filter(lambda line: "libmpi" in line, otool_output.splitlines())
|
|
||||||
)
|
|
||||||
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def launch_mpi(parser, hosts, args, command):
|
|
||||||
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
|
|
||||||
mpirun = mpirun.stdout.strip().decode()
|
|
||||||
|
|
||||||
# Compatibility with homebrew and pip installs
|
|
||||||
mpi_libname = get_mpi_libname()
|
|
||||||
if mpi_libname is not None:
|
|
||||||
dyld = Path(mpirun).parent.parent / "lib"
|
|
||||||
args.env = [
|
|
||||||
f"DYLD_LIBRARY_PATH={str(dyld)}",
|
|
||||||
f"MLX_MPI_LIBNAME={mpi_libname}",
|
|
||||||
] + args.env
|
|
||||||
|
|
||||||
log(args.verbose, f"Using '{mpirun}'")
|
|
||||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
|
||||||
hosts = Counter((h.ssh_hostname for h in hosts))
|
|
||||||
for h, n in hosts.items():
|
|
||||||
print(f"{h} slots={n}", file=f)
|
|
||||||
f.flush()
|
|
||||||
|
|
||||||
cmd = [
|
|
||||||
mpirun,
|
|
||||||
"--output",
|
|
||||||
":raw", # do not line buffer output
|
|
||||||
"--hostfile",
|
|
||||||
f.name,
|
|
||||||
*(["-cwd", args.cwd] if args.cwd else []),
|
|
||||||
*sum((["-x", e] for e in args.env), []),
|
|
||||||
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
|
|
||||||
"--",
|
|
||||||
*command,
|
|
||||||
]
|
|
||||||
log(args.verbose, "Running", " ".join(cmd))
|
|
||||||
try:
|
|
||||||
run(cmd)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def launch_nccl(parser, hosts, args, command):
|
|
||||||
master_host = hosts[0].ips[0]
|
|
||||||
|
|
||||||
if master_host != "127.0.0.1":
|
|
||||||
raise ValueError("The NCCL backend only supports localhost for now.")
|
|
||||||
master_port = args.nccl_port
|
|
||||||
world_size = len(hosts)
|
|
||||||
|
|
||||||
base_env = os.environ.copy()
|
|
||||||
base_env.update(
|
|
||||||
{
|
|
||||||
"NCCL_DEBUG": base_env.get(
|
|
||||||
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
|
|
||||||
),
|
|
||||||
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
|
||||||
"NCCL_HOST_IP": master_host,
|
|
||||||
"NCCL_PORT": str(master_port),
|
|
||||||
"MLX_WORLD_SIZE": str(world_size),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
procs = []
|
|
||||||
num_gpus = get_num_nvidia_gpus()
|
|
||||||
if num_gpus == 0:
|
|
||||||
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
|
|
||||||
if args.repeat_hosts > num_gpus:
|
|
||||||
raise RuntimeError("NCCL requires a separate GPU per process.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
for rank in range(world_size):
|
|
||||||
env = base_env.copy()
|
|
||||||
mlx_rank = str(rank % args.repeat_hosts)
|
|
||||||
env["MLX_RANK"] = mlx_rank
|
|
||||||
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
|
|
||||||
p = Popen(command, env=env)
|
|
||||||
procs.append(p)
|
|
||||||
|
|
||||||
for p in procs:
|
|
||||||
ret = p.wait()
|
|
||||||
if ret != 0:
|
|
||||||
raise RuntimeError(f"Rank process exited with {ret}")
|
|
||||||
|
|
||||||
except (RuntimeError, KeyboardInterrupt) as err:
|
|
||||||
for p in procs:
|
|
||||||
if p.poll() is None:
|
|
||||||
try:
|
|
||||||
p.kill()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def check_ssh_connections(hosts):
|
|
||||||
results = [False] * len(hosts)
|
|
||||||
|
|
||||||
def _check(hostname, i):
|
|
||||||
result = run(
|
|
||||||
[
|
|
||||||
"ssh",
|
|
||||||
"-o",
|
|
||||||
"BatchMode=yes",
|
|
||||||
"-o",
|
|
||||||
"ConnectTimeout=5",
|
|
||||||
hostname,
|
|
||||||
"echo",
|
|
||||||
"success",
|
|
||||||
],
|
|
||||||
stdout=PIPE,
|
|
||||||
stderr=PIPE,
|
|
||||||
)
|
|
||||||
results[i] = result.returncode == 0
|
|
||||||
|
|
||||||
threads = [
|
|
||||||
threading.Thread(target=_check, args=(h.ssh_hostname, i))
|
|
||||||
for i, h in enumerate(hosts)
|
|
||||||
]
|
|
||||||
for t in threads:
|
|
||||||
t.start()
|
|
||||||
for t in threads:
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
if not all(results):
|
|
||||||
log_error("Could not ssh to the following hosts:")
|
|
||||||
for i, h in enumerate(hosts):
|
|
||||||
if not results[i]:
|
|
||||||
log_error(" - ", h.ssh_hostname)
|
|
||||||
log_error()
|
|
||||||
log_error("Maybe they are not set-up for password-less ssh?")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_tb_ring(args, hosts):
|
|
||||||
log(
|
|
||||||
args.verbose,
|
|
||||||
f"Preparing a thunderbolt ring for {', '.join(h.ssh_hostname for h in hosts)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that we can ssh
|
|
||||||
check_ssh_connections(hosts)
|
|
||||||
if args.auto_setup and args.verbose:
|
|
||||||
log_warning(
|
|
||||||
"--auto-setup is requested which requires password-less sudo",
|
|
||||||
"on the remote hosts",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract the current connectivity from the remote hosts
|
|
||||||
thunderbolt_connections = []
|
|
||||||
for h in hosts:
|
|
||||||
log(args.verbose, "Getting connectivity from", h.ssh_hostname)
|
|
||||||
thunderbolt_connections.append(
|
|
||||||
json.loads(
|
|
||||||
run(
|
|
||||||
[
|
|
||||||
"ssh",
|
|
||||||
h.ssh_hostname,
|
|
||||||
"system_profiler",
|
|
||||||
"SPThunderboltDataType",
|
|
||||||
"-json",
|
|
||||||
],
|
|
||||||
capture_output=True,
|
|
||||||
).stdout
|
|
||||||
)
|
|
||||||
)
|
|
||||||
interface_maps = []
|
|
||||||
for h in hosts:
|
|
||||||
log(args.verbose, "Getting interface names from", h.ssh_hostname)
|
|
||||||
interface_maps.append(
|
|
||||||
parse_hardware_ports(
|
|
||||||
run(
|
|
||||||
[
|
|
||||||
"ssh",
|
|
||||||
h.ssh_hostname,
|
|
||||||
"networksetup",
|
|
||||||
"-listallhardwareports",
|
|
||||||
],
|
|
||||||
capture_output=True,
|
|
||||||
).stdout
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse the connectivity into some simple dataclasses
|
|
||||||
tb_hosts = []
|
|
||||||
for c, iface_map in zip(thunderbolt_connections, interface_maps):
|
|
||||||
name = ""
|
|
||||||
ports = []
|
|
||||||
for t in c["SPThunderboltDataType"]:
|
|
||||||
uuid = t.get("domain_uuid_key")
|
|
||||||
if uuid is None:
|
|
||||||
continue
|
|
||||||
name = t["device_name_key"]
|
|
||||||
tag = t["receptacle_1_tag"]["receptacle_id_key"]
|
|
||||||
items = t.get("_items", [])
|
|
||||||
connected_items = [item for item in items if "domain_uuid_key" in item]
|
|
||||||
connected_to = (
|
|
||||||
connected_items[0]["domain_uuid_key"] if connected_items else None
|
|
||||||
)
|
|
||||||
iface = iface_map[f"Thunderbolt {tag}"]
|
|
||||||
ports.append(ThunderboltPort(iface, uuid, connected_to))
|
|
||||||
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
|
|
||||||
|
|
||||||
# Create a reverse index to be able to map uuids to (host, port) quickly
|
|
||||||
uuid_reverse_index = {}
|
|
||||||
for i, h in enumerate(tb_hosts):
|
|
||||||
for j, p in enumerate(h.ports):
|
|
||||||
uuid_reverse_index[p.uuid] = (i, j)
|
|
||||||
|
|
||||||
# Find the rings by simply walking and marking visited (host, port) tuples
|
|
||||||
# and keeping the largest rings greedily.
|
|
||||||
log(args.verbose, "Extracting rings from the parsed connectivity")
|
|
||||||
rings = extract_rings(tb_hosts, uuid_reverse_index)
|
|
||||||
|
|
||||||
# Just output a DOT graphical representation of the found rings
|
|
||||||
if args.dot:
|
|
||||||
names = []
|
|
||||||
for i in range(len(tb_hosts)):
|
|
||||||
n = ""
|
|
||||||
j = i
|
|
||||||
while True:
|
|
||||||
n += chr(97 + j % 26)
|
|
||||||
j //= 26
|
|
||||||
if j == 0:
|
|
||||||
break
|
|
||||||
names.append(n)
|
|
||||||
|
|
||||||
print("graph G {")
|
|
||||||
print(" node [shape=rectangle];")
|
|
||||||
for i, h in enumerate(hosts):
|
|
||||||
print(f' {names[i]} [label="{h.ssh_hostname}"];')
|
|
||||||
for r in rings:
|
|
||||||
for (i, _), (j, _) in r:
|
|
||||||
print(f" {names[i]} -- {names[j]};")
|
|
||||||
print("}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Assign IPs to each interface such that the interfaces can communicate
|
|
||||||
ips = {}
|
|
||||||
pairs = {}
|
|
||||||
expecting = set()
|
|
||||||
ip0 = 0
|
|
||||||
ip1 = 0
|
|
||||||
netmask = "255.255.255.252"
|
|
||||||
for r in rings:
|
|
||||||
for a, b in r:
|
|
||||||
ips[a] = f"192.168.{ip0}.{ip1 + 1}"
|
|
||||||
ips[b] = f"192.168.{ip0}.{ip1 + 2}"
|
|
||||||
pairs[a] = b
|
|
||||||
pairs[b] = a
|
|
||||||
expecting.add(b)
|
|
||||||
ip1 += 4
|
|
||||||
if ip1 > 255:
|
|
||||||
ip0 += 1
|
|
||||||
ip1 = 0
|
|
||||||
if ip0 > 255:
|
|
||||||
raise ValueError("Ran out of available local IPs for the ring")
|
|
||||||
|
|
||||||
# Extract the host order from the first ring
|
|
||||||
hostmap = dict((r[0][0], r[1][0]) for r in rings[0])
|
|
||||||
first_host = min(hostmap.keys())
|
|
||||||
order = [first_host]
|
|
||||||
while hostmap[order[-1]] != first_host:
|
|
||||||
order.append(hostmap[order[-1]])
|
|
||||||
|
|
||||||
# Create the hostfile
|
|
||||||
hostfile = []
|
|
||||||
for i in order:
|
|
||||||
h = hosts[i]
|
|
||||||
host = {
|
|
||||||
"ssh": h.ssh_hostname,
|
|
||||||
"ips": [
|
|
||||||
ips[i, j]
|
|
||||||
for j, p in enumerate(tb_hosts[i].ports)
|
|
||||||
if (i, j) in expecting
|
|
||||||
],
|
|
||||||
}
|
|
||||||
hostfile.append(host)
|
|
||||||
|
|
||||||
if not args.hostfile_only:
|
|
||||||
for i, h in enumerate(hosts):
|
|
||||||
command = ""
|
|
||||||
command += "sudo ifconfig bridge0 down\n"
|
|
||||||
for j, p in enumerate(tb_hosts[i].ports):
|
|
||||||
if (i, j) not in ips:
|
|
||||||
continue
|
|
||||||
iface = p.iface
|
|
||||||
ip = ips[i, j]
|
|
||||||
peer = ips[pairs[i, j]]
|
|
||||||
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
|
|
||||||
command += f"sudo route change {peer} -interface {iface}\n"
|
|
||||||
if args.auto_setup:
|
|
||||||
print(f"Running auto setup for {h.ssh_hostname}")
|
|
||||||
command = command.strip().replace("\n", " && ")
|
|
||||||
command = ["ssh", h.ssh_hostname, command]
|
|
||||||
log(args.verbose, shlex.join(command))
|
|
||||||
run(command)
|
|
||||||
else:
|
|
||||||
msg = f"Setup for {h.ssh_hostname}"
|
|
||||||
print(msg)
|
|
||||||
print("=" * len(msg))
|
|
||||||
print(command)
|
|
||||||
input("Enter to continue")
|
|
||||||
print()
|
|
||||||
|
|
||||||
if args.output_hostfile:
|
|
||||||
with open(args.output_hostfile, "w") as f:
|
|
||||||
json.dump(hostfile, f, indent=4)
|
|
||||||
else:
|
|
||||||
print("Hostfile")
|
|
||||||
print("========")
|
|
||||||
print(json.dumps(hostfile, indent=4))
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_hostfile(args, hosts):
|
|
||||||
log(
|
|
||||||
args.verbose,
|
|
||||||
f"Preparing an ethernet hostfile for {', '.join(h.ssh_hostname for h in hosts)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that we can ssh
|
|
||||||
check_ssh_connections(hosts)
|
|
||||||
|
|
||||||
# Get the ips for each host
|
|
||||||
for h in hosts:
|
|
||||||
log(args.verbose, "Getting the ip from", h.ssh_hostname)
|
|
||||||
h.ips.append(
|
|
||||||
run(
|
|
||||||
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
).stdout.strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
hostfile = []
|
|
||||||
for h in hosts:
|
|
||||||
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
|
|
||||||
|
|
||||||
if args.output_hostfile:
|
|
||||||
with open(args.output_hostfile, "w") as f:
|
|
||||||
json.dump(hostfile, f, indent=4)
|
|
||||||
else:
|
|
||||||
print("Hostfile")
|
|
||||||
print("========")
|
|
||||||
print(json.dumps(hostfile, indent=4))
|
|
||||||
|
|
||||||
|
|
||||||
def distributed_config():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Configure remote machines for use with MLX distributed"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--verbose", action="store_true", help="Print debug messages in stdout"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--backend",
|
|
||||||
choices=["ring", "mpi", "nccl"],
|
|
||||||
default="nccl" if mx.cuda.is_available() else "ring",
|
|
||||||
help="Which distributed backend to configure",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--over",
|
|
||||||
choices=["thunderbolt", "ethernet"],
|
|
||||||
default="thunderbolt",
|
|
||||||
help="What type of connectivity to configure",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
|
|
||||||
)
|
|
||||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
|
||||||
parser.add_argument(
|
|
||||||
"--dot", action="store_true", help="Output the topology in DOT format and exit"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--hostfile-only", action="store_true", help="If set only compute the hostfile"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-hostfile", help="If provided, save the hostfile to this path"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--auto-setup",
|
|
||||||
action="store_true",
|
|
||||||
help="If set we will attempt to automatically configure the machines via ssh",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.backend == "mpi" and args.over == "thunderbolt":
|
|
||||||
raise ValueError(
|
|
||||||
(
|
|
||||||
"The configuration of MPI over thunderbolt is "
|
|
||||||
"not supported yet by mlx.distributed_config"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.hostfile is not None:
|
|
||||||
hosts = parse_hostfile(parser, args.hostfile)
|
|
||||||
else:
|
|
||||||
hosts = parse_hostlist(parser, args.hosts, 1)
|
|
||||||
|
|
||||||
if args.over == "thunderbolt":
|
|
||||||
prepare_tb_ring(args, hosts)
|
|
||||||
else:
|
|
||||||
prepare_hostfile(args, hosts)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
|
|
||||||
parser.add_argument(
|
|
||||||
"--print-python",
|
|
||||||
action="store_true",
|
|
||||||
help="Print the path to the current python executable and exit",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--verbose", action="store_true", help="Print debug messages in stdout"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--repeat-hosts",
|
|
||||||
"-n",
|
|
||||||
type=positive_number,
|
|
||||||
default=1,
|
|
||||||
help="Repeat each host a given number of times",
|
|
||||||
)
|
|
||||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
|
||||||
parser.add_argument(
|
|
||||||
"--backend",
|
|
||||||
choices=["ring", "mpi", "nccl"],
|
|
||||||
default="nccl" if mx.cuda.is_available() else "ring",
|
|
||||||
help="Which distributed backend to launch",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--env",
|
|
||||||
action="append",
|
|
||||||
default=[],
|
|
||||||
help="Set environment variables for the jobs",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mpi-arg",
|
|
||||||
action="append",
|
|
||||||
default=[],
|
|
||||||
help="Arguments to pass directly to mpirun",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--connections-per-ip",
|
|
||||||
default=1,
|
|
||||||
type=int,
|
|
||||||
help="How many connections per ip to use for the ring backend",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--starting-port",
|
|
||||||
"-p",
|
|
||||||
type=int,
|
|
||||||
default=5000,
|
|
||||||
help="For the ring backend listen on this port increasing by 1 per rank and IP",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cwd", help="Set the working directory on each node to the provided one"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--nccl-port",
|
|
||||||
type=int,
|
|
||||||
default=12345,
|
|
||||||
help="The port to use for the NCCL communication (only for nccl backend)",
|
|
||||||
)
|
|
||||||
|
|
||||||
args, rest = parser.parse_known_args()
|
|
||||||
|
|
||||||
if args.print_python:
|
|
||||||
print(sys.executable)
|
|
||||||
return
|
|
||||||
|
|
||||||
if len(rest) == 0:
|
|
||||||
parser.error("No script is provided")
|
|
||||||
if rest[0] == "--":
|
|
||||||
rest.pop(0)
|
|
||||||
|
|
||||||
# Try to extract a list of hosts and corresponding ips
|
|
||||||
if args.hostfile is not None:
|
|
||||||
hosts = parse_hostfile(parser, args.hostfile)
|
|
||||||
else:
|
|
||||||
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
|
|
||||||
|
|
||||||
# Check if the script is a file and convert it to a full path
|
|
||||||
if (script := Path(rest[0])).exists():
|
|
||||||
rest[0:1] = [sys.executable, str(script.resolve())]
|
|
||||||
elif (command := shutil.which(rest[0])) is not None:
|
|
||||||
rest[0] = command
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid script or command {rest[0]}")
|
|
||||||
|
|
||||||
# Launch
|
|
||||||
if args.backend == "ring":
|
|
||||||
launch_ring(parser, hosts, args, rest)
|
|
||||||
if args.backend == "mpi":
|
|
||||||
launch_mpi(parser, hosts, args, rest)
|
|
||||||
if args.backend == "nccl":
|
|
||||||
launch_nccl(parser, hosts, args, rest)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
auditwheel repair dist/* \
|
auditwheel repair dist/* \
|
||||||
--plat manylinux_2_35_x86_64 \
|
--plat manylinux_2_35_${1} \
|
||||||
--exclude libcublas* \
|
--exclude libcublas* \
|
||||||
--exclude libnvrtc* \
|
--exclude libnvrtc* \
|
||||||
--exclude libcuda* \
|
--exclude libcuda* \
|
||||||
|
|||||||
@@ -52,9 +52,25 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"is_available",
|
"is_available",
|
||||||
&mx::distributed::is_available,
|
[](const std::string& backend) {
|
||||||
|
return mx::distributed::is_available(backend);
|
||||||
|
},
|
||||||
|
"backend"_a = "any",
|
||||||
|
nb::sig("def is_available(backend: str = 'any') -> bool"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Check if a communication backend is available.
|
Check if a communication backend is available.
|
||||||
|
|
||||||
|
Note, this function returns whether MLX has the capability of
|
||||||
|
instantiating that distributed backend not whether it is possible to
|
||||||
|
create a communication group. For that purpose one should use
|
||||||
|
``init(strict=True)``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend (str, optional): The name of the backend to check for availability.
|
||||||
|
It takes the same values as :func:`init()`. Default: ``"any"``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the distributed backend is available.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
@@ -79,10 +95,10 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||||
it throws a runtime error. Default: ``False``
|
it throws a runtime error. Default: ``False``
|
||||||
backend (str, optional): Which distributed backend to initialize.
|
backend (str, optional): Which distributed backend to initialize.
|
||||||
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
|
Possible values ``mpi``, ``ring``, ``nccl``, ``jaccl``, ``any``. If
|
||||||
available backends are tried and the first one that succeeds
|
set to ``any`` all available backends are tried and the first one
|
||||||
becomes the global group which will be returned in subsequent
|
that succeeds becomes the global group which will be returned in
|
||||||
calls. Default: ``any``
|
subsequent calls. Default: ``any``
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Group: The group representing all the launched processes.
|
Group: The group representing all the launched processes.
|
||||||
|
|||||||
@@ -210,6 +210,14 @@ class TestReduce(mlx_tests.MLXTestCase):
|
|||||||
ref = getattr(np, op)(np_arr, axis=axis)
|
ref = getattr(np, op)(np_arr, axis=axis)
|
||||||
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
||||||
|
|
||||||
|
def test_long_column(self):
|
||||||
|
a = (np.random.randn(8192, 64) * 32).astype(np.int32)
|
||||||
|
b = mx.array(a)
|
||||||
|
|
||||||
|
c1 = a.sum(0)
|
||||||
|
c2 = b.sum(0)
|
||||||
|
self.assertTrue(np.all(c1 == c2))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner(failfast=True)
|
mlx_tests.MLXTestRunner(failfast=True)
|
||||||
|
|||||||
44
setup.py
44
setup.py
@@ -7,13 +7,21 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from subprocess import run
|
|
||||||
|
|
||||||
from setuptools import Command, Extension, find_namespace_packages, setup
|
from setuptools import Command, Extension, find_namespace_packages, setup
|
||||||
from setuptools.command.bdist_wheel import bdist_wheel
|
from setuptools.command.bdist_wheel import bdist_wheel
|
||||||
from setuptools.command.build_ext import build_ext
|
from setuptools.command.build_ext import build_ext
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_toolkit_major_version():
|
||||||
|
out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
|
||||||
|
text = out.decode()
|
||||||
|
m = re.search(r"release (\d+)", text)
|
||||||
|
if m:
|
||||||
|
return int(m.group(1))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_version():
|
def get_version():
|
||||||
with open("mlx/version.h", "r") as fid:
|
with open("mlx/version.h", "r") as fid:
|
||||||
for l in fid:
|
for l in fid:
|
||||||
@@ -31,7 +39,7 @@ def get_version():
|
|||||||
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
|
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
|
||||||
if not pypi_release and not dev_release:
|
if not pypi_release and not dev_release:
|
||||||
git_hash = (
|
git_hash = (
|
||||||
run(
|
subprocess.run(
|
||||||
"git rev-parse --short HEAD".split(),
|
"git rev-parse --short HEAD".split(),
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
check=True,
|
check=True,
|
||||||
@@ -257,8 +265,8 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
entry_points = {
|
entry_points = {
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
"mlx.launch = mlx.distributed_run:main",
|
"mlx.launch = mlx._distributed_utils.launch:main",
|
||||||
"mlx.distributed_config = mlx.distributed_run:distributed_config",
|
"mlx.distributed_config = mlx._distributed_utils.config:main",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
install_requires = []
|
install_requires = []
|
||||||
@@ -284,7 +292,11 @@ if __name__ == "__main__":
|
|||||||
install_requires.append(
|
install_requires.append(
|
||||||
f'mlx-metal=={version}; platform_system == "Darwin"'
|
f'mlx-metal=={version}; platform_system == "Darwin"'
|
||||||
)
|
)
|
||||||
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
|
extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"']
|
||||||
|
for toolkit in [12, 13]:
|
||||||
|
extras[f"cuda{toolkit}"] = [
|
||||||
|
f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"'
|
||||||
|
]
|
||||||
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
|
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
|
||||||
|
|
||||||
_setup(
|
_setup(
|
||||||
@@ -299,13 +311,25 @@ if __name__ == "__main__":
|
|||||||
if build_macos:
|
if build_macos:
|
||||||
name = "mlx-metal"
|
name = "mlx-metal"
|
||||||
elif build_cuda:
|
elif build_cuda:
|
||||||
name = "mlx-cuda"
|
toolkit = cuda_toolkit_major_version()
|
||||||
|
name = f"mlx-cuda-{toolkit}"
|
||||||
|
if toolkit == 12:
|
||||||
|
install_requires += [
|
||||||
|
"nvidia-cublas-cu12==12.9.*",
|
||||||
|
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||||
|
]
|
||||||
|
elif toolkit == 13:
|
||||||
|
install_requires += [
|
||||||
|
"nvidia-cublas-cu13",
|
||||||
|
"nvidia-cuda-nvrtc-cu13",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown toolkit {toolkit}")
|
||||||
install_requires += [
|
install_requires += [
|
||||||
"nvidia-cublas-cu12==12.9.*",
|
f"nvidia-cudnn-cu{toolkit}==9.*",
|
||||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
f"nvidia-nccl-cu{toolkit}",
|
||||||
"nvidia-cudnn-cu12==9.*",
|
|
||||||
"nvidia-nccl-cu12",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
name = "mlx-cpu"
|
name = "mlx-cpu"
|
||||||
_setup(
|
_setup(
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <climits>
|
#include <climits>
|
||||||
|
|
||||||
#include "doctest/doctest.h"
|
#include "doctest/doctest.h"
|
||||||
@@ -608,3 +607,24 @@ TEST_CASE("test make empty array") {
|
|||||||
CHECK_EQ(a.size(), 0);
|
CHECK_EQ(a.size(), 0);
|
||||||
CHECK_EQ(a.dtype(), bool_);
|
CHECK_EQ(a.dtype(), bool_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test make array from user buffer") {
|
||||||
|
int size = 4096;
|
||||||
|
std::vector<int> buffer(size, 0);
|
||||||
|
|
||||||
|
int count = 0;
|
||||||
|
auto deleter = [&count](void*) { count++; };
|
||||||
|
|
||||||
|
{
|
||||||
|
auto a = array(buffer.data(), Shape{size}, int32, deleter);
|
||||||
|
if (metal::is_available()) {
|
||||||
|
CHECK_EQ(buffer.data(), a.data<int>());
|
||||||
|
}
|
||||||
|
auto b = a + array(1);
|
||||||
|
eval(b);
|
||||||
|
auto expected = ones({4096});
|
||||||
|
CHECK(array_equal(b, expected).item<bool>());
|
||||||
|
}
|
||||||
|
// deleter should always get called
|
||||||
|
CHECK_EQ(count, 1);
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user