Compare commits

..

34 Commits

Author SHA1 Message Date
Angelos Katharopoulos
d2bc340df4 Update the mlx.launch and mlx.distributed_config docs 2025-12-12 17:03:38 -08:00
Angelos Katharopoulos
fabc947df4 Possible doc fix 2025-12-12 15:34:59 -08:00
Angelos Katharopoulos
5523087cfb Finish the distributed docs 2025-12-12 14:16:16 -08:00
Angelos Katharopoulos
2f939acefa Progress with the docs 2025-12-12 04:36:46 -08:00
Angelos Katharopoulos
3b416a2e36 Comments 2025-12-12 01:21:43 -08:00
Angelos Katharopoulos
753c6a4d0f Disable echo for interactive distributed 2025-12-11 17:47:42 -08:00
Angelos Katharopoulos
d3a754c8aa Fix config when more cables are connected 2025-12-10 02:14:29 -08:00
Angelos Katharopoulos
595a4ad206 Improve interactivity of mlx.launch 2025-12-10 00:54:27 -08:00
Angelos Katharopoulos
a4dc1fac6c Set the shell to bash explicitly 2025-12-09 15:17:09 -08:00
Angelos Katharopoulos
ebda161a86 Remove old joined script 2025-12-09 13:39:57 -08:00
Angelos Katharopoulos
fa31a4b295 Add more checks and improve errors 2025-12-09 13:36:17 -08:00
Angelos Katharopoulos
9d707ba3b5 Remove python from the launch script 2025-12-09 13:04:37 -08:00
Angelos Katharopoulos
405d30b6e5 Refactor distributed config 2025-12-09 05:58:44 -08:00
Angelos Katharopoulos
cd4b12ce1b Refactoring launcher 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
425043ccca Change the name to a fun pun 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
95d92af8a0 Add headers for gcc 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
bfdddd644b Expose per-backend availability in C++ and python 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
1216afdc91 Add a no_ibv 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
04e94d78bb Add empty sum_scatter 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
60d4e8b2a8 Add send/recv 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
c5745fddd2 Make sure that there is space for work completions 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
e937a8033f Add working reduce and semi-working all gather 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
4dfe02d7c6 Fix ring 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
5c2cff9329 Fix side channel initialization for more than 2 peers 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
325dab9559 All gather 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
67e454ab0a Initial working all reduce 2025-12-08 15:50:05 -08:00
Awni Hannun
27232db1ba [CUDA] Enable more graphs to be updatable (#2883)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-08 06:18:01 -08:00
Awni Hannun
a4b3bc969b Try not to fail when there should be memory available (#2869)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-07 06:11:00 -08:00
Awni Hannun
667c0f3bb9 [Metal] No copy array init (#2875)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-05 13:36:45 -08:00
Cheng
6245824d42 Make allocator::malloc throw on allocation failure (#2874)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-05 17:44:38 +09:00
Awni Hannun
39289ef025 [CUDA] Release build for cuda 13 (#2872) 2025-12-04 21:42:26 -08:00
Awni Hannun
aefc9bd3f6 [CUDA] Faster general copy (#2873) 2025-12-04 21:42:15 -08:00
Angelos Katharopoulos
997cfc7699 Add a 2-pass col reduce for CUDA (#2863)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-04 15:53:59 -08:00
Awni Hannun
1fa8dc5797 Do a PyPi release for cuda on arm (#2866) 2025-12-04 15:28:29 -08:00
34 changed files with 1606 additions and 1170 deletions

View File

@@ -1,6 +1,15 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
steps:
@@ -12,4 +21,4 @@ runs:
pip install auditwheel build patchelf setuptools
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}

View File

@@ -15,6 +15,7 @@ runs:
using: "composite"
steps:
- name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}

View File

@@ -128,7 +128,11 @@ jobs:
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
strategy:
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.9', 'cuda-13.0']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
@@ -136,9 +140,11 @@ jobs:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: 'cuda-12.9'
toolkit: ${{ matrix.toolkit }}
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:

View File

@@ -119,10 +119,6 @@ if(MLX_BUILD_METAL)
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path"
OUTPUT_VARIABLE CMAKE_OSX_SYSROOT
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

View File

@@ -29,17 +29,20 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell
pip install mlx[cuda]
pip install mlx[cuda12]
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia architecture >= SM 7.5
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.10
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
CPU-only (Linux)
^^^^^^^^^^^^^^^^

View File

@@ -7,22 +7,29 @@ Distributed Communication
MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the
moment we support three different communication backends:
moment we support several different communication backends introduced below.
.. list-table::
:widths: 20 80
:header-rows: 1
* - Backend
- Description
* - :ref:`MPI <mpi_section>`
- A full featured and mature distributed communications library.
* - :ref:`RING <ring_section>`
- Ring all reduce and all gather over TCP sockets. Always available and
usually faster than MPI.
* - :ref:`JACCL <ring_section>`
- Low latency communication with RDMA over thunderbolt. Necessary for
things like tensor parallelism.
* - :ref:`NCCL <nccl_section>`
- The backend of choice for CUDA environments.
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets. It should be
faster for thunderbolt connections, but it also works over Ethernet.
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.
The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`.
.. note::
Some operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
@@ -85,7 +92,7 @@ Selecting Backend
^^^^^^^^^^^^^^^^^
You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
one of ``{'any', 'ring', 'jaccl', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
available backends. If they all fail then a singleton group is created.
.. note::
@@ -110,6 +117,8 @@ The following examples aim to clarify the backend initialization logic in MLX:
world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init() # same as MPI because it was initialized first!
.. _training_example:
Training Example
----------------
@@ -192,16 +201,273 @@ almost identical to the example above:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
.. _ring_section:
Getting Started with Ring
-------------------------
The ring backend does not depend on any third party library so it is always
available. It uses TCP sockets so the nodes need to be reachable via a network.
As the name suggests the nodes are connected in a ring which means that rank 1
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
and so on and so forth. As a result :func:`send` and :func:`recv` with
arbitrary sender and receiver is not supported in the ring backend.
Defining a Ring
^^^^^^^^^^^^^^^
The easiest way to define and use a ring is via a JSON hostfile and the
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
defines a hostname to ssh into to run commands on this node and one or more IPs
that this node will listen to for connections.
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
rank 0, ``hostname2`` rank 1 etc.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
]
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
node, run the script which will listen for connections in each of the provided
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
connection from ``123.123.123.4`` and so on and so forth.
Thunderbolt Ring
^^^^^^^^^^^^^^^^
Although the ring backend can have benefits over MPI even for Ethernet, its
main purpose is to use Thunderbolt rings for higher bandwidth communication.
Setting up such thunderbolt rings can be done manually, but is a relatively
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
utility as follows:
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --backend ring
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
If you want to go through the process manually, the steps are as follows:
* Disable the thunderbolt bridge interface
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
corresponding to that cable in nodes ``i`` and ``i + 1``.
* Set up a unique subnetwork connecting the two nodes for the corresponding
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.
.. _jaccl_section:
Getting Started with RDMA over Thunderbolt
------------------------------------------
Starting from version 26.2 RDMA over thunderbolt is available in MacOS and
enables low-latency communication between Macs with thunderbolt 5. MLX provides
the JACCL backend that uses this functionality to achieve communication latency
an order of magnitude lower than the ring backend.
.. note::
The name JACCL (pronounced Jackal) stands for *Jack and Angelos' Collective
Communication Library* and it is an obvious pun to Nvidia's NCCL but also
tribute to *Jack Beasley* who led the development of RDMA over Thunderbolt
at Apple.
Enabling RDMA
^^^^^^^^^^^^^
Until the feature matures, enabling RDMA over thunderbolt is slightly more
involved and **cannot** be done remotely even with sudo. In fact, it has to be
done in macOS recovery:
1. `Start your computer in recovery <https://support.apple.com/en-us/102518>`_.
2. Open the Terminal by going to Utilities -> Terminal.
3. Run ``rdma_ctl enable``.
4. Reboot.
To verify that you have successfully enabled Thunderbolt RDMA you can run
``ibv_devices`` which should produce something like the following for an M3 Ultra.
.. code-block:: bash
~ % ibv_devices
device node GUID
------ ----------------
rdma_en2 8096a9d9edbaac05
rdma_en3 8196a9d9edbaac05
rdma_en5 8396a9d9edbaac05
rdma_en4 8296a9d9edbaac05
rdma_en6 8496a9d9edbaac05
rdma_en7 8596a9d9edbaac05
Defining a Mesh
^^^^^^^^^^^^^^^
The JACCL backend supports only fully connected topologies. Namely, there needs
to be a thunderbolt cable connecting all pairs of Macs directly. For example, in
the following topology visualizations, the left one is valid because there is a
connection from any node to any other node, while for the one on the right M3
Ultra 1 is not connected to M3 Ultra 2.
.. raw:: html
<div style="display: flex; text-align: center; align-items: end; font-size: 80%;">
<div>
<img src="/_static/distributed/m3-ultra-mesh.png" alt="M3 Ultra thunderbolt mesh" style="width: 55%">
<p>Fully connected mesh of four M3 Ultra.</p>
</div>
<div>
<img src="/_static/distributed/m3-ultra-mesh-broken.png" alt="M3 Ultra broken thunderbolt mesh" style="width: 55%">
<p>Not a valid mesh (M3 Ultra 1 is not connected to M3 Ultra 2).</p>
</div>
</div>
Similar to the ring backend, the easiest way to use JACCL with MLX is to write
a JSON hostfile that will be used by ``mlx.launch``. The hostfile needs to contain
- Hostnames to use for launching scripts via ssh
- An IP for rank 0 that is reachable by all nodes
- A list of rdma devices that connect each node to each other node
The following JSON defines the valid 4-node mesh from the image above.
.. code-block:: json
[
{
"ssh": "m3-ultra-1",
"ips": ["123.123.123.1"],
"rdma": [null, "rdma_en5", "rdma_en4", "rdma_en3"]
},
{
"ssh": "m3-ultra-2",
"ips": [],
"rdma": ["rdma_en5", null, "rdma_en3", "rdma_en4"]
},
{
"ssh": "m3-ultra-3",
"ips": [],
"rdma": ["rdma_en4", "rdma_en3", null, "rdma_en5"]
},
{
"ssh": "m3-ultra-4",
"ips": [],
"rdma": ["rdma_en3", "rdma_en4", "rdma_en5", null]
}
]
Even though TCP/IP is not used when communicating with Thunderbolt RDMA,
disabling the thunderbolt bridge is still required as well as setting up
isolated local networks for each thunderbolt connection.
All of the above can be done instead via ``mlx.distributed_config``. This helper
script will
- ssh into each node
- extract the thunderbolt connectivity
- check for a valid mesh
- provide the commands to configure each node (or run them if sudo is available)
- generate the hostfile to be used with ``mlx.launch``
Putting it All Together
^^^^^^^^^^^^^^^^^^^^^^^^
For example launching a distributed MLX script that uses JACCL is fairly simple
if the nodes are reachable via ssh and have password-less sudo.
First, connect all the thunderbolt cables. Then we can verify the connections
by using the ``mlx.distributed_config`` script to visualize them.
.. code-block::
mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
--over thunderbolt --dot | dot -Tpng | open -f -a Preview
After making sure that everything looks right we can auto-configure the nodes
and save the hostfile to ``m3-ultra-jaccl.json`` by running:
.. code-block::
mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
--over thunderbolt --backend jaccl \
--auto-setup --output m3-ultra-jaccl.json
And now we are ready to run a distributed MLX script such as distributed inference
of a gigantic model using MLX-LM.
.. code-block::
mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \
--env MLX_METAL_FAST_SYNCH=1 -- \ # <--- important
/path/to/remote/python -m mlx_lm chat --model mlx-community/DeepSeek-V3.2-8bit --shard
.. note::
Defining the environment variable ``MLX_METAL_FAST_SYNCH=1`` enables a
different, faster way of synchronizing between the GPU and the CPU. It is
not specific to the JACCL backend and can be used in all cases where the CPU
and GPU need to collaborate for some computation and is pretty critical for
low-latency communication since the communication is done by the CPU.
.. _nccl_section:
Getting Started with NCCL
-------------------------
MLX on CUDA environments ships with the ability to talk to `NCCL
<https://developer.nvidia.com/nccl>`_ which is a high-performance collective
communication library that supports both multi-gpu and multi-node setups.
For CUDA environments, NCCL is the default backend for ``mlx.launch`` and all
it takes to run a distributed job is
.. code-block::
mlx.launch -n 8 test.py
# perfect for interactive scripts
mlx.launch -n 8 python -m mlx_lm chat --model my-model --shard
You can also use ``mlx.launch`` to ssh to a remote node and launch a script
with the same ease
.. code-block::
mlx.launch --hosts my-cuda-node -n 8 test.py
In many cases you may not want to use ``mlx.launch`` with the NCCL backend
because the cluster scheduler will be the one launching the processes. You can
:ref:`see which environment variables need to be defined <no_mlx_launch>` in
order for the MLX NCCL backend to be initialized correctly.
.. _mpi_section:
Getting Started with MPI
------------------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. Launching distributed MLX programs that use MPI can be done with
``mpirun`` as expected. However, in the following examples we will be using
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
library.
MLX already comes with the ability to "talk" to `MPI
<https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ if it is installed
on the machine. Launching distributed MLX programs that use MPI can be done
with ``mpirun`` as expected. However, in the following examples we will be
using ``mlx.launch --backend mpi`` which takes care of some nuisances such as
setting absolute paths for the ``mpirun`` executable and the ``libmpi.dyld``
shared library.
The simplest possible usage is the following which, assuming the minimal
example in the beginning of this page, should result in:
@@ -269,78 +535,116 @@ Force MPI to use the most performant network interface by setting ``--mca
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
to use.
Getting Started with Ring
-------------------------
.. _no_mlx_launch:
The ring backend does not depend on any third party library so it is always
available. It uses TCP sockets so the nodes need to be reachable via a network.
As the name suggests the nodes are connected in a ring which means that rank 1
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
and so on and so forth. As a result :func:`send` and :func:`recv` with
arbitrary sender and receiver is not supported in the ring backend.
Distributed Without ``mlx.launch``
----------------------------------
Defining a Ring
^^^^^^^^^^^^^^^
None of the implementations of the distributed backends require launching with
``mlx.launch``. The script simply connects to each host. Starts a process per
rank and sets up the necessary environment variables before delegating to your
MLX script. See the :doc:`dedicated documentation page <launching_distributed>`
for more details.
The easiest way to define and use a ring is via a JSON hostfile and the
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
defines a hostname to ssh into to run commands on this node and one or more IPs
that this node will listen to for connections.
For many use-cases this will be the easiest way to perform distributed
computations in MLX. However, there may be reasons that you cannot or should
not use ``mlx.launch``. A common such case is the use of a scheduler that
starts all the processes for you on machines undetermined at the time of
scheduling the job.
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
rank 0, ``hostname2`` rank 1 etc.
Below we list the environment variables required to use each backend.
.. code:: json
Ring
^^^^^^
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
]
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
node, run the script which will listen for connections in each of the provided
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
connection from ``123.123.123.4`` and so on and so forth.
**MLX_HOSTFILE** should contain the path to a json file that contains IPs and
ports for each rank to listen to, something like the following:
Thunderbolt Ring
^^^^^^^^^^^^^^^^
.. code-block:: json
Although the ring backend can have benefits over MPI even for Ethernet, its
main purpose is to use Thunderbolt rings for higher bandwidth communication.
Setting up such thunderbolt rings can be done manually, but is a relatively
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
[
["123.123.1.1:5000", "123.123.1.2:5000"],
["123.123.2.1:5000", "123.123.2.2:5000"],
["123.123.3.1:5000", "123.123.3.2:5000"],
["123.123.4.1:5000", "123.123.4.2:5000"]
]
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
utility as follows:
**MLX_RING_VERBOSE** is optional and if set to 1 it enables some more logging
from the distributed backend.
.. code:: shell
JACCL
^^^^^
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
**MLX_JACCL_COORDINATOR** should contain the IP and port that rank 0 can listen
to all the other ranks connect to in order to establish the RDMA connections.
To validate your connection without configuring anything
``mlx.distributed_config`` can also plot the ring using DOT format.
**MLX_IBV_DEVICES** should contain the path to a json file that contains the
ibverbs device names that connect each node to each other node, something like
the following:
.. code:: shell
.. code-block:: json
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
dot -Tpng ring.dot >ring.png
open ring.png
[
[null, "rdma_en5", "rdma_en4", "rdma_en3"],
["rdma_en5", null, "rdma_en3", "rdma_en4"],
["rdma_en4", "rdma_en3", null, "rdma_en5"],
["rdma_en3", "rdma_en4", "rdma_en5", null]
]
If you want to go through the process manually, the steps are as follows:
* Disable the thunderbolt bridge interface
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
corresponding to that cable in nodes ``i`` and ``i + 1``.
* Set up a unique subnetwork connecting the two nodes for the corresponding
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.
NCCL
^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_WORLD_SIZE** should contain the total number of processes that will be
launched.
**NCCL_HOST_IP** and **NCCL_PORT** should contain the IP and port that all
hosts can connect to to establish the NCCL communication.
**CUDA_VISIBLE_DEVICES** should contain the local index of the gpu that
corresponds to this process.
Of course any `other environment variable
<https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html>`_ that is
used by NCCL can be set.
.. _tips_and_tricks:
Tips and Tricks
----------------
This is a small collection of tips to help you utilize better the distributed
communication capabilities of MLX.
- *Test locally first.*
You can use the pattern ``mlx.launch -n2 -- my_script.py`` to run a small
scale test on a single node first.
- *Batch your communication.*
As described in the :ref:`training example <training_example>`, performing a
lot of small communication can hurt performance. Copy the approach of
:func:`mlx.nn.average_gradients` to gather many small communications in a
single large one.
- *Visualize the connectivity.*
Use ``mlx.distributed_config --hosts h1,h2,h3 --over thunderbolt --dot`` to
visualize the connnections and make sure that the cables are connected
correctly. See the :ref:`JACCL section <jaccl_section>` for examples.
- *Use the debugger.*
``mlx.launch`` is meant for interactive use. It broadcasts stdin to all
processes and gathers stdout from all processes. This makes using ``pdb`` a
breeze.

View File

@@ -7,13 +7,106 @@ Launching Distributed Programs
.. currentmodule:: mlx.core.distributed
Installing the MLX python package provides a helper script ``mlx.launch`` that
can be used to run python scripts distributed on several nodes. It allows
launching using either the MPI backend or the ring backend. See the
:doc:`distributed docs <distributed>` for the different backends.
Installing the MLX python package provides two utilities to help you configure
your Macs for distributed computation and also launch distributed programs on
multiple nodes or with many processes in a single node. These utilities are aptly named
Usage
-----
- ``mlx.launch``
- ``mlx.distributed_config``
See the :doc:`distributed docs <distributed>` for an introduction and
getting-started guides to the various backends.
``mlx.distributed_config``
---------------------------
Unless you are launching distributed jobs locally for development or multi-gpu
CUDA environments, then you have several Macs that you need to configure for
distributed communication with MLX.
``mlx.distributed_config`` aims to automate the process of configuring the
network interfaces (especially for communication over thunderbolt) and also
creating the hostfile to be used with ``mlx.launch``.
We will analyse 3 cases of using ``mlx.distributed_config``
1. RDMA over thunderbolt using JACCL
2. TCP/IP over thunderbolt using the ring backend
3. TCP/IP over ethernet using the ring backend
JACCL
^^^^^^^
After following :ref:`the steps to enable RDMA <jaccl_section>` you can run the
following command to configure the nodes and create the hostfile.
.. code-block::
mlx.distributed_config --verbose --backend jaccl \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 --over thunderbolt \
--auto-setup --output m3-ultra-jaccl.json
Let's walk through the steps that the script takes to configure the nodes.
1. Ssh to all nodes to verify that they are reachable
2. Extract the thunderbolt connectivity. Namely run commands on each node to
calculate which node is connected to which other node.
3. Verify that we have a valid fully connected mesh
4. Check that RDMA is enabled
5. Extract the ethernet IP from interface en0
6. Disable the thunderbolt bridge and set up peer to peer networks for each
thunderbolt cable
7. Write the hostfile
Knowing the above steps allows you to manually configure the nodes but also
debug any configuration issue. For instance changing the Ethernet IP to a
different interface directly in the config is possible (as long as it is
reachable from all nodes).
The ``--auto-setup`` argument requires password-less sudo on each node. If it
isn't available then the configuration script will print commands to be run on
each node.
Ring over thunderbolt
^^^^^^^^^^^^^^^^^^^^^
Setting up a ring backend over thunderbolt only requires changing the
``--backend`` from ``jaccl`` to ``ring``.
The steps are very similar with the main difference being that instead of
verifying that the nodes are fully connected, the script attempts to identify a
ring topology (or multiple rings).
Ring over Ethernet
^^^^^^^^^^^^^^^^^^
Configuring the ring backend over ethernet doesn't require setting up network
interface and as such it simply extracts the ``en0`` IP from each node and
writes the hostfile.
Debugging cable connections
^^^^^^^^^^^^^^^^^^^^^^^^^^^
``mlx.distributed_config`` can help you debug the connectivity of your nodes
over thunderbolt by exporting a graph of the connections.
Running
.. code-block::
mlx.distributed_config --verbose \
--hosts host1,host2,host3,host4 \
--over thunderbolt --dot
will export a `GraphViz <https://graphviz.org>`_ representation of the
connections between the nodes which makes it very easy to figure out which
cable is not connected correctly.
See :ref:`the JACCL section <jaccl_section>` for an example.
``mlx.launch``
--------------
The minimal usage example of ``mlx.launch`` is simply
@@ -33,6 +126,10 @@ the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
It also takes care of forwarding the output of each remote process to stdout
and stderr respectively.
Importantly, it also broadcasts stdin to each process which enables interactive
programs to work in distributed mode as well as debugging using the interactive
debugger.
Providing Hosts
^^^^^^^^^^^^^^^^
@@ -63,10 +160,62 @@ host and on the same path. A good checklist to debug errors is the following:
``mlx.launch --print-python`` to see what that path is.
* the script you want to run is available on all hosts at the same path
If you are launching from a node with a completely different setup than the
nodes that the program will run on, you can specify ``--no-verify-script`` so
that ``mlx.launch`` does not attempt to verify that the executable and script
exist locally before launching the distributed job.
.. _ring_specifics:
Ring Specifics
^^^^^^^^^^^^^^
The :ref:`ring <ring_section>` backend, which is also the default
backend, can be explicitly selected with the argument ``--backend ring``. The
ring backend has some specific requirements and arguments that are different to
other backends:
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
ssh to a hostname that does not correspond to the IP we want to bind to we
have to provide a hostfile.
* ``--starting-port`` defines the port to bind to on the remote hosts.
Specifically rank 0 for the first IP will use this port and each subsequent
IP or rank will add 1 to this port.
* ``--connections-per-ip`` allows us to increase the number of connections
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
``mpirun``.
.. _jaccl_specifics:
JACCL Specifics
^^^^^^^^^^^^^^^^
The :ref:`JACCL <jaccl_section>` backend can be selected with the argument
``--backend jaccl``. A hostfile is necessary to launch with this backend
because it needs to contain the RDMA devices connecting each node to each other
node.
NCCL Specifics
^^^^^^^^^^^^^^
The :ref:`NCCL <nccl_section>` backend is the default backend for CUDA
environments. When launching from a Mac to a Linux machine with CUDA then the
backend should be selected using ``--backend nccl``.
The ``--repeat-hosts, -n`` argument should be used to launch multi-node and
multi-gpu jobs. For instance
.. code-block::
mlx.launch --backend nccl --hosts linux-1,linux-2 -n 8 --no-verify-script -- ./my-job.sh
will attempt to launch 16 processes, 8 on each node that will all run
``my-job.sh``.
.. _mpi_specifics:
MPI Specifics
-------------
^^^^^^^^^^^^^
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
@@ -83,23 +232,3 @@ to choose a specific interface for the byte-transfer-layer of MPI we can call
.. code:: shell
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
.. _ring_specifics:
Ring Specifics
--------------
The ring backend, which is also the default backend, can be explicitly selected
with the argument ``--backend ring``. The ring backend has some specific
requirements and arguments that are different to MPI:
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
ssh to a hostname that does not correspond to the IP we want to bind to we
have to provide a hostfile.
* ``--starting-port`` defines the port to bind to on the remote hosts.
Specifically rank 0 for the first IP will use this port and each subsequent
IP or rank will add 1 to this port.
* ``--connections-per-ip`` allows us to increase the number of connections
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
``mpirun``.

View File

@@ -1,7 +1,6 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp

View File

@@ -1,24 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cstdlib>
#include <sstream>
#include "mlx/allocator.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
void free(Buffer buffer) {
allocator().free(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -28,16 +28,16 @@ class Buffer {
};
};
Buffer malloc(size_t size);
void free(Buffer buffer);
class Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
virtual Buffer make_buffer(void* ptr, size_t size) {
return Buffer{nullptr};
};
virtual void release(Buffer buffer) {}
Allocator() = default;
Allocator(const Allocator& other) = delete;
@@ -49,4 +49,25 @@ class Allocator {
Allocator& allocator();
inline Buffer malloc(size_t size) {
return allocator().malloc(size);
}
inline void free(Buffer buffer) {
allocator().free(buffer);
}
// Make a Buffer from a raw pointer of the given size without a copy. If a
// no-copy conversion is not possible then the returned buffer.ptr() will be
// nullptr. Any buffer created with this function must be released with
// release(buffer)
inline Buffer make_buffer(void* ptr, size_t size) {
return allocator().make_buffer(ptr, size);
};
// Release a buffer from the allocator made with make_buffer
inline void release(Buffer buffer) {
allocator().release(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -82,6 +82,28 @@ array::array(std::initializer_list<int> data, Dtype dtype)
init(data.begin());
}
array::array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
auto buffer = allocator::make_buffer(data, nbytes());
if (buffer.ptr() == nullptr) {
set_data(allocator::malloc(nbytes()));
auto ptr = static_cast<char*>(data);
std::copy(ptr, ptr + nbytes(), this->data<char>());
deleter(data);
} else {
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
auto ptr = buffer.ptr();
allocator::release(buffer);
return deleter(ptr);
};
set_data(buffer, std::move(wrapped_deleter));
}
}
/* Build an array from a shared buffer */
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {

View File

@@ -57,6 +57,16 @@ class array {
Shape shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a raw pointer. The constructor will attempt to use the
* input data without a copy. The deleter will be called when the array no
* longer needs the underlying memory - after the array is destroyed in the
* no-copy case and after the copy otherwise. */
explicit array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter);
/* Build an array from a buffer */
explicit array(
allocator::Buffer data,

View File

@@ -20,6 +20,19 @@ constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8;
#if CUDART_VERSION >= 13000
inline cudaMemLocation cuda_mem_loc(int i) {
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
return loc;
}
#else
inline int cuda_mem_loc(int i) {
return i;
}
#endif // CUDART_VERSION >= 13000
// The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;
@@ -35,13 +48,7 @@ SmallSizePool::SmallSizePool() {
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
auto loc = cuda_mem_loc(i);
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
@@ -90,9 +97,10 @@ CudaAllocator::CudaAllocator()
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.9;
size_t free;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
memory_limit_ = total_memory_ * 0.95;
free_limit_ = total_memory_ - memory_limit_;
max_pool_size_ = memory_limit_;
int device_count = 0;
@@ -104,6 +112,10 @@ CudaAllocator::CudaAllocator()
cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s);
cudaMemPool_t mem_pool;
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pool, i));
mem_pools_.push_back(mem_pool);
}
CHECK_CUDA_ERROR(cudaSetDevice(curr));
}
@@ -154,23 +166,35 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
}
lock.unlock();
if (!buf) {
cudaError_t err;
void* data = nullptr;
if (device == -1) {
err = cudaMallocManaged(&data, size);
CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
} else {
err = cudaMallocAsync(&data, size, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
}
if (!data) {
return Buffer{nullptr};
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
buf = new CudaBuffer{data, size, device};
}
lock.lock();
// If any cuda memory pool has too much reserved memory, clear some
// memory from the cache. This prevents graph / kernel execution failing
// from OOM
if (get_cache_memory() > 0) {
for (auto p : mem_pools_) {
size_t used = 0;
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
p, cudaMemPoolAttrReservedMemCurrent, &used));
if (used > (total_memory_ - free_limit_)) {
buffer_cache_.release_cached_buffers(free_limit_);
break;
}
}
}
}
active_memory_ += buf->size;
peak_memory_ = std::max(active_memory_, peak_memory_);

View File

@@ -71,11 +71,14 @@ class CudaAllocator : public allocator::Allocator {
std::mutex mutex_;
size_t memory_limit_;
size_t free_limit_;
size_t total_memory_;
size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_;
std::vector<cudaMemPool_t> mem_pools_;
SmallSizePool scalar_pool_;
};

View File

@@ -95,11 +95,14 @@ void copy_general_input(
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
int work_per_thread = 1;
int work_per_thread = 8;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
if (dim0 >= 4 && dim0 < 8) {
work_per_thread = 4;
} else if (dim0 < 4) {
work_per_thread = 1;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
@@ -110,7 +113,10 @@ void copy_general_input(
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 4) {
if (work_per_thread == 8) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
} else if (work_per_thread == 4) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
}
@@ -127,7 +133,9 @@ void copy_general_input(
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
if (work_per_thread == 8) {
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
} else if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(

View File

@@ -318,46 +318,52 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
insert_graph_dependencies(GraphNode{node, "K"});
}
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
// CUDA graphs do not get updated correctly if a kernel node getting updated
// has a different cluster shape than the node it's being updated with.
std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
// Constructs a key representing the nodes of a sub-graph.
// Also checks if the sub-graph is updatable as CUDA graphs do not get
// updated correctly if a kernel node getting updated has a different cluster
// shape than the node it's being updated with.
std::string key = "(";
size_t num_nodes = 0;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
if (num_nodes == 0) {
return true;
return {key + ")", true};
}
bool is_updatable = true;
std::vector<cudaGraphNode_t> nodes(num_nodes);
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
for (const auto& node : nodes) {
if (!is_updatable) {
break;
}
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeGraph) {
// Try to be updatable for a structure like graph -> graph -> kernel
if (num_nodes > 1) {
return false;
}
cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
return is_graph_updatable(child, cluster_dim_x);
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
is_updatable &= sub_is_updatable;
key += subkey;
} else if (type == cudaGraphNodeTypeMemset) {
key += "M";
} else if (type != cudaGraphNodeTypeKernel) {
return false;
is_updatable = false;
} else {
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only dim.x can be greater than 1
// Only allow dim.x to be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
return false;
is_updatable = false;
} else {
key += "K";
key += std::to_string(cluster_dim.clusterDim.x);
}
// Only one child node allowed when subgraph uses clusters
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
return false;
}
cluster_dim_x = cluster_dim.clusterDim.x;
}
}
return true;
key += ")";
return {key, is_updatable};
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
@@ -370,11 +376,10 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
return;
}
cudaGraphNode_t node;
int cluster_dim_x = 0;
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
auto [sub_graph_key, is_updatable] = subgraph_to_key(child);
is_graph_updatable_ &= is_updatable;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
insert_graph_dependencies(GraphNode{node, sub_graph_key});
}
bool CommandEncoder::needs_commit() {

View File

@@ -106,7 +106,7 @@ class CommandEncoder {
cudaGraphNode_t node;
// K = kernel
// E = empty
// G* = subgraph (with metadata)
// () = subgraph (with metadata)
// Symbols ':', '-' are reserved as separators
std::string node_type;
std::string id;

View File

@@ -89,9 +89,13 @@ template <
int NDIM,
int BM,
int BN,
int N_READS = 4>
__global__ void
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
int N_READS = 4,
int BLOCKS = 1>
__global__ void col_reduce_looped(
T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
int64_t out_size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -102,6 +106,8 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
size_t tile_idx = grid.block_rank();
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
size_t tile_out = tile_y / out_size;
tile_y = tile_y % out_size;
// Compute the indices for the thread within the tile
short thread_x = block.thread_rank() % threads_per_row;
@@ -118,12 +124,23 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
totals[i] = ReduceInit<Op, T>::value();
}
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
size_t total = args.non_col_reductions * args.reduction_size;
size_t per_block, start, end;
if constexpr (BLOCKS > 1) {
per_block = (total + BLOCKS - 1) / BLOCKS;
start = tile_out * per_block + thread_y;
end = min((tile_out + 1) * per_block, total);
} else {
per_block = total;
start = thread_y;
end = total;
}
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
if (tile_x * BN + BN <= args.reduction_stride) {
if (args.reduction_stride % N_READS == 0) {
for (size_t r = thread_y; r < total; r += BM) {
for (size_t r = start; r < end; r += BM) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
@@ -132,7 +149,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
} else {
for (size_t r = thread_y; r < total; r += BM) {
for (size_t r = start; r < end; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
@@ -142,7 +159,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
}
}
} else {
for (size_t r = thread_y; r < total; r += BM) {
for (size_t r = start; r < end; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(
thread_x,
@@ -173,6 +190,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
// Write result.
if (warp.thread_rank() == 0) {
if (BLOCKS > 1) {
out += tile_out * out_size * args.reduction_stride;
}
cub::StoreDirectBlocked(
warp.meta_group_rank(),
out + tile_y * args.reduction_stride + tile_x * BN,
@@ -227,11 +247,12 @@ __global__ void col_reduce_small(
inline auto output_grid_for_col_reduce(
const array& out,
const cu::ColReduceArgs& args,
int bn) {
int bn,
int outer = 1) {
int gx, gy = 1;
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
size_t n_outer_blocks = out.size() / args.reduction_stride;
size_t n_blocks = n_outer_blocks * n_inner_blocks;
size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
while (n_blocks / gy > INT32_MAX) {
gy *= 2;
}
@@ -277,7 +298,8 @@ void col_reduce_looped(
0,
indata,
gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args));
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
});
});
});
@@ -320,6 +342,117 @@ void col_reduce_small(
});
}
void col_reduce_two_pass(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes, encoder);
// Allocate an intermediate array to hold the 1st pass result
constexpr int outer = 32;
Shape intermediate_shape;
intermediate_shape.push_back(outer);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
Strides intermediate_strides;
intermediate_strides.push_back(out.size());
intermediate_strides.insert(
intermediate_strides.end(), out.strides().begin(), out.strides().end());
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
auto [data_size, rc, cc] =
check_contiguity(intermediate_shape, intermediate_strides);
auto fl = out.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = true;
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder),
data_size,
intermediate_strides,
fl,
allocator::free);
encoder.add_temporary(intermediate);
encoder.set_input_array(in);
encoder.set_output_array(intermediate);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(gpu_ptr<T>(in));
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
int blocks = BM * BN / N_READS;
auto kernel = cu::
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
indata,
gpu_ptr<U>(intermediate),
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
});
});
});
// Prepare the reduction arguments for the 2nd pass
cu::ColReduceArgs second_args = args;
second_args.reduction_size = outer;
second_args.reduction_stride = out.size();
second_args.ndim = 0;
second_args.reduce_shape[0] = outer;
second_args.reduce_strides[0] = out.size();
second_args.reduce_ndim = 1;
second_args.non_col_reductions = 1;
encoder.set_input_array(intermediate);
encoder.set_output_array(out);
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
int blocks = BM * BN / N_READS;
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
gpu_ptr<T>(intermediate),
gpu_ptr<U>(out),
second_args,
second_args.reduction_stride);
});
});
});
}
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
@@ -334,6 +467,18 @@ void col_reduce(
// It is a general strided reduce. Each threadblock computes the output for
// a subrow of the fast moving axis. For instance 32 elements.
//
// - col_reduce_small
//
// It is a column reduce for small columns. Each thread loops over the whole
// column without communicating with any other thread.
//
// - col_reduce_two_pass
//
// It is a reduce for long columns. To increase parallelism, we split the
// reduction in two passes. First we do a column reduce where many
// threadblocks operate on different parts of the reduced axis. Then we
// perform a final column reduce.
//
// Notes: As in row reduce we opt to read as much in order as possible and
// leave transpositions as they are (contrary to our Metal backend).
//
@@ -349,6 +494,14 @@ void col_reduce(
return;
}
// Long column with smallish row
size_t total_sums = args.non_col_reductions * args.reduction_size;
size_t approx_threads = out.size();
if (total_sums / approx_threads > 32) {
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
}

View File

@@ -80,7 +80,6 @@ CudaGraph::CudaGraph(cu::Device& device) {
}
void CudaGraph::end_capture(cudaStream_t stream) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
}

View File

@@ -7,8 +7,6 @@
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}

View File

@@ -149,7 +149,9 @@ Buffer MetalAllocator::malloc(size_t size) {
buf = device_->newBuffer(size, resource_options);
}
if (!buf) {
return Buffer{nullptr};
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
lk.lock();
num_resources_++;
@@ -201,6 +203,32 @@ size_t MetalAllocator::size(Buffer buffer) const {
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
}
Buffer MetalAllocator::make_buffer(void* ptr, size_t size) {
auto buf = device_->newBuffer(ptr, size, resource_options, nullptr);
if (!buf) {
return Buffer{nullptr};
}
std::unique_lock lk(mutex_);
residency_set_.insert(buf);
active_memory_ += buf->length();
peak_memory_ = std::max(peak_memory_, active_memory_);
num_resources_++;
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::release(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (buf == nullptr) {
return;
}
std::unique_lock lk(mutex_);
active_memory_ -= buf->length();
num_resources_--;
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
buf->release();
}
MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator
// will not be called on exit and buffers in the cache will be leaked. This

View File

@@ -21,6 +21,9 @@ class MetalAllocator : public allocator::Allocator {
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
virtual Buffer make_buffer(void* ptr, size_t size) override;
virtual void release(Buffer buffer) override;
size_t get_active_memory() {
return active_memory_;
};

View File

@@ -25,6 +25,7 @@ class CommonAllocator : public Allocator {
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() const {
return active_memory_;
};

View File

@@ -1,5 +1,6 @@
# Copyright © 2025 Apple Inc.
import argparse
import ipaddress
import json
import sys
@@ -16,6 +17,14 @@ class Host:
rdma: list[Optional[str]]
class OptionalBoolAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if option_string.startswith("--no-"):
setattr(namespace, self.dest, False)
else:
setattr(namespace, self.dest, True)
def positive_number(x):
x = int(x)
if x <= 0:
@@ -26,6 +35,7 @@ def positive_number(x):
def log(verbose, *args, **kwargs):
if not verbose:
return
kwargs["file"] = sys.stderr
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
@@ -50,7 +60,7 @@ def parse_hostlist(parser, hostlist, repeats):
except ValueError:
ips = []
for i in range(repeats):
hosts.append(Host(i, h, ips))
hosts.append(Host(i, h, ips, []))
return hosts

View File

@@ -0,0 +1,570 @@
# Copyright © 2025 Apple Inc.
import argparse
import json
import shlex
import sys
import threading
from collections import defaultdict
from dataclasses import dataclass
from subprocess import DEVNULL, run
from typing import Optional
import mlx.core as mx
from .common import (
Host,
OptionalBoolAction,
log,
log_error,
parse_hostfile,
parse_hostlist,
)
@dataclass
class SSHInfo:
can_ssh: bool
has_sudo: bool
def __bool__(self):
return self.can_ssh
@dataclass
class ThunderboltPort:
iface: str
uuid: str
connected_to: Optional[str]
@dataclass
class ThunderboltHost:
name: str
ports: list[ThunderboltPort]
def add_ethernet_ips(hosts, verbose=False):
# Get the ips for each host
for h in hosts:
log(verbose, "Getting the ip from", h.ssh_hostname)
h.ips.append(
run(
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
capture_output=True,
text=True,
).stdout.strip()
)
def check_rdma(hosts, verbose=False):
# Check whether the hosts are capable of RDMA over thunderbolt
warn = False
for h in hosts:
log(verbose, "Checking that", h.ssh_hostname, "supports RDMA")
rdma_devs = (
run(["ssh", h.ssh_hostname, "ibv_devices"], capture_output=True, text=True)
.stdout.strip()
.split()
)
rdma_devs = [d for d in rdma_devs if d.startswith("rdma_")]
if not rdma_devs:
log_warning(h.ssh_hostname, "does not seem to have RDMA enabled")
warn = True
if warn:
log_warning()
log_warning(
"Some of the hosts don't have RDMA enabled or they don't support RDMA."
)
log_warning()
log_warning(
"See https://ml-explore.github.io/mlx/build/html/usage/distributed.html"
)
log_warning("for instructions on how to enable RDMA.")
def can_auto_setup(hosts, sshinfo, auto_setup=False):
has_sudo = all(info.has_sudo for info in sshinfo)
if not has_sudo and auto_setup:
log_warning(
"Automatic setup requested but the following hosts do not have passwordless sudo"
)
for h, i in zip(hosts, sshinfo):
if not i.has_sudo:
log_warning(" - ", h.ssh_hostname)
return has_sudo
class IPConfigurator:
def __init__(self, hosts, tb_hosts, uuid_reverse_index):
assigned = set()
ips = defaultdict(list)
ip0 = 0
ip1 = 0
for src_node, h in enumerate(tb_hosts):
for src_port, p in enumerate(h.ports):
if not p.connected_to:
continue
if p.connected_to not in uuid_reverse_index:
continue
if (src_node, src_port) in assigned:
continue
dst_node, dst_port = uuid_reverse_index[p.connected_to]
ip_src = f"192.168.{ip0}.{ip1 + 1}"
ip_dst = f"192.168.{ip0}.{ip1 + 2}"
iface_src = p.iface
iface_dst = tb_hosts[dst_node].ports[dst_port].iface
ips[src_node, dst_node].append((iface_src, ip_src))
ips[dst_node, src_node].append((iface_dst, ip_dst))
assigned.add((src_node, src_port))
assigned.add((dst_node, dst_port))
ip1 += 4
if ip1 > 255:
ip0 += 1
ip1 = 0
if ip0 > 255:
raise ValueError("Ran out of available local IPs")
self.ips = ips
self.hosts = hosts
self.tb_hosts = tb_hosts
def setup(self, verbose=False, auto_setup=False):
netmask = "255.255.255.252"
for i, (h, th) in enumerate(zip(self.hosts, self.tb_hosts)):
command = ""
command += "sudo ifconfig bridge0 down\n"
for j in range(len(self.hosts)):
if i == j or (i, j) not in self.ips:
continue
for (iface, ip), (_, peer) in zip(self.ips[i, j], self.ips[j, i]):
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
command += f"sudo route change {peer} -interface {iface}\n"
if auto_setup:
print(f"Running auto setup for {h.ssh_hostname}")
command = command.strip().replace("\n", " ; ")
command = ["ssh", h.ssh_hostname, command]
log(verbose, shlex.join(command))
run(command)
else:
msg = f"Setup for {h.ssh_hostname}"
print(msg)
print("=" * len(msg))
print(command)
input("Enter to continue")
print()
def parse_hardware_ports(ports_string):
ports = {}
port_name = None
for l in ports_string.decode("utf-8").split("\n"):
if l.startswith("Hardware Port:"):
port_name = l.strip()[15:]
elif l.startswith("Device:"):
ports[port_name] = l.strip()[8:]
port_name = None
return ports
def extract_connectivity(hosts, verbose):
# Extract the current connectivity from the remote hosts
thunderbolt_connections = []
for h in hosts:
log(verbose, "Getting connectivity from", h.ssh_hostname)
thunderbolt_connections.append(
json.loads(
run(
[
"ssh",
h.ssh_hostname,
"system_profiler",
"SPThunderboltDataType",
"-json",
],
capture_output=True,
).stdout
)
)
interface_maps = []
for h in hosts:
log(verbose, "Getting interface names from", h.ssh_hostname)
interface_maps.append(
parse_hardware_ports(
run(
[
"ssh",
h.ssh_hostname,
"networksetup",
"-listallhardwareports",
],
capture_output=True,
).stdout
)
)
# Parse the connectivity into some simple dataclasses
tb_hosts = []
for c, iface_map in zip(thunderbolt_connections, interface_maps):
name = ""
ports = []
for t in c["SPThunderboltDataType"]:
uuid = t.get("domain_uuid_key")
if uuid is None:
continue
name = t["device_name_key"]
tag = t["receptacle_1_tag"]["receptacle_id_key"]
items = t.get("_items", [])
connected_items = [item for item in items if "domain_uuid_key" in item]
connected_to = (
connected_items[0]["domain_uuid_key"] if connected_items else None
)
iface = iface_map[f"Thunderbolt {tag}"]
ports.append(ThunderboltPort(iface, uuid, connected_to))
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
# Create a reverse index to be able to map uuids to (host, port) quickly
uuid_reverse_index = {}
for i, h in enumerate(tb_hosts):
for j, p in enumerate(h.ports):
uuid_reverse_index[p.uuid] = (i, j)
return tb_hosts, uuid_reverse_index
def make_connectivity_matrix(tb_hosts, uuid_reverse_index):
connectivity = []
for i, h in enumerate(tb_hosts):
c = [0] * len(tb_hosts)
for p in h.ports:
if p.connected_to in uuid_reverse_index:
j, _ = uuid_reverse_index[p.connected_to]
c[j] += 1
connectivity.append(c)
return connectivity
def tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index):
# Make ids per node
names = []
for i in range(len(tb_hosts)):
n = ""
j = i
while True:
n += chr(97 + j % 26)
j //= 26
if j == 0:
break
names.append(n)
print("graph G {")
print(" node [shape=rectangle];")
for i, h in enumerate(hosts):
print(f' {names[i]} [label="{h.ssh_hostname}"];')
for i, h in enumerate(tb_hosts):
for p in h.ports:
if not p.connected_to:
continue
dst = uuid_reverse_index[p.connected_to]
if dst[0] < i:
continue
print(f" {names[i]} -- {names[dst[0]]}", end="")
print(f' [label="{p.iface}/{tb_hosts[dst[0]].ports[dst[1]].iface}"]')
print("}")
def extract_rings(connectivity):
rings = []
existing_rings = set()
num_nodes = len(connectivity)
def dfs(start_node, node, path, visited):
path.append(node)
visited.add(node)
for j in range(num_nodes):
if connectivity[node][j] <= 0:
continue
if j == start_node:
yield path[:]
if j not in visited:
yield from dfs(start_node, j, path, visited)
path.pop()
visited.remove(node)
for start in range(num_nodes):
for r in dfs(start, start, [], set()):
cnt = min(connectivity[r[i]][r[(i + 1) % len(r)]] for i in range(len(r)))
rkey = tuple(sorted(r))
if rkey not in existing_rings:
rings.append((r, cnt))
existing_rings.add(rkey)
return sorted(rings, key=lambda x: -len(x[0]))
def check_valid_mesh(hosts, connectivity, strict=True):
num_nodes = len(connectivity)
for i in range(num_nodes):
for j in range(num_nodes):
if i == j:
continue
if connectivity[i][j] <= 0:
if strict:
log_error(
f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}"
)
log_error()
log_error("Try passing --dot to visualize the connectivity")
sys.exit(1)
else:
return False
return True
def check_ssh_connections(hosts):
results = [None] * len(hosts)
def _check(hostname, i):
info = SSHInfo(False, False)
results[i] = info
# Check for ssh
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"echo",
"success",
],
stdout=DEVNULL,
stderr=DEVNULL,
)
info.can_ssh = result.returncode == 0
if not info.can_ssh:
return
# Check for sudo
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"sudo",
"ls",
],
stdout=DEVNULL,
stderr=DEVNULL,
)
info.has_sudo = result.returncode == 0
threads = [
threading.Thread(target=_check, args=(h.ssh_hostname, i))
for i, h in enumerate(hosts)
]
for t in threads:
t.start()
for t in threads:
t.join()
if not all(results):
log_error("Could not ssh to the following hosts:")
for i, h in enumerate(hosts):
if not results[i]:
log_error(" - ", h.ssh_hostname)
log_error()
log_error("Maybe they are not set-up for password-less ssh?")
sys.exit(1)
return results
def prepare_ethernet_hostfile(args, hosts):
log(args.verbose, f"Preparing an ethernet hostfile")
add_ethernet_ips(hosts, args.verbose)
hostfile = []
for h in hosts:
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def configure_ring(args, hosts, ips, ring, sshinfo):
log(args.verbose, "Prepare a ring hostfile")
ring, count = ring
hostfile = []
for i, node in enumerate(ring):
h = hosts[node]
peer = ring[i - 1]
hostfile.append(
{
"ssh": h.ssh_hostname,
"ips": [ips.ips[node, peer][c][1] for c in range(count)],
"rdma": [],
}
)
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def configure_jaccl(args, hosts, ips, sshinfo):
log(args.verbose, "Prepare a jaccl hostfile")
check_rdma(hosts, args.verbose)
add_ethernet_ips(hosts, args.verbose)
hostfile = []
for i, h in enumerate(hosts):
rdma = []
for j in range(len(hosts)):
if i == j:
rdma.append(None)
else:
rdma.append(f"rdma_{ips.ips[i, j][0][0]}")
hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma})
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def prepare_tb_hostfile(args, hosts, sshinfo):
log(args.verbose, f"Preparing for communication over thunderbolt")
tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)
if args.dot:
tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index)
return
ips = IPConfigurator(hosts, tb_hosts, uuid_reverse_index)
connectivity = make_connectivity_matrix(tb_hosts, uuid_reverse_index)
if args.backend is None:
rings = extract_rings(connectivity)
has_mesh = check_valid_mesh(hosts, connectivity, False)
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
if not has_ring and not has_mesh:
log_error("Neither thunderbolt mesh nor ring found.")
log_error("Perhaps run with --dot to generate a plot of the connectivity.")
sys.exit(1)
elif has_ring:
configure_ring(args, hosts, ips, rings[0], sshinfo)
else:
configure_jaccl(args, hosts, ips, sshinfo)
elif args.backend == "ring":
rings = extract_rings(connectivity)
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
if not has_ring:
log_error("Could not find a full ring.")
log_error()
log_error("Try passing --dot to visualize the connectivity")
if len(rings) > 0:
log_error("Rings found:")
for r in rings:
log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}")
sys.exit(1)
configure_ring(args, hosts, ips, rings[0], sshinfo)
elif args.backend == "jaccl":
check_valid_mesh(hosts, connectivity)
configure_jaccl(args, hosts, ips, sshinfo)
def main():
parser = argparse.ArgumentParser(
description="Configure remote machines for use with MLX distributed"
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--over",
choices=["thunderbolt", "ethernet"],
default="thunderbolt",
help="What type of connectivity to configure",
required=True,
)
parser.add_argument(
"--output-hostfile", help="If provided, save the hostfile to this path"
)
parser.add_argument(
"--auto-setup",
"--no-auto-setup",
action=OptionalBoolAction,
nargs=0,
dest="auto_setup",
default=None,
)
parser.add_argument(
"--dot", action="store_true", help="Output the topology in DOT format and exit"
)
parser.add_argument(
"--backend",
choices=["ring", "jaccl"],
default=None,
help="Which distributed backend to configure",
)
args = parser.parse_args()
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, 1)
# Check that we can ssh
log(
args.verbose,
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
)
sshinfo = check_ssh_connections(hosts)
# Prepare a hostfile for communication over ethernet using the ips of the
# provided hostnames
if args.over == "ethernet":
prepare_ethernet_hostfile(args, hosts)
# Configure the macs for communication over thunderbolt, both via RDMA and IP
else:
prepare_tb_hostfile(args, hosts, sshinfo)

View File

@@ -1,911 +0,0 @@
# Copyright © 2025 Apple Inc.
import argparse
import base64
import ipaddress
import json
import os
import platform
import shlex
import shutil
import sys
import tempfile
import threading
import time
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from queue import Empty as QueueEmpty
from queue import Queue
from select import select
from subprocess import PIPE, Popen, run
from typing import Optional
import mlx.core as mx
@dataclass
class Host:
rank: int
ssh_hostname: str
ips: list[str]
@dataclass
class ThunderboltPort:
iface: str
uuid: str
connected_to: Optional[str]
@dataclass
class ThunderboltHost:
name: str
ports: list[ThunderboltPort]
def parse_hardware_ports(ports_string):
ports = {}
port_name = None
for l in ports_string.decode("utf-8").split("\n"):
if l.startswith("Hardware Port:"):
port_name = l.strip()[15:]
elif l.startswith("Device:"):
ports[port_name] = l.strip()[8:]
port_name = None
return ports
def get_num_nvidia_gpus():
result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
return len(result.stdout.strip().split("\n"))
def extract_rings(hosts, index):
def usable_port(i, j, used_ports):
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
def dfs(start_node, node, path, visited, used_ports):
path.append(node)
visited.add(node)
for j, p in enumerate(hosts[node].ports):
if not usable_port(node, j, used_ports):
continue
next_node, _ = index[p.connected_to]
if next_node == start_node:
yield path[:]
if next_node not in visited:
yield from dfs(start_node, next_node, path, visited, used_ports)
path.pop()
visited.remove(node)
# Concretize maps the found cycle to real thunderbolt ports. It also adds
# those ports to the used set so next cycles can't use them again.
def concretize(cycle, used_ports):
concrete_path = []
for n1, n2 in zip(cycle, cycle[1:] + cycle[:1]):
for j, p in enumerate(hosts[n1].ports):
if not usable_port(n1, j, used_ports):
continue
n2_hat, nj = index[p.connected_to]
if n2 == n2_hat:
concrete_path.append(((n1, j), (n2, nj)))
used_ports.add((n1, j))
used_ports.add((n2, nj))
break
if concrete_path[-1][0][0] != n1:
raise RuntimeError("Couldn't concretize the cycle")
return concrete_path
# Normalize tries to ensure that the cycles have the same direction so we can
# use them together. We achieve this by selecting the direction such that
# the smallest rank hosts connect to larger rank hosts.
def normalize(path):
small_to_large = sum(1 for p in path if p[0][0] < p[1][0])
if small_to_large > len(path) - small_to_large:
return path
else:
return [(p[1], p[0]) for p in path]
rings = []
used_ports = set()
for start_node in range(len(hosts)):
while True:
ring = []
for r in dfs(start_node, start_node, [], set(), used_ports):
if len(r) > len(ring):
ring = r
# Break early since we won't find a bigger ring no matter what
if len(ring) == len(hosts):
break
if not ring:
break
try:
rings.append(normalize(concretize(ring, used_ports)))
except RuntimeError:
if len(rings) > 0:
return rings
raise
return rings
def positive_number(x):
x = int(x)
if x <= 0:
raise ValueError("Number should be positive")
return x
def log(verbose, *args, **kwargs):
if not verbose:
return
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
def log_warning(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
def log_error(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
def parse_hostfile(parser, hostfile):
"""Parse the json hostfile that contains both the hostnames to ssh into and
the ips to communicate over when using the ring backend.
Example:
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
...
{"ssh": "hostnameN", "ips": ["123.123.123.N"]},
]
Args:
hostfile (str): The path to the json file containing the host
information
"""
hostfile = Path(hostfile)
if not hostfile.exists():
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
try:
hosts = []
with open(hostfile) as f:
for i, h in enumerate(json.load(f)):
hosts.append(Host(i, h["ssh"], h.get("ips", [])))
return hosts
except Exception as e:
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")
def parse_hostlist(parser, hostlist, repeats):
hosts = []
for i, h in enumerate(hostlist.split(",")):
if h == "":
raise ValueError("Hostname cannot be empty")
try:
ipaddress.ip_address(h)
ips = [h]
except ValueError:
ips = []
for i in range(repeats):
hosts.append(Host(i, h, ips))
return hosts
def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
# Imports that are used throughout
script = ""
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Write the PID to a file so we can kill the process if needed
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
# Change the working directory if one was requested. Otherwise attempt to
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
if cwd is not None:
script += "else:\n"
script += (
f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
)
script += f" sys.exit(1)\n"
# Add the environment variables that were given to us
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
if not all(c.isalnum() or c == "_" for c in key):
log_warning(f"'{e}' is an invalid environment variable so it is ignored")
continue
script += f"env[{repr(key)}] = {repr(value)}\n"
# Add the environment variables to enable the ring distributed backend
if hostfile != "":
script += "_, hostfile = tempfile.mkstemp()\n"
script += "with open(hostfile, 'w') as f:\n"
script += f" f.write({repr(hostfile)})\n"
if verbose:
script += "env['MLX_RING_VERBOSE'] = '1'\n"
script += "env['MLX_HOSTFILE'] = hostfile\n"
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
# Replace the process with the script
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
return script
def launch_ring(parser, hosts, args, command):
stop = False
exit_codes = [None] * len(hosts)
def node_thread(rank, host, hostfile, input_queue):
is_local = host == "127.0.0.1"
script = make_monitor_script(
rank, hostfile, args.cwd, args.env, command, args.verbose
)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
if not is_local:
cmd = f"ssh {host} '{cmd}'"
p = Popen(
cmd,
shell=True,
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
os.set_blocking(p.stdout.fileno(), False)
os.set_blocking(p.stderr.fileno(), False)
os.set_blocking(p.stdin.fileno(), False)
# Repeat the stdout and stderr to the local machine
to_read = [p.stdout.fileno(), p.stderr.fileno()]
to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()]
pidfile = ""
stdin_buffer = b""
stdout_buffer = b""
stderr_buffer = b""
while p.poll() is None:
try:
stdin_buffer += input_queue.get_nowait()
except QueueEmpty:
pass
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
for fd in rlist:
msg = os.read(fd, 8192).decode(errors="ignore")
# Fetch the PID file first if we haven't already
if pidfile == "":
pidfile, *msg = msg.split("\n", maxsplit=1)
msg = msg[0] if msg else ""
is_stdout = fd == p.stdout.fileno()
if is_stdout:
stdout_buffer += msg.encode()
else:
stderr_buffer += msg.encode()
for fd in wlist:
if fd == p.stdin.fileno() and len(stdin_buffer) > 0:
n = os.write(fd, stdin_buffer)
stdin_buffer = stdin_buffer[n:]
elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0:
n = os.write(fd, stdout_buffer)
stdout_buffer = stdout_buffer[n:]
elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0:
n = os.write(fd, stderr_buffer)
stderr_buffer = stderr_buffer[n:]
if stop:
p.terminate()
break
p.wait()
exit_codes[rank] = p.returncode
# Kill the remote program if possible
cmd = ""
cmd += f"pid=$(cat {pidfile}); "
cmd += "if ps -p $pid >/dev/null; then "
cmd += " kill $pid; "
cmd += " echo 1; "
cmd += "else "
cmd += " echo 0; "
cmd += "fi; "
cmd += f"rm {pidfile}"
if not is_local:
cmd = f"ssh {host} '{cmd}'"
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
if c.stdout.strip() == "1":
log_warning(f"Node with rank {rank} was killed")
elif p.returncode != 0:
log_warning(f"Node with rank {rank} exited with code {p.returncode}")
else:
log(args.verbose, f"Node with rank {rank} completed")
if all(len(h.ips) == 0 for h in hosts):
parser.error(
"The ring backend requires IPs to be provided instead of hostnames"
)
port = args.starting_port
ring_hosts = []
for h in hosts:
node = []
for ip in h.ips:
for i in range(args.connections_per_ip):
node.append(f"{ip}:{port}")
port += 1
ring_hosts.append(node)
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
log(args.verbose, "Running", shlex.join(command))
input_queues = []
threads = []
for i, h in enumerate(hosts):
if i + 1 == len(hosts):
time.sleep(1.0)
input_queues.append(Queue())
t = threading.Thread(
target=node_thread, args=(i, h.ssh_hostname, hostfile, input_queues[-1])
)
t.start()
threads.append(t)
os.set_blocking(sys.stdin.fileno(), False)
while not stop:
rlist, _, _ = select([sys.stdin.fileno()], [], [], 1.0)
for fd in rlist:
stdin_buffer = os.read(fd, 8192)
for q in input_queues:
q.put(stdin_buffer)
if any(t.is_alive() for t in threads):
for i, t in enumerate(threads):
if not t.is_alive():
if exit_codes[i] != 0:
stop = True
break
else:
break
for t in threads:
t.join()
def get_mpi_libname():
try:
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
ompi_info = ompi_info.stdout.strip().decode()
if platform.system() == "Darwin":
otool_output = run(
["otool", "-L", ompi_info], check=True, capture_output=True
)
else:
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
otool_output = otool_output.stdout.decode()
# StopIteration if not found
libmpi_line = next(
filter(lambda line: "libmpi" in line, otool_output.splitlines())
)
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
except:
return None
def launch_mpi(parser, hosts, args, command):
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
mpirun = mpirun.stdout.strip().decode()
# Compatibility with homebrew and pip installs
mpi_libname = get_mpi_libname()
if mpi_libname is not None:
dyld = Path(mpirun).parent.parent / "lib"
args.env = [
f"DYLD_LIBRARY_PATH={str(dyld)}",
f"MLX_MPI_LIBNAME={mpi_libname}",
] + args.env
log(args.verbose, f"Using '{mpirun}'")
with tempfile.NamedTemporaryFile(mode="w") as f:
hosts = Counter((h.ssh_hostname for h in hosts))
for h, n in hosts.items():
print(f"{h} slots={n}", file=f)
f.flush()
cmd = [
mpirun,
"--output",
":raw", # do not line buffer output
"--hostfile",
f.name,
*(["-cwd", args.cwd] if args.cwd else []),
*sum((["-x", e] for e in args.env), []),
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
"--",
*command,
]
log(args.verbose, "Running", " ".join(cmd))
try:
run(cmd)
except KeyboardInterrupt:
pass
def launch_nccl(parser, hosts, args, command):
master_host = hosts[0].ips[0]
if master_host != "127.0.0.1":
raise ValueError("The NCCL backend only supports localhost for now.")
master_port = args.nccl_port
world_size = len(hosts)
base_env = os.environ.copy()
base_env.update(
{
"NCCL_DEBUG": base_env.get(
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
),
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
"NCCL_HOST_IP": master_host,
"NCCL_PORT": str(master_port),
"MLX_WORLD_SIZE": str(world_size),
}
)
procs = []
num_gpus = get_num_nvidia_gpus()
if num_gpus == 0:
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
if args.repeat_hosts > num_gpus:
raise RuntimeError("NCCL requires a separate GPU per process.")
try:
for rank in range(world_size):
env = base_env.copy()
mlx_rank = str(rank % args.repeat_hosts)
env["MLX_RANK"] = mlx_rank
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
p = Popen(command, env=env)
procs.append(p)
for p in procs:
ret = p.wait()
if ret != 0:
raise RuntimeError(f"Rank process exited with {ret}")
except (RuntimeError, KeyboardInterrupt) as err:
for p in procs:
if p.poll() is None:
try:
p.kill()
except Exception:
pass
raise
def check_ssh_connections(hosts):
results = [False] * len(hosts)
def _check(hostname, i):
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"echo",
"success",
],
stdout=PIPE,
stderr=PIPE,
)
results[i] = result.returncode == 0
threads = [
threading.Thread(target=_check, args=(h.ssh_hostname, i))
for i, h in enumerate(hosts)
]
for t in threads:
t.start()
for t in threads:
t.join()
if not all(results):
log_error("Could not ssh to the following hosts:")
for i, h in enumerate(hosts):
if not results[i]:
log_error(" - ", h.ssh_hostname)
log_error()
log_error("Maybe they are not set-up for password-less ssh?")
sys.exit(1)
def prepare_tb_ring(args, hosts):
log(
args.verbose,
f"Preparing a thunderbolt ring for {', '.join(h.ssh_hostname for h in hosts)}",
)
# Check that we can ssh
check_ssh_connections(hosts)
if args.auto_setup and args.verbose:
log_warning(
"--auto-setup is requested which requires password-less sudo",
"on the remote hosts",
)
# Extract the current connectivity from the remote hosts
thunderbolt_connections = []
for h in hosts:
log(args.verbose, "Getting connectivity from", h.ssh_hostname)
thunderbolt_connections.append(
json.loads(
run(
[
"ssh",
h.ssh_hostname,
"system_profiler",
"SPThunderboltDataType",
"-json",
],
capture_output=True,
).stdout
)
)
interface_maps = []
for h in hosts:
log(args.verbose, "Getting interface names from", h.ssh_hostname)
interface_maps.append(
parse_hardware_ports(
run(
[
"ssh",
h.ssh_hostname,
"networksetup",
"-listallhardwareports",
],
capture_output=True,
).stdout
)
)
# Parse the connectivity into some simple dataclasses
tb_hosts = []
for c, iface_map in zip(thunderbolt_connections, interface_maps):
name = ""
ports = []
for t in c["SPThunderboltDataType"]:
uuid = t.get("domain_uuid_key")
if uuid is None:
continue
name = t["device_name_key"]
tag = t["receptacle_1_tag"]["receptacle_id_key"]
items = t.get("_items", [])
connected_items = [item for item in items if "domain_uuid_key" in item]
connected_to = (
connected_items[0]["domain_uuid_key"] if connected_items else None
)
iface = iface_map[f"Thunderbolt {tag}"]
ports.append(ThunderboltPort(iface, uuid, connected_to))
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
# Create a reverse index to be able to map uuids to (host, port) quickly
uuid_reverse_index = {}
for i, h in enumerate(tb_hosts):
for j, p in enumerate(h.ports):
uuid_reverse_index[p.uuid] = (i, j)
# Find the rings by simply walking and marking visited (host, port) tuples
# and keeping the largest rings greedily.
log(args.verbose, "Extracting rings from the parsed connectivity")
rings = extract_rings(tb_hosts, uuid_reverse_index)
# Just output a DOT graphical representation of the found rings
if args.dot:
names = []
for i in range(len(tb_hosts)):
n = ""
j = i
while True:
n += chr(97 + j % 26)
j //= 26
if j == 0:
break
names.append(n)
print("graph G {")
print(" node [shape=rectangle];")
for i, h in enumerate(hosts):
print(f' {names[i]} [label="{h.ssh_hostname}"];')
for r in rings:
for (i, _), (j, _) in r:
print(f" {names[i]} -- {names[j]};")
print("}")
return
# Assign IPs to each interface such that the interfaces can communicate
ips = {}
pairs = {}
expecting = set()
ip0 = 0
ip1 = 0
netmask = "255.255.255.252"
for r in rings:
for a, b in r:
ips[a] = f"192.168.{ip0}.{ip1 + 1}"
ips[b] = f"192.168.{ip0}.{ip1 + 2}"
pairs[a] = b
pairs[b] = a
expecting.add(b)
ip1 += 4
if ip1 > 255:
ip0 += 1
ip1 = 0
if ip0 > 255:
raise ValueError("Ran out of available local IPs for the ring")
# Extract the host order from the first ring
hostmap = dict((r[0][0], r[1][0]) for r in rings[0])
first_host = min(hostmap.keys())
order = [first_host]
while hostmap[order[-1]] != first_host:
order.append(hostmap[order[-1]])
# Create the hostfile
hostfile = []
for i in order:
h = hosts[i]
host = {
"ssh": h.ssh_hostname,
"ips": [
ips[i, j]
for j, p in enumerate(tb_hosts[i].ports)
if (i, j) in expecting
],
}
hostfile.append(host)
if not args.hostfile_only:
for i, h in enumerate(hosts):
command = ""
command += "sudo ifconfig bridge0 down\n"
for j, p in enumerate(tb_hosts[i].ports):
if (i, j) not in ips:
continue
iface = p.iface
ip = ips[i, j]
peer = ips[pairs[i, j]]
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
command += f"sudo route change {peer} -interface {iface}\n"
if args.auto_setup:
print(f"Running auto setup for {h.ssh_hostname}")
command = command.strip().replace("\n", " && ")
command = ["ssh", h.ssh_hostname, command]
log(args.verbose, shlex.join(command))
run(command)
else:
msg = f"Setup for {h.ssh_hostname}"
print(msg)
print("=" * len(msg))
print(command)
input("Enter to continue")
print()
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def prepare_hostfile(args, hosts):
log(
args.verbose,
f"Preparing an ethernet hostfile for {', '.join(h.ssh_hostname for h in hosts)}",
)
# Check that we can ssh
check_ssh_connections(hosts)
# Get the ips for each host
for h in hosts:
log(args.verbose, "Getting the ip from", h.ssh_hostname)
h.ips.append(
run(
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
capture_output=True,
text=True,
).stdout.strip()
)
hostfile = []
for h in hosts:
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def distributed_config():
parser = argparse.ArgumentParser(
description="Configure remote machines for use with MLX distributed"
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to configure",
)
parser.add_argument(
"--over",
choices=["thunderbolt", "ethernet"],
default="thunderbolt",
help="What type of connectivity to configure",
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--dot", action="store_true", help="Output the topology in DOT format and exit"
)
parser.add_argument(
"--hostfile-only", action="store_true", help="If set only compute the hostfile"
)
parser.add_argument(
"--output-hostfile", help="If provided, save the hostfile to this path"
)
parser.add_argument(
"--auto-setup",
action="store_true",
help="If set we will attempt to automatically configure the machines via ssh",
)
args = parser.parse_args()
if args.backend == "mpi" and args.over == "thunderbolt":
raise ValueError(
(
"The configuration of MPI over thunderbolt is "
"not supported yet by mlx.distributed_config"
)
)
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, 1)
if args.over == "thunderbolt":
prepare_tb_ring(args, hosts)
else:
prepare_hostfile(args, hosts)
def main():
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
parser.add_argument(
"--print-python",
action="store_true",
help="Print the path to the current python executable and exit",
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument(
"--repeat-hosts",
"-n",
type=positive_number,
default=1,
help="Repeat each host a given number of times",
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl", "jaccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to launch",
)
parser.add_argument(
"--env",
action="append",
default=[],
help="Set environment variables for the jobs",
)
parser.add_argument(
"--mpi-arg",
action="append",
default=[],
help="Arguments to pass directly to mpirun",
)
parser.add_argument(
"--connections-per-ip",
default=1,
type=int,
help="How many connections per ip to use for the ring backend",
)
parser.add_argument(
"--starting-port",
"-p",
type=int,
default=5000,
help="For the ring backend listen on this port increasing by 1 per rank and IP",
)
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
args, rest = parser.parse_known_args()
if args.print_python:
print(sys.executable)
return
if len(rest) == 0:
parser.error("No script is provided")
if rest[0] == "--":
rest.pop(0)
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
# Check if the script is a file and convert it to a full path
if (script := Path(rest[0])).exists():
rest[0:1] = [sys.executable, str(script.resolve())]
elif (command := shutil.which(rest[0])) is not None:
rest[0] = command
else:
raise ValueError(f"Invalid script or command {rest[0]}")
# Launch
if args.backend == "ring":
launch_ring(parser, hosts, args, rest)
if args.backend == "mpi":
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if args.backend == "jaccl":
launch_jaccl(parser, hosts, args, rest)
if __name__ == "__main__":
main()

View File

@@ -45,13 +45,11 @@ class CommandProcess:
class RemoteProcess(CommandProcess):
def __init__(self, rank, host, cwd, files, env, command):
def __init__(self, rank, host, python, cwd, files, env, command):
is_local = host == "127.0.0.1"
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
cmd = RemoteProcess.make_launch_script(rank, cwd, files, env, command)
if not is_local:
cmd = f"ssh {host} '{cmd}'"
cmd = f"ssh -tt -o LogLevel=QUIET {host} {shlex.quote(cmd)}"
self._host = host
self._pidfile = None
@@ -59,6 +57,7 @@ class RemoteProcess(CommandProcess):
self._process = Popen(
cmd,
shell=True,
executable="/bin/bash",
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
@@ -90,47 +89,43 @@ class RemoteProcess(CommandProcess):
self._process.wait()
# Kill the remote program if possible
cmd = ""
cmd += f"pid=$(cat {self._pidfile}); "
cmd += "if ps -p $pid >/dev/null; then "
cmd += " kill $pid; "
cmd += " echo 1; "
cmd += "else "
cmd += " echo 0; "
cmd += "fi; "
cmd += f"rm {self._pidfile}"
cmd = RemoteProcess.make_kill_script(self._pidfile)
if not self._is_local:
cmd = f"ssh {self._host} '{cmd}'"
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
cmd = f"ssh {self._host} {shlex.quote(cmd)}"
c = run(
cmd,
check=True,
shell=True,
executable="/bin/bash",
capture_output=True,
text=True,
)
self._killed = c.stdout.strip() == "1"
@staticmethod
def make_monitor_script(rank, cwd, files, env, command):
# Imports that are used throughout
def make_launch_script(rank, cwd, files, env, command):
script = ""
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Disable echo
script = "stty -echo; "
# Write the PID to a file so we can kill the process if needed
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
script += "pidfile=$(mktemp); "
script += "echo $$ > $pidfile; "
script += 'printf "%s\\n" $pidfile; '
# Change the working directory if one was requested. Otherwise attempt to
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
script += f"if [[ -d {repr(d)} ]]; then "
script += f" cd {repr(d)}; "
if cwd is not None:
script += "else:\n"
script += f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
script += f" sys.exit(1)\n"
script += "else "
script += f" echo 'Failed to change directory to' {repr(d)} >2; "
script += "fi; "
# Add the environment variables that were requested
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
@@ -139,22 +134,34 @@ class RemoteProcess(CommandProcess):
f"'{e}' is an invalid environment variable so it is ignored"
)
continue
script += f"env[{repr(key)}] = {repr(value)}\n"
script += f"export {key}={value}; "
# Make the temporary files
for env_name, content in files.items():
script += "_, fname = tempfile.mkstemp()\n"
script += "with open(fname, 'w') as f:\n"
script += f" f.write({repr(content)})\n"
script += f"env[{repr(env_name)}] = fname\n"
script += "fname=$(mktemp); "
script += f"echo {shlex.quote(content)} >$fname; "
script += f"export {env_name}=$fname; "
# Finally add the rank
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
script += f"export MLX_RANK={rank}; "
# Replace the process with the script
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
script += f"cmd=({' '.join(map(shlex.quote, command))}); "
script += 'exec "${cmd[@]}"'
return script
@staticmethod
def make_kill_script(pidfile):
script = ""
script += f"pid=$(cat {pidfile}); "
script += "if ps -p $pid >/dev/null; then "
script += " kill $pid; "
script += " echo 1; "
script += "else "
script += " echo 0; "
script += "fi; "
script += f"rm {pidfile}"
return script
@@ -309,7 +316,7 @@ def launch_ring(parser, hosts, args, command):
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, cwd, files, env, command), {})
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
@@ -341,6 +348,7 @@ def launch_nccl(parser, hosts, args, command):
(
rank,
h.ssh_hostname,
args.python,
cwd,
{},
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
@@ -374,7 +382,7 @@ def launch_jaccl(parser, hosts, args, command):
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, cwd, files, env, command), {})
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
@@ -503,11 +511,20 @@ def main():
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
parser.add_argument(
"--no-verify-script",
action="store_false",
dest="verify_script",
help="Do not verify that the script exists",
)
parser.add_argument(
"--python", default=sys.executable, help="Use this python on the remote hosts"
)
args, rest = parser.parse_known_args()
if args.print_python:
print(sys.executable)
print(args.python)
return
if len(rest) == 0:
@@ -523,10 +540,10 @@ def main():
# Check if the script is a file and convert it to a full path
if (script := Path(rest[0])).exists() and script.is_file():
rest[0:1] = [sys.executable, str(script.resolve())]
rest[0:1] = [args.python, str(script.resolve())]
elif (command := shutil.which(rest[0])) is not None:
rest[0] = command
else:
elif args.verify_script:
raise ValueError(f"Invalid script or command {rest[0]}")
# Launch

View File

@@ -1,7 +1,7 @@
#!/bin/bash
auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \
--plat manylinux_2_35_${1} \
--exclude libcublas* \
--exclude libnvrtc* \
--exclude libcuda* \

View File

@@ -67,7 +67,7 @@ void init_distributed(nb::module_& parent_module) {
Args:
backend (str, optional): The name of the backend to check for availability.
It takes the same values as ``init()``. Default: ``any``.
It takes the same values as :func:`init()`. Default: ``"any"``.
Returns:
bool: Whether the distributed backend is available.

View File

@@ -210,6 +210,14 @@ class TestReduce(mlx_tests.MLXTestCase):
ref = getattr(np, op)(np_arr, axis=axis)
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
def test_long_column(self):
a = (np.random.randn(8192, 64) * 32).astype(np.int32)
b = mx.array(a)
c1 = a.sum(0)
c2 = b.sum(0)
self.assertTrue(np.all(c1 == c2))
if __name__ == "__main__":
mlx_tests.MLXTestRunner(failfast=True)

View File

@@ -7,13 +7,21 @@ import re
import subprocess
from functools import partial
from pathlib import Path
from subprocess import run
from setuptools import Command, Extension, find_namespace_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext
def cuda_toolkit_major_version():
out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
text = out.decode()
m = re.search(r"release (\d+)", text)
if m:
return int(m.group(1))
return None
def get_version():
with open("mlx/version.h", "r") as fid:
for l in fid:
@@ -31,7 +39,7 @@ def get_version():
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
if not pypi_release and not dev_release:
git_hash = (
run(
subprocess.run(
"git rev-parse --short HEAD".split(),
capture_output=True,
check=True,
@@ -258,7 +266,7 @@ if __name__ == "__main__":
entry_points = {
"console_scripts": [
"mlx.launch = mlx._distributed_utils.launch:main",
# "mlx.distributed_config = mlx.distributed_run:distributed_config",
"mlx.distributed_config = mlx._distributed_utils.config:main",
]
}
install_requires = []
@@ -284,7 +292,11 @@ if __name__ == "__main__":
install_requires.append(
f'mlx-metal=={version}; platform_system == "Darwin"'
)
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"']
for toolkit in [12, 13]:
extras[f"cuda{toolkit}"] = [
f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"'
]
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
_setup(
@@ -299,13 +311,25 @@ if __name__ == "__main__":
if build_macos:
name = "mlx-metal"
elif build_cuda:
name = "mlx-cuda"
toolkit = cuda_toolkit_major_version()
name = f"mlx-cuda-{toolkit}"
if toolkit == 12:
install_requires += [
"nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*",
]
elif toolkit == 13:
install_requires += [
"nvidia-cublas-cu13",
"nvidia-cuda-nvrtc-cu13",
]
else:
raise ValueError(f"Unknown toolkit {toolkit}")
install_requires += [
"nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*",
"nvidia-cudnn-cu12==9.*",
"nvidia-nccl-cu12",
f"nvidia-cudnn-cu{toolkit}==9.*",
f"nvidia-nccl-cu{toolkit}",
]
else:
name = "mlx-cpu"
_setup(

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
#include <climits>
#include "doctest/doctest.h"
@@ -608,3 +607,24 @@ TEST_CASE("test make empty array") {
CHECK_EQ(a.size(), 0);
CHECK_EQ(a.dtype(), bool_);
}
TEST_CASE("test make array from user buffer") {
int size = 4096;
std::vector<int> buffer(size, 0);
int count = 0;
auto deleter = [&count](void*) { count++; };
{
auto a = array(buffer.data(), Shape{size}, int32, deleter);
if (metal::is_available()) {
CHECK_EQ(buffer.data(), a.data<int>());
}
auto b = a + array(1);
eval(b);
auto expected = ones({4096});
CHECK(array_equal(b, expected).item<bool>());
}
// deleter should always get called
CHECK_EQ(count, 1);
}