Compare commits

..

15 Commits

Author SHA1 Message Date
dependabot[bot]
c2764d1073 Bump actions/download-artifact from 6 to 7 (#2912)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-15 06:10:16 -08:00
dependabot[bot]
093a62d2ed Bump actions/upload-artifact from 5 to 6 (#2911)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-15 06:09:55 -08:00
Awni Hannun
1b591ec736 No VJP for mask or sinks in attention (#2909)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-13 19:48:39 -08:00
Awni Hannun
47d2505ea9 Fix attention for large sizes (#2903)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-13 06:54:30 -08:00
Cheng
bedefed784 Fix ccache getting disabled (#2905)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-13 13:00:51 +09:00
Melissa Kilby
ccaaa7d6df fix: possible heap-buffer-overflow in RandomBits::eval_cpu (#2877)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-12 02:11:18 -08:00
Awni Hannun
f3e5ca5414 [CUDA] Add host nodes to subgraph types for graph update (#2901)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-11 19:13:44 -08:00
Awni Hannun
81dfe5f137 Fix grad in place updates (#2899)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-11 14:44:58 -08:00
Anastasiia Filippova
012fb220a1 fp quantize (#2892)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-11 06:11:25 -08:00
Nathan Goldbaum
e1fee0074b Update nanobind pin to most recent version (#2896) 2025-12-11 06:07:36 -08:00
CCYeh
3c8ce9b00e Fix input buffer donation in compile (#2897) 2025-12-11 06:07:03 -08:00
David Koski
937ce79660 do not use simd neon intrinsics on x86 (#2893)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-10 12:23:28 -08:00
Nathan Goldbaum
208f5441a7 bump minimum required Python version (#2891)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-09 16:54:38 -08:00
Awni Hannun
b862d842e1 Allow events in sub graph to be updatable (#2886)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-09 12:34:37 -08:00
Satyam singh
f7a400951a Fix docs: replace mx.random.randn with mx.random.normal (#2890) 2025-12-09 11:46:30 -08:00
49 changed files with 1840 additions and 3315 deletions

View File

@@ -11,7 +11,7 @@ runs:
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0
pip install cmake setuptools nanobind==2.10.2
pip install -e . -v
- name: Generate package stubs

View File

@@ -10,23 +10,29 @@ inputs:
description: 'Version of python to set up'
required: false
default: '3.10'
use-ccache:
description: 'Whether to enable ccache'
required: false
default: 'true'
runs:
using: "composite"
steps:
- name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
max-size: 1GB
- name: Install common dependencies
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
- name: Use ccache
if: ${{ inputs.use-ccache == 'true' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}
max-size: 1GB
# ccache-action bug: running "apt-get update" fails on large arm runner.
update-package-index: false
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
@@ -36,7 +42,7 @@ runs:
run: |
python -m venv .venv
source .venv/bin/activate
pip install setuptools cmake nanobind==2.4.0
pip install setuptools cmake nanobind==2.10.2
echo PATH=$PATH >> $GITHUB_ENV
# Make cmake search .venv for nanobind
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV

View File

@@ -23,14 +23,14 @@ jobs:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: "x86_64"
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
@@ -89,7 +89,7 @@ jobs:
with:
toolkit: 'cuda-12.9'
- name: Upload artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl

View File

@@ -57,19 +57,20 @@ jobs:
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
use-ccache: false
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: ${{ matrix.arch }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
overwrite: true
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
overwrite: true
name: mlx-cpu-${{ matrix.arch }}
@@ -95,7 +96,7 @@ jobs:
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0
pip install cmake setuptools nanobind==2.10.2
pip install -e . -v
- name: Generate package stubs
shell: bash -l {0}
@@ -113,14 +114,14 @@ jobs:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
overwrite: true
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
overwrite: true
name: mlx-metal
@@ -141,12 +142,13 @@ jobs:
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
use-ccache: false
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
overwrite: true
name: mlx-cuda
@@ -162,12 +164,12 @@ jobs:
name: pypi
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
- uses: actions/download-artifact@v7
with:
pattern: linux-wheels-*
merge-multiple: true
path: dist
- uses: actions/download-artifact@v6
- uses: actions/download-artifact@v7
with:
pattern: mac-wheels-*
merge-multiple: true
@@ -189,7 +191,7 @@ jobs:
name: pypi
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
- uses: actions/download-artifact@v7
with:
name: mlx-cuda
path: dist
@@ -210,7 +212,7 @@ jobs:
name: pypi
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
- uses: actions/download-artifact@v7
with:
pattern: mlx-cpu-*
merge-multiple: true
@@ -232,7 +234,7 @@ jobs:
name: pypi
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
- uses: actions/download-artifact@v7
with:
name: mlx-metal
path: dist

View File

@@ -273,7 +273,7 @@ target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(
Python 3.8
Python 3.10
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

View File

@@ -7,29 +7,22 @@ Distributed Communication
MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the
moment we support several different communication backends introduced below.
.. list-table::
:widths: 20 80
:header-rows: 1
* - Backend
- Description
* - :ref:`MPI <mpi_section>`
- A full featured and mature distributed communications library.
* - :ref:`RING <ring_section>`
- Ring all reduce and all gather over TCP sockets. Always available and
usually faster than MPI.
* - :ref:`JACCL <ring_section>`
- Low latency communication with RDMA over thunderbolt. Necessary for
things like tensor parallelism.
* - :ref:`NCCL <nccl_section>`
- The backend of choice for CUDA environments.
moment we support three different communication backends:
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets. It should be
faster for thunderbolt connections, but it also works over Ethernet.
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.
The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`.
.. note::
Some operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
@@ -92,7 +85,7 @@ Selecting Backend
^^^^^^^^^^^^^^^^^
You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'jaccl', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
available backends. If they all fail then a singleton group is created.
.. note::
@@ -117,8 +110,6 @@ The following examples aim to clarify the backend initialization logic in MLX:
world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init() # same as MPI because it was initialized first!
.. _training_example:
Training Example
----------------
@@ -201,273 +192,16 @@ almost identical to the example above:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
.. _ring_section:
Getting Started with Ring
-------------------------
The ring backend does not depend on any third party library so it is always
available. It uses TCP sockets so the nodes need to be reachable via a network.
As the name suggests the nodes are connected in a ring which means that rank 1
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
and so on and so forth. As a result :func:`send` and :func:`recv` with
arbitrary sender and receiver is not supported in the ring backend.
Defining a Ring
^^^^^^^^^^^^^^^
The easiest way to define and use a ring is via a JSON hostfile and the
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
defines a hostname to ssh into to run commands on this node and one or more IPs
that this node will listen to for connections.
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
rank 0, ``hostname2`` rank 1 etc.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
]
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
node, run the script which will listen for connections in each of the provided
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
connection from ``123.123.123.4`` and so on and so forth.
Thunderbolt Ring
^^^^^^^^^^^^^^^^
Although the ring backend can have benefits over MPI even for Ethernet, its
main purpose is to use Thunderbolt rings for higher bandwidth communication.
Setting up such thunderbolt rings can be done manually, but is a relatively
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
utility as follows:
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --backend ring
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
If you want to go through the process manually, the steps are as follows:
* Disable the thunderbolt bridge interface
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
corresponding to that cable in nodes ``i`` and ``i + 1``.
* Set up a unique subnetwork connecting the two nodes for the corresponding
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.
.. _jaccl_section:
Getting Started with RDMA over Thunderbolt
------------------------------------------
Starting from version 26.2 RDMA over thunderbolt is available in MacOS and
enables low-latency communication between Macs with thunderbolt 5. MLX provides
the JACCL backend that uses this functionality to achieve communication latency
an order of magnitude lower than the ring backend.
.. note::
The name JACCL (pronounced Jackal) stands for *Jack and Angelos' Collective
Communication Library* and it is an obvious pun to Nvidia's NCCL but also
tribute to *Jack Beasley* who led the development of RDMA over Thunderbolt
at Apple.
Enabling RDMA
^^^^^^^^^^^^^
Until the feature matures, enabling RDMA over thunderbolt is slightly more
involved and **cannot** be done remotely even with sudo. In fact, it has to be
done in macOS recovery:
1. `Start your computer in recovery <https://support.apple.com/en-us/102518>`_.
2. Open the Terminal by going to Utilities -> Terminal.
3. Run ``rdma_ctl enable``.
4. Reboot.
To verify that you have successfully enabled Thunderbolt RDMA you can run
``ibv_devices`` which should produce something like the following for an M3 Ultra.
.. code-block:: bash
~ % ibv_devices
device node GUID
------ ----------------
rdma_en2 8096a9d9edbaac05
rdma_en3 8196a9d9edbaac05
rdma_en5 8396a9d9edbaac05
rdma_en4 8296a9d9edbaac05
rdma_en6 8496a9d9edbaac05
rdma_en7 8596a9d9edbaac05
Defining a Mesh
^^^^^^^^^^^^^^^
The JACCL backend supports only fully connected topologies. Namely, there needs
to be a thunderbolt cable connecting all pairs of Macs directly. For example, in
the following topology visualizations, the left one is valid because there is a
connection from any node to any other node, while for the one on the right M3
Ultra 1 is not connected to M3 Ultra 2.
.. raw:: html
<div style="display: flex; text-align: center; align-items: end; font-size: 80%;">
<div>
<img src="/_static/distributed/m3-ultra-mesh.png" alt="M3 Ultra thunderbolt mesh" style="width: 55%">
<p>Fully connected mesh of four M3 Ultra.</p>
</div>
<div>
<img src="/_static/distributed/m3-ultra-mesh-broken.png" alt="M3 Ultra broken thunderbolt mesh" style="width: 55%">
<p>Not a valid mesh (M3 Ultra 1 is not connected to M3 Ultra 2).</p>
</div>
</div>
Similar to the ring backend, the easiest way to use JACCL with MLX is to write
a JSON hostfile that will be used by ``mlx.launch``. The hostfile needs to contain
- Hostnames to use for launching scripts via ssh
- An IP for rank 0 that is reachable by all nodes
- A list of rdma devices that connect each node to each other node
The following JSON defines the valid 4-node mesh from the image above.
.. code-block:: json
[
{
"ssh": "m3-ultra-1",
"ips": ["123.123.123.1"],
"rdma": [null, "rdma_en5", "rdma_en4", "rdma_en3"]
},
{
"ssh": "m3-ultra-2",
"ips": [],
"rdma": ["rdma_en5", null, "rdma_en3", "rdma_en4"]
},
{
"ssh": "m3-ultra-3",
"ips": [],
"rdma": ["rdma_en4", "rdma_en3", null, "rdma_en5"]
},
{
"ssh": "m3-ultra-4",
"ips": [],
"rdma": ["rdma_en3", "rdma_en4", "rdma_en5", null]
}
]
Even though TCP/IP is not used when communicating with Thunderbolt RDMA,
disabling the thunderbolt bridge is still required as well as setting up
isolated local networks for each thunderbolt connection.
All of the above can be done instead via ``mlx.distributed_config``. This helper
script will
- ssh into each node
- extract the thunderbolt connectivity
- check for a valid mesh
- provide the commands to configure each node (or run them if sudo is available)
- generate the hostfile to be used with ``mlx.launch``
Putting it All Together
^^^^^^^^^^^^^^^^^^^^^^^^
For example launching a distributed MLX script that uses JACCL is fairly simple
if the nodes are reachable via ssh and have password-less sudo.
First, connect all the thunderbolt cables. Then we can verify the connections
by using the ``mlx.distributed_config`` script to visualize them.
.. code-block::
mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
--over thunderbolt --dot | dot -Tpng | open -f -a Preview
After making sure that everything looks right we can auto-configure the nodes
and save the hostfile to ``m3-ultra-jaccl.json`` by running:
.. code-block::
mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
--over thunderbolt --backend jaccl \
--auto-setup --output m3-ultra-jaccl.json
And now we are ready to run a distributed MLX script such as distributed inference
of a gigantic model using MLX-LM.
.. code-block::
mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \
--env MLX_METAL_FAST_SYNCH=1 -- \ # <--- important
/path/to/remote/python -m mlx_lm chat --model mlx-community/DeepSeek-V3.2-8bit --shard
.. note::
Defining the environment variable ``MLX_METAL_FAST_SYNCH=1`` enables a
different, faster way of synchronizing between the GPU and the CPU. It is
not specific to the JACCL backend and can be used in all cases where the CPU
and GPU need to collaborate for some computation and is pretty critical for
low-latency communication since the communication is done by the CPU.
.. _nccl_section:
Getting Started with NCCL
-------------------------
MLX on CUDA environments ships with the ability to talk to `NCCL
<https://developer.nvidia.com/nccl>`_ which is a high-performance collective
communication library that supports both multi-gpu and multi-node setups.
For CUDA environments, NCCL is the default backend for ``mlx.launch`` and all
it takes to run a distributed job is
.. code-block::
mlx.launch -n 8 test.py
# perfect for interactive scripts
mlx.launch -n 8 python -m mlx_lm chat --model my-model --shard
You can also use ``mlx.launch`` to ssh to a remote node and launch a script
with the same ease
.. code-block::
mlx.launch --hosts my-cuda-node -n 8 test.py
In many cases you may not want to use ``mlx.launch`` with the NCCL backend
because the cluster scheduler will be the one launching the processes. You can
:ref:`see which environment variables need to be defined <no_mlx_launch>` in
order for the MLX NCCL backend to be initialized correctly.
.. _mpi_section:
Getting Started with MPI
------------------------
MLX already comes with the ability to "talk" to `MPI
<https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ if it is installed
on the machine. Launching distributed MLX programs that use MPI can be done
with ``mpirun`` as expected. However, in the following examples we will be
using ``mlx.launch --backend mpi`` which takes care of some nuisances such as
setting absolute paths for the ``mpirun`` executable and the ``libmpi.dyld``
shared library.
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. Launching distributed MLX programs that use MPI can be done with
``mpirun`` as expected. However, in the following examples we will be using
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
library.
The simplest possible usage is the following which, assuming the minimal
example in the beginning of this page, should result in:
@@ -535,116 +269,78 @@ Force MPI to use the most performant network interface by setting ``--mca
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
to use.
.. _no_mlx_launch:
Getting Started with Ring
-------------------------
Distributed Without ``mlx.launch``
----------------------------------
The ring backend does not depend on any third party library so it is always
available. It uses TCP sockets so the nodes need to be reachable via a network.
As the name suggests the nodes are connected in a ring which means that rank 1
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
and so on and so forth. As a result :func:`send` and :func:`recv` with
arbitrary sender and receiver is not supported in the ring backend.
None of the implementations of the distributed backends require launching with
``mlx.launch``. The script simply connects to each host. Starts a process per
rank and sets up the necessary environment variables before delegating to your
MLX script. See the :doc:`dedicated documentation page <launching_distributed>`
for more details.
Defining a Ring
^^^^^^^^^^^^^^^
For many use-cases this will be the easiest way to perform distributed
computations in MLX. However, there may be reasons that you cannot or should
not use ``mlx.launch``. A common such case is the use of a scheduler that
starts all the processes for you on machines undetermined at the time of
scheduling the job.
The easiest way to define and use a ring is via a JSON hostfile and the
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
defines a hostname to ssh into to run commands on this node and one or more IPs
that this node will listen to for connections.
Below we list the environment variables required to use each backend.
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
rank 0, ``hostname2`` rank 1 etc.
Ring
^^^^^^
.. code:: json
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
]
**MLX_HOSTFILE** should contain the path to a json file that contains IPs and
ports for each rank to listen to, something like the following:
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
node, run the script which will listen for connections in each of the provided
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
connection from ``123.123.123.4`` and so on and so forth.
.. code-block:: json
Thunderbolt Ring
^^^^^^^^^^^^^^^^
[
["123.123.1.1:5000", "123.123.1.2:5000"],
["123.123.2.1:5000", "123.123.2.2:5000"],
["123.123.3.1:5000", "123.123.3.2:5000"],
["123.123.4.1:5000", "123.123.4.2:5000"]
]
Although the ring backend can have benefits over MPI even for Ethernet, its
main purpose is to use Thunderbolt rings for higher bandwidth communication.
Setting up such thunderbolt rings can be done manually, but is a relatively
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
**MLX_RING_VERBOSE** is optional and if set to 1 it enables some more logging
from the distributed backend.
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
utility as follows:
JACCL
^^^^^
.. code:: shell
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
**MLX_JACCL_COORDINATOR** should contain the IP and port that rank 0 can listen
to all the other ranks connect to in order to establish the RDMA connections.
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
**MLX_IBV_DEVICES** should contain the path to a json file that contains the
ibverbs device names that connect each node to each other node, something like
the following:
To validate your connection without configuring anything
``mlx.distributed_config`` can also plot the ring using DOT format.
.. code-block:: json
.. code:: shell
[
[null, "rdma_en5", "rdma_en4", "rdma_en3"],
["rdma_en5", null, "rdma_en3", "rdma_en4"],
["rdma_en4", "rdma_en3", null, "rdma_en5"],
["rdma_en3", "rdma_en4", "rdma_en5", null]
]
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
dot -Tpng ring.dot >ring.png
open ring.png
If you want to go through the process manually, the steps are as follows:
NCCL
^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_WORLD_SIZE** should contain the total number of processes that will be
launched.
**NCCL_HOST_IP** and **NCCL_PORT** should contain the IP and port that all
hosts can connect to to establish the NCCL communication.
**CUDA_VISIBLE_DEVICES** should contain the local index of the gpu that
corresponds to this process.
Of course any `other environment variable
<https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html>`_ that is
used by NCCL can be set.
.. _tips_and_tricks:
Tips and Tricks
----------------
This is a small collection of tips to help you utilize better the distributed
communication capabilities of MLX.
- *Test locally first.*
You can use the pattern ``mlx.launch -n2 -- my_script.py`` to run a small
scale test on a single node first.
- *Batch your communication.*
As described in the :ref:`training example <training_example>`, performing a
lot of small communication can hurt performance. Copy the approach of
:func:`mlx.nn.average_gradients` to gather many small communications in a
single large one.
- *Visualize the connectivity.*
Use ``mlx.distributed_config --hosts h1,h2,h3 --over thunderbolt --dot`` to
visualize the connnections and make sure that the cables are connected
correctly. See the :ref:`JACCL section <jaccl_section>` for examples.
- *Use the debugger.*
``mlx.launch`` is meant for interactive use. It broadcasts stdin to all
processes and gathers stdout from all processes. This makes using ``pdb`` a
breeze.
* Disable the thunderbolt bridge interface
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
corresponding to that cable in nodes ``i`` and ``i + 1``.
* Set up a unique subnetwork connecting the two nodes for the corresponding
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.

View File

@@ -186,7 +186,7 @@ Boolean masks follow NumPy semantics:
.. code-block:: shell
>>> a = mx.arange(1000).reshape(10, 10, 10)
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
>>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.

View File

@@ -3,6 +3,6 @@ requires = [
"setuptools>=42",
"cmake>=3.25",
"mlx>=0.18.0",
"nanobind==2.4.0",
"nanobind==2.10.2",
]
build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.25
mlx>=0.21.0
nanobind==2.4.0
nanobind==2.10.2

View File

@@ -130,7 +130,7 @@ void compiled_allocate_outputs(
// - Donatable
// - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && is_constant(i)) {
in.is_donatable() && !is_constant(i)) {
outputs[o++].copy_shared_buffer(in);
}
// Get representative input flags to properly set non-donated outputs
@@ -158,7 +158,7 @@ void compiled_allocate_outputs(
// - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
is_constant(i)) {
!is_constant(i)) {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;

View File

@@ -291,6 +291,17 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
num_keys,
kshape = keys.shape(),
kstrides = keys.strides()]() mutable {
auto copy_remaining = [&](char* cptr, size_t loc, uint32_t v) {
if (4 * loc + 4 <= bytes_per_key) {
reinterpret_cast<uint32_t*>(cptr)[loc] = v;
} else {
std::copy(
reinterpret_cast<char*>(&v),
reinterpret_cast<char*>(&v) + bytes_per_key - 4 * loc,
cptr + 4 * loc);
}
};
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
@@ -310,18 +321,12 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) {
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
}
copy_remaining(cptr, count.second, rb.second);
}
if (!even) {
count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first;
copy_remaining(
cptr, half_size, random::threefry2x32_hash(key, count).first);
}
}
});

View File

@@ -3,5 +3,9 @@
#include "mlx/backend/cpu/simd/base_simd.h"
#ifdef MLX_USE_ACCELERATE
#if defined(__x86_64__)
// the accelerate_simd implementation require neon -- use base implementation
#else
#include "mlx/backend/cpu/simd/accelerate_simd.h"
#endif
#endif

View File

@@ -338,28 +338,43 @@ std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
}
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeGraph) {
// Try to be updatable for a structure like graph -> graph -> kernel
cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
is_updatable &= sub_is_updatable;
key += subkey;
} else if (type == cudaGraphNodeTypeMemset) {
key += "M";
} else if (type != cudaGraphNodeTypeKernel) {
is_updatable = false;
} else {
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only allow dim.x to be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
is_updatable = false;
} else {
key += "K";
key += std::to_string(cluster_dim.clusterDim.x);
switch (type) {
case cudaGraphNodeTypeGraph: {
// Try to be updatable for a structure like graph -> graph -> kernel
cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
is_updatable &= sub_is_updatable;
key += subkey;
break;
}
case cudaGraphNodeTypeHost:
key += "H";
break;
case cudaGraphNodeTypeMemset:
key += "M";
break;
case cudaGraphNodeTypeKernel: {
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only allow dim.x to be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
is_updatable = false;
} else {
key += "K";
key += std::to_string(cluster_dim.clusterDim.x);
}
break;
}
case cudaGraphNodeTypeWaitEvent:
key += "W";
break;
case cudaGraphNodeTypeEventRecord:
key += "R";
break;
default:
is_updatable = false;
}
}
key += ")";

View File

@@ -2,7 +2,11 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh"
#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh"
#include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/backend/cuda/vector_types.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
@@ -13,17 +17,6 @@
namespace mlx::core {
namespace cu {
template <int bits>
struct Quantize {
__device__ uint8_t operator()(float x) {
if constexpr (bits == 8) {
return __nv_fp8_e4m3(x).__x;
} else {
return __nv_fp4_e2m1(x).__x;
}
}
};
template <int bits>
struct Dequantize {
__device__ float operator()(uint8_t x) {
@@ -37,29 +30,40 @@ struct Dequantize {
namespace cg = cooperative_groups;
template <typename T, int group_size, int bits, bool use_mx_scale>
__global__ void
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
__global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) {
using Tx2 = Vector2_t<T>;
using Tx4 = Vector4_t<T>;
uint32_t rbits = 0; // reserved bits for future use
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
size_t index = tidx + grid_dim_x * size_t(tidy);
if (index >= size) {
size_t thread_idx = tidx + grid_dim_x * size_t(tidy);
size_t base_idx = thread_idx * group_size;
if (base_idx >= size) {
return;
}
float w_thread = w[index];
auto w_tile = load_vector<group_size, T>(w, thread_idx);
float scale = 0.0f;
cg::greater<float> max_op;
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
Tx2 amax_2x = Tx2{0.0f, 0.0f};
#pragma unroll
for (int i = 0; i < group_size; i += 2) {
auto pair = Tx2{w_tile[i], w_tile[i + 1]};
abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
}
scale = static_cast<float>(
max(fabsf(static_cast<float>(amax_2x.x)),
fabsf(static_cast<float>(amax_2x.y))));
float scale = cg::reduce(warp, abs(w_thread), max_op);
scale /= bits == 4 ? 6.0f : 448.0f;
// Convert to mx scale or nv scale
using ScaleType =
@@ -68,21 +72,24 @@ fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
uint8_t q_scale = s.__x;
scale = float(s);
// Write out the scales
size_t gindex = index / group_size;
if (index % group_size == 0) {
scales[gindex] = q_scale;
}
scales[thread_idx] = q_scale;
constexpr int elem_per_byte = bits == 8 ? 1 : 2;
AlignedVector<uint8_t, group_size / elem_per_byte> quantized;
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
if (bits == 4) {
uint8_t sval = warp.shfl_down(output, 1);
output |= sval << bits;
}
constexpr int pack_factor = bits == 8 ? 1 : 2;
if (index % pack_factor == 0) {
out[index / pack_factor] = output;
#pragma unroll
for (int i = 0; i < group_size / 4; i++) {
Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);
if constexpr (bits == 8) {
uint32_t quantized_val =
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
*reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;
} else {
uint16_t quantized_val =
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
*reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;
}
}
store_vector<group_size / elem_per_byte>(out, thread_idx, quantized);
}
template <typename T, int group_size, int bits, bool use_mx_scale>
@@ -142,15 +149,16 @@ void fp_quantize(
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_quantize<T, 32, 4, true>;
auto kernel = cu::fp_quantize<T, 32, 4, true, false>;
if (bits == 8) {
kernel = cu::fp_quantize<T, 32, 8, true>;
kernel = cu::fp_quantize<T, 32, 8, true, false>;
} else if (group_size == 16) {
kernel = cu::fp_quantize<T, 16, 4, false>;
kernel = cu::fp_quantize<T, 16, 4, false, false>;
}
bool large = w.size() > UINT_MAX;
auto [num_blocks, block_dims] =
get_launch_args(w.size(), w.shape(), w.strides(), large);
get_launch_args(w.size(), w.shape(), w.strides(), large, group_size);
enc.add_kernel_node(
kernel,
num_blocks,

View File

@@ -0,0 +1,32 @@
#pragma once
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
// TODO implement fast path
template <typename T>
__device__ __forceinline__ uint32_t
scale_cvt_Tx4_to_fp8x4_fallback(const Vector4_t<T> input, const float scale) {
uint32_t out_fp8x4 = 0;
float4 scaled;
scaled.x = static_cast<float>(input.x) * scale;
scaled.y = static_cast<float>(input.y) * scale;
scaled.z = static_cast<float>(input.z) * scale;
scaled.w = static_cast<float>(input.w) * scale;
out_fp8x4 = __nv_fp8x4_e4m3(scaled).__x;
return out_fp8x4;
}
// Place holder for future fast path implementation
template <typename T, bool USE_SR>
__device__ __forceinline__ uint32_t scale_cvt_Tx4_to_fp8x4(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
return scale_cvt_Tx4_to_fp8x4_fallback(input, scale);
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,334 @@
#pragma once
#include <cuda.h>
#include <cuda_fp4.h>
#include <cuda_runtime.h>
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
using bf16x4 = Vector4_t<__nv_bfloat16>;
using fp16x4 = Vector4_t<__half>;
using f32x4 = Vector4_t<float>;
template <typename T>
__device__ __forceinline__ uint16_t
scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t<T> input, const float scale) {
// Fallback implementation for architectures that do not support cvt
// instructions or for cuda versions with no fp4 support (< 12.8) -> scalar
uint16_t out_fp4x4 = 0;
fp32x4 scaled;
scaled.x = static_cast<float>(input.x) * scale;
scaled.y = static_cast<float>(input.y) * scale;
scaled.z = static_cast<float>(input.z) * scale;
scaled.w = static_cast<float>(input.w) * scale;
uint8_t q0 = __nv_fp4_e2m1(scaled.x).__x;
uint8_t q1 = __nv_fp4_e2m1(scaled.y).__x;
uint8_t q2 = __nv_fp4_e2m1(scaled.z).__x;
uint8_t q3 = __nv_fp4_e2m1(scaled.w).__x;
out_fp4x4 = (static_cast<uint16_t>(q3) << 12) |
(static_cast<uint16_t>(q2) << 8) | (static_cast<uint16_t>(q1) << 4) |
static_cast<uint16_t>(q0);
return out_fp4x4;
}
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
defined(__CUDA_ARCH_SPECIFIC__)
__device__ __forceinline__ uint16_t
scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_bf16; \n\t" // first bf16
".reg.b16 x1_bf16; \n\t" // second bf16
".reg.b16 x2_bf16; \n\t" // third bf16
".reg.b16 x3_bf16; \n\t" // fourth bf16
".reg.b32 x0; \n\t" // to hold scaled first
".reg.b32 x1; \n\t" // to hold scaled second
".reg.b32 x2; \n\t" // to hold scaled third
".reg.b32 x3; \n\t" // to hold scaled fourth
".reg.b64 x01; \n\t" // to hold vector mul
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t" // output byte fp4x2 (first pair)
".reg.b8 q1; \n\t" // output byte fp4x2 (second pair)
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16
"cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32
"cvt.f32.bf16 x1, x1_bf16; \n\t"
"cvt.f32.bf16 x2, x2_bf16; \n\t"
"cvt.f32.bf16 x3, x3_bf16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t" // scale first pair
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t" // scale second pair
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first
// pair
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second
// pair
"mov.b16 %0, {q0, q1}; \n\t" // pack to output
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
"l"(reinterpret_cast<const uint64_t&>(
scale))); // here cast is needed becuase an asm operand must have
// scalar type
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs(
const bf16x4 input_bf16x4,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_bf16; \n\t"
".reg.b16 x1_bf16; \n\t"
".reg.b16 x2_bf16; \n\t"
".reg.b16 x3_bf16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t"
"cvt.f32.bf16 x0, x0_bf16; \n\t"
"cvt.f32.bf16 x1, x1_bf16; \n\t"
"cvt.f32.bf16 x2, x2_bf16; \n\t"
"cvt.f32.bf16 x3, x3_bf16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn(
const float2 input_fp32x2_0,
const float2 input_fp32x2_1,
const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t"
".reg.b8 q1; \n\t"
"mov.b64 x01, {%1, %2}; \n\t"
"mul.f32x2 x01, x01, %5; \n\t"
"mov.b64 x23, {%3, %4}; \n\t"
"mul.f32x2 x23, x23, %5; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
"mov.b16 %0, {q0, q1}; \n\t"
"}"
: "=h"(out_fp4x4)
: "f"(input_fp32x2_0.x),
"f"(input_fp32x2_0.y),
"f"(input_fp32x2_1.x),
"f"(input_fp32x2_1.y),
"l"(reinterpret_cast<const uint64_t&>(scale)));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs(
const float2 input_fp32x2_0,
const float2 input_fp32x2_1,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 x01, {%1, %2}; \n\t"
"mul.f32x2 x01, x01, %5; \n\t"
"mov.b64 x23, {%3, %4}; \n\t"
"mul.f32x2 x23, x23, %5; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t"
"}"
: "=h"(out_fp4x4)
: "f"(input_fp32x2_0.x),
"f"(input_fp32x2_0.y),
"f"(input_fp32x2_1.x),
"f"(input_fp32x2_1.y),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t
scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_fp16; \n\t"
".reg.b16 x1_fp16; \n\t"
".reg.b16 x2_fp16; \n\t"
".reg.b16 x3_fp16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t"
".reg.b8 q1; \n\t"
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
"cvt.f32.f16 x0, x0_fp16; \n\t"
"cvt.f32.f16 x1, x1_fp16; \n\t"
"cvt.f32.f16 x2, x2_fp16; \n\t"
"cvt.f32.f16 x3, x3_fp16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
"mov.b16 %0, {q0, q1}; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs(
const fp16x4 input_fp16x4,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_fp16; \n\t"
".reg.b16 x1_fp16; \n\t"
".reg.b16 x2_fp16; \n\t"
".reg.b16 x3_fp16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
"cvt.f32.f16 x0, x0_fp16; \n\t"
"cvt.f32.f16 x1, x1_fp16; \n\t"
"cvt.f32.f16 x2, x2_fp16; \n\t"
"cvt.f32.f16 x3, x3_fp16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4(
const bf16x4 input,
const float scale,
uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
if constexpr (USE_SR) {
return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
} else {
return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2);
}
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4(
const fp16x4 input,
const float scale,
uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
if constexpr (USE_SR) {
return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
} else {
return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2);
}
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t
scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
float2 input_fp32x2_0 = make_float2(input.x, input.y);
float2 input_fp32x2_1 = make_float2(input.z, input.w);
if constexpr (USE_SR) {
return scale_cvt_fp32x4_to_fp4x4_rs(
input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits);
} else {
return scale_cvt_fp32x4_to_fp4x4_rn(
input_fp32x2_0, input_fp32x2_1, scale_fp32x2);
}
}
template <typename T, bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return scale_cvt_bf16x4_to_fp4x4<USE_SR>(input, scale, rbits);
} else if constexpr (std::is_same<T, __half>::value) {
return scale_cvt_fp16x4_to_fp4x4<USE_SR>(input, scale, rbits);
} else {
return scale_cvt_f32x4_to_fp4x4<USE_SR>(input, scale, rbits);
}
}
#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
// (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
template <typename T, bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
(__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
return scale_cvt_Tx4_to_fp4x4_fast<T, USE_SR>(input, scale, rbits);
#else
static_assert(
!USE_SR,
"Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000.");
return scale_cvt_Tx4_to_fp4x4_fallback(input, scale);
#endif
}
} // namespace mlx::core::cu

View File

@@ -15,6 +15,22 @@ inline constexpr __device__ short get_bytes_per_pack() {
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T>
__device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) {
if constexpr (
(std::is_same<T, __nv_bfloat162>::value) ||
(std::is_same<T, __half2>::value)) {
T a = x1;
T b = x2;
out = __hmax2(__habs2(a), __habs2(b));
} else if constexpr (std::is_same<T, float2>::value) {
float2 a = x1;
float2 b = x2;
out.x = fmaxf(fabsf(a.x), fabsf(b.x));
out.y = fmaxf(fabsf(a.y), fabsf(b.y));
}
}
} // namespace cu
template <typename F>

View File

@@ -3,31 +3,10 @@
#pragma once
#include "mlx/backend/cuda/steel/utils.cuh"
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
// Map types to their vector of 2 type float -> float2, double -> double2 etc
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
/**
* The basic building block for Ampere mmas. A 16x16 tile distributed across
* the warp.

View File

@@ -0,0 +1,48 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace mlx::core::cu {
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
template <typename T>
struct Vector4 {
T x, y, z, w;
};
template <typename T>
using Vector4_t = Vector4<T>;
using bf16x4 = Vector4_t<__nv_bfloat16>;
using fp16x4 = Vector4_t<__half>;
using fp32x4 = Vector4_t<float>;
} // namespace mlx::core::cu

View File

@@ -347,7 +347,7 @@ template <
MMAFrag_mask_t::load_safe(
mfrag,
mask,
int(mask_params->M_strides[2]),
int64_t(mask_params->M_strides[2]),
Int<1>{},
params->qL,
params->kL,

View File

@@ -346,7 +346,7 @@ template <
MSubTile mfrag;
mfrag.load_safe(
mask,
int(mask_params->M_strides[2]),
int64_t(mask_params->M_strides[2]),
Int<1>{},
params->qL,
params->kL,

View File

@@ -105,17 +105,20 @@ struct BaseMMAFrag<T, 8, 8> {
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
src += off_x * str_x + off_y * str_y;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[i * kElemCols + j] =
static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
dst[i * kElemCols + j] = static_cast<T>(src[0]);
} else {
dst[i * kElemCols + j] = T(0);
}
src += str_y;
}
src -= kElemCols * str_y;
src += str_x;
}
}

View File

@@ -4,11 +4,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
if(MLX_BUILD_CPU AND NOT WIN32)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)

View File

@@ -5,7 +5,6 @@
#include "mlx/backend/cuda/cuda.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/jaccl/jaccl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h"
@@ -103,27 +102,7 @@ class EmptyGroup : public GroupImpl {
} // namespace detail
bool is_available() {
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
jaccl::is_available();
}
bool is_available(const std::string& bk) {
if (bk == "any") {
return is_available();
}
if (bk == "mpi") {
return mpi::is_available();
}
if (bk == "ring") {
return ring::is_available();
}
if (bk == "nccl") {
return nccl::is_available();
}
if (bk == "jaccl") {
return jaccl::is_available();
}
return false;
return mpi::is_available() || ring::is_available() || nccl::is_available();
}
int Group::rank() const {
@@ -156,8 +135,6 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = ring::init(strict);
} else if (bk == "nccl") {
group = nccl::init(strict);
} else if (bk == "jaccl") {
group = jaccl::init(strict);
} else if (bk == "any") {
if (mlx::core::cu::is_available()) {
group = nccl::init(false);
@@ -171,17 +148,13 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = mpi::init(false);
bk_ = "mpi";
}
if (group == nullptr) {
group = jaccl::init(false);
bk_ = "jaccl";
}
if (group == nullptr && strict) {
throw std::runtime_error("[distributed] Couldn't initialize any backend");
}
} else {
std::ostringstream msg;
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
<< "'jaccl' and 'ring' but '" << bk << "' was provided.";
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
<< "and 'ring' but '" << bk << "' was provided.";
throw std::invalid_argument(msg.str());
}

View File

@@ -16,7 +16,6 @@ class GroupImpl;
/* Check if a communication backend is available */
bool is_available();
bool is_available(const std::string& bk);
/**
* A distributed::Group represents a group of independent mlx processes that

View File

@@ -1,8 +0,0 @@
if(MLX_BUILD_CPU
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp)
target_link_libraries(mlx PRIVATE rdma)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)
endif()

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::jaccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::jaccl

View File

@@ -1,20 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/distributed/jaccl/jaccl.h"
namespace mlx::core::distributed::jaccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available() {
return false;
}
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) {
throw std::runtime_error("Cannot initialize jaccl distributed backend.");
}
return nullptr;
}
} // namespace mlx::core::distributed::jaccl

View File

@@ -1,38 +0,0 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::distributed::detail {
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
};
template <typename T>
struct MaxOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output = std::max(*output, *input);
input++;
output++;
}
}
};
template <typename T>
struct MinOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output = std::min(*output, *input);
input++;
output++;
}
}
};
} // namespace mlx::core::distributed::detail

View File

@@ -1,6 +1,9 @@
// Copyright © 2024 Apple Inc.
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <unistd.h>
@@ -19,8 +22,6 @@
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/reduction_ops.h"
#include "mlx/distributed/utils.h"
#include "mlx/threadpool.h"
#ifndef SOL_TCP
@@ -93,7 +94,6 @@ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
constexpr const size_t ALL_SUM_BUFFERS = 2;
constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000;
constexpr const char* RING_TAG = "[ring]";
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
using json = nlohmann::json;
@@ -296,6 +296,55 @@ class CommunicationThreads {
std::unordered_map<int, SocketThread> threads_;
};
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* get() const {
return (struct sockaddr*)&addr;
}
};
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port) {
struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port) {
auto colon = ip_port.find(":");
if (colon == std::string::npos) {
std::ostringstream msg;
msg << "Can't parse address " << ip_port;
throw std::runtime_error(msg.str());
}
std::string ip(ip_port.begin(), ip_port.begin() + colon);
std::string port(ip_port.begin() + colon + 1, ip_port.end());
return parse_address(ip, port);
}
/**
* Load all addresses from the json hostfile. The hostfile is a list of
* addresses in order of rank. For each rank there can be many addresses so
@@ -308,15 +357,15 @@ class CommunicationThreads {
* ["ip3:5000", "ip3:5001"],
* ]
*/
std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
std::vector<std::vector<detail::address_t>> nodes;
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
std::vector<std::vector<address_t>> nodes;
std::ifstream f(hostfile);
json hosts = json::parse(f);
for (auto& h : hosts) {
std::vector<detail::address_t> host;
std::vector<address_t> host;
for (auto& ips : h) {
host.push_back(std::move(detail::parse_address(ips.get<std::string>())));
host.push_back(parse_address(ips.get<std::string>()));
}
nodes.push_back(std::move(host));
}
@@ -328,15 +377,73 @@ std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
* Create a socket and accept one connection for each of the provided
* addresses.
*/
std::vector<int> accept_connections(
const std::vector<detail::address_t>& addresses) {
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
std::vector<int> sockets;
int success;
for (auto& address : addresses) {
detail::TCPSocket socket(RING_TAG);
socket.listen(RING_TAG, address);
sockets.push_back(socket.accept(RING_TAG).detach());
// Create the socket to wait for connections from the peers
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind the socket to the address and port
success = bind(sock, address.get(), address.len);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Wait for connections
success = listen(sock, 0);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
int peer_socket = accept(sock, nullptr, nullptr);
if (peer_socket < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Close the listening socket
shutdown(sock, 2);
close(sock);
sockets.push_back(peer_socket);
}
return sockets;
@@ -347,42 +454,93 @@ std::vector<int> accept_connections(
* provided addresses.
*/
std::vector<int> make_connections(
const std::vector<detail::address_t>& addresses,
const std::vector<address_t>& addresses,
bool verbose) {
std::vector<int> sockets;
int success;
for (auto& address : addresses) {
sockets.push_back(detail::TCPSocket::connect(
RING_TAG,
address,
CONN_ATTEMPTS,
CONN_WAIT,
[verbose](int attempt, int wait) {
log_info(
verbose,
"Attempt",
attempt,
"waiting",
wait,
"ms (error:",
errno,
")");
})
.detach());
int sock;
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
// backoff. TODO: Do we need that?
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
// Create the socket
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
if (attempt > 0) {
int wait = (1 << (attempt - 1)) * CONN_WAIT;
log_info(
verbose,
"Attempt",
attempt,
"wait",
wait,
"ms (error:",
errno,
")");
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
success = connect(sock, address.get(), address.len);
if (success == 0) {
break;
}
}
if (success < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't connect (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockets.push_back(sock);
}
return sockets;
}
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
};
template <typename T>
struct MaxOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output = std::max(*output, *input);
input++;
output++;
}
}
};
template <typename T>
struct MinOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output = std::min(*output, *input);
input++;
output++;
}
}
};
} // namespace
class RingGroup : public GroupImpl {
public:
RingGroup(
int rank,
std::vector<std::vector<detail::address_t>> nodes,
bool verbose)
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
: rank_(rank), verbose_(verbose), pool_(0) {
if (rank_ > 0 && rank_ >= nodes.size()) {
throw std::runtime_error(
@@ -475,17 +633,17 @@ class RingGroup : public GroupImpl {
void all_sum(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T>(input, output, stream, detail::SumOp<T>()));
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
}
void all_max(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T>(input, output, stream, detail::MaxOp<T>()));
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
}
void all_min(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T>(input, output, stream, detail::MinOp<T>()));
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
}
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {

View File

@@ -1,204 +0,0 @@
// Copyright © 2025 Apple Inc.
#include <netdb.h>
#include <unistd.h>
#include <cstring>
#include <sstream>
#include <thread>
#include "mlx/distributed/utils.h"
namespace mlx::core::distributed::detail {
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port) {
struct addrinfo hints, *res;
std::memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port) {
auto colon = ip_port.find(":");
if (colon == std::string::npos) {
std::ostringstream msg;
msg << "Can't parse address " << ip_port;
throw std::runtime_error(msg.str());
}
std::string ip(ip_port.begin(), ip_port.begin() + colon);
std::string port(ip_port.begin() + colon + 1, ip_port.end());
return parse_address(ip, port);
}
TCPSocket::TCPSocket(const char* tag) {
sock_ = socket(AF_INET, SOCK_STREAM, 0);
if (sock_ < 0) {
std::ostringstream msg;
msg << tag << " Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
TCPSocket::TCPSocket(TCPSocket&& s) {
sock_ = s.sock_;
s.sock_ = -1;
}
TCPSocket& TCPSocket::operator=(TCPSocket&& s) {
if (this != &s) {
sock_ = s.sock_;
s.sock_ = -1;
}
return *this;
}
TCPSocket::TCPSocket(int s) : sock_(s) {}
TCPSocket::~TCPSocket() {
if (sock_ > 0) {
shutdown(sock_, 2);
close(sock_);
}
}
int TCPSocket::detach() {
int s = sock_;
sock_ = -1;
return s;
}
void TCPSocket::listen(const char* tag, const address_t& addr) {
int success;
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't enable reuseport (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind the socket to the address and port
success = bind(sock_, addr.get(), addr.len);
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't bind socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Prepare waiting for connections
success = ::listen(sock_, 0);
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
TCPSocket TCPSocket::accept(const char* tag) {
int peer = ::accept(sock_, nullptr, nullptr);
if (peer < 0) {
std::ostringstream msg;
msg << tag << " Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
return TCPSocket(peer);
}
void TCPSocket::send(const char* tag, const void* data, size_t len) {
while (len > 0) {
auto n = ::send(sock_, data, len, 0);
if (n <= 0) {
std::ostringstream msg;
msg << tag << " Send failed with errno=" << errno;
throw std::runtime_error(msg.str());
}
len -= n;
data = static_cast<const char*>(data) + n;
}
}
void TCPSocket::recv(const char* tag, void* data, size_t len) {
while (len > 0) {
auto n = ::recv(sock_, data, len, 0);
if (n <= 0) {
std::ostringstream msg;
msg << tag << " Recv failed with errno=" << errno;
throw std::runtime_error(msg.str());
}
len -= n;
data = static_cast<char*>(data) + n;
}
}
TCPSocket TCPSocket::connect(
const char* tag,
const address_t& addr,
int num_retries,
int wait,
std::function<void(int, int)> cb) {
int sock, success;
// Attempt to connect `num_retries` times with exponential backoff.
for (int attempt = 0; attempt < num_retries; attempt++) {
// Create the socket
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << tag << " Couldn't create socket to connect (error: " << errno
<< ")";
throw std::runtime_error(msg.str());
}
success = ::connect(sock, addr.get(), addr.len);
if (success == 0) {
break;
}
cb(attempt, wait);
if (wait > 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
wait <<= 1;
}
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't connect (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
return TCPSocket(sock);
}
} // namespace mlx::core::distributed::detail

View File

@@ -1,67 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/socket.h>
#include <functional>
#include <string>
namespace mlx::core::distributed::detail {
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* get() const {
return (struct sockaddr*)&addr;
}
};
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port);
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port);
/**
* Small wrapper over a TCP socket to simplify initiating connections.
*/
class TCPSocket {
public:
TCPSocket(const char* tag);
TCPSocket(const TCPSocket&) = delete;
TCPSocket& operator=(const TCPSocket&) = delete;
TCPSocket(TCPSocket&& s);
TCPSocket& operator=(TCPSocket&&);
~TCPSocket();
void listen(const char* tag, const address_t& addr);
TCPSocket accept(const char* tag);
void send(const char* tag, const void* data, size_t len);
void recv(const char* tag, void* data, size_t len);
int detach();
operator int() const {
return sock_;
}
static TCPSocket connect(
const char* tag,
const address_t& addr,
int num_retries = 1,
int wait = 0,
std::function<void(int, int)> cb = nullptr);
private:
TCPSocket(int sock);
int sock_;
};
} // namespace mlx::core::distributed::detail

View File

@@ -880,6 +880,11 @@ std::vector<array> ScaledDotProductAttention::vjp(
std::vector<array> returned_vjps;
for (int arg : argnums) {
if (arg >= 3) {
throw std::invalid_argument(
"[scale_dot_product_attention] Does not support VJP with respect "
" to mask or attention sinks.");
}
returned_vjps.push_back(std::move(vjps[arg]));
}
return returned_vjps;

View File

@@ -1,7 +1,7 @@
[build-system]
requires = [
"setuptools>=80",
"nanobind==2.4.0",
"nanobind==2.10.2",
"cmake>=3.25",
]
build-backend = "setuptools.build_meta"

View File

@@ -1,95 +0,0 @@
# Copyright © 2025 Apple Inc.
import argparse
import ipaddress
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
@dataclass
class Host:
rank: int
ssh_hostname: str
ips: list[str]
rdma: list[Optional[str]]
class OptionalBoolAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if option_string.startswith("--no-"):
setattr(namespace, self.dest, False)
else:
setattr(namespace, self.dest, True)
def positive_number(x):
x = int(x)
if x <= 0:
raise ValueError("Number should be positive")
return x
def log(verbose, *args, **kwargs):
if not verbose:
return
kwargs["file"] = sys.stderr
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
def log_warning(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
def log_error(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
def parse_hostlist(parser, hostlist, repeats):
hosts = []
for i, h in enumerate(hostlist.split(",")):
if h == "":
raise ValueError("Hostname cannot be empty")
try:
ipaddress.ip_address(h)
ips = [h]
except ValueError:
ips = []
for i in range(repeats):
hosts.append(Host(i, h, ips, []))
return hosts
def parse_hostfile(parser, hostfile):
"""Parse the json hostfile that contains both the hostnames to ssh into and
the ips to communicate over when using the ring backend.
Example:
[
{"ssh": "hostname1", "ips": ["123.123.123.1"], "rdma": [null, "rdma_en2", "rdma_en3"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"], "rdma": ["rdma_en2", null, "rdma_en3"]},
...
{"ssh": "hostnameN", "ips": ["123.123.123.N"], "rdma": ["rdma_en2", "rdma_en3", null]},
]
Args:
hostfile (str): The path to the json file containing the host
information
"""
hostfile = Path(hostfile)
if not hostfile.exists():
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
try:
hosts = []
with open(hostfile) as f:
for i, h in enumerate(json.load(f)):
hosts.append(Host(i, h["ssh"], h.get("ips", []), h.get("rdma", [])))
return hosts
except Exception as e:
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")

View File

@@ -1,570 +0,0 @@
# Copyright © 2025 Apple Inc.
import argparse
import json
import shlex
import sys
import threading
from collections import defaultdict
from dataclasses import dataclass
from subprocess import DEVNULL, run
from typing import Optional
import mlx.core as mx
from .common import (
Host,
OptionalBoolAction,
log,
log_error,
parse_hostfile,
parse_hostlist,
)
@dataclass
class SSHInfo:
can_ssh: bool
has_sudo: bool
def __bool__(self):
return self.can_ssh
@dataclass
class ThunderboltPort:
iface: str
uuid: str
connected_to: Optional[str]
@dataclass
class ThunderboltHost:
name: str
ports: list[ThunderboltPort]
def add_ethernet_ips(hosts, verbose=False):
# Get the ips for each host
for h in hosts:
log(verbose, "Getting the ip from", h.ssh_hostname)
h.ips.append(
run(
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
capture_output=True,
text=True,
).stdout.strip()
)
def check_rdma(hosts, verbose=False):
# Check whether the hosts are capable of RDMA over thunderbolt
warn = False
for h in hosts:
log(verbose, "Checking that", h.ssh_hostname, "supports RDMA")
rdma_devs = (
run(["ssh", h.ssh_hostname, "ibv_devices"], capture_output=True, text=True)
.stdout.strip()
.split()
)
rdma_devs = [d for d in rdma_devs if d.startswith("rdma_")]
if not rdma_devs:
log_warning(h.ssh_hostname, "does not seem to have RDMA enabled")
warn = True
if warn:
log_warning()
log_warning(
"Some of the hosts don't have RDMA enabled or they don't support RDMA."
)
log_warning()
log_warning(
"See https://ml-explore.github.io/mlx/build/html/usage/distributed.html"
)
log_warning("for instructions on how to enable RDMA.")
def can_auto_setup(hosts, sshinfo, auto_setup=False):
has_sudo = all(info.has_sudo for info in sshinfo)
if not has_sudo and auto_setup:
log_warning(
"Automatic setup requested but the following hosts do not have passwordless sudo"
)
for h, i in zip(hosts, sshinfo):
if not i.has_sudo:
log_warning(" - ", h.ssh_hostname)
return has_sudo
class IPConfigurator:
def __init__(self, hosts, tb_hosts, uuid_reverse_index):
assigned = set()
ips = defaultdict(list)
ip0 = 0
ip1 = 0
for src_node, h in enumerate(tb_hosts):
for src_port, p in enumerate(h.ports):
if not p.connected_to:
continue
if p.connected_to not in uuid_reverse_index:
continue
if (src_node, src_port) in assigned:
continue
dst_node, dst_port = uuid_reverse_index[p.connected_to]
ip_src = f"192.168.{ip0}.{ip1 + 1}"
ip_dst = f"192.168.{ip0}.{ip1 + 2}"
iface_src = p.iface
iface_dst = tb_hosts[dst_node].ports[dst_port].iface
ips[src_node, dst_node].append((iface_src, ip_src))
ips[dst_node, src_node].append((iface_dst, ip_dst))
assigned.add((src_node, src_port))
assigned.add((dst_node, dst_port))
ip1 += 4
if ip1 > 255:
ip0 += 1
ip1 = 0
if ip0 > 255:
raise ValueError("Ran out of available local IPs")
self.ips = ips
self.hosts = hosts
self.tb_hosts = tb_hosts
def setup(self, verbose=False, auto_setup=False):
netmask = "255.255.255.252"
for i, (h, th) in enumerate(zip(self.hosts, self.tb_hosts)):
command = ""
command += "sudo ifconfig bridge0 down\n"
for j in range(len(self.hosts)):
if i == j or (i, j) not in self.ips:
continue
for (iface, ip), (_, peer) in zip(self.ips[i, j], self.ips[j, i]):
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
command += f"sudo route change {peer} -interface {iface}\n"
if auto_setup:
print(f"Running auto setup for {h.ssh_hostname}")
command = command.strip().replace("\n", " ; ")
command = ["ssh", h.ssh_hostname, command]
log(verbose, shlex.join(command))
run(command)
else:
msg = f"Setup for {h.ssh_hostname}"
print(msg)
print("=" * len(msg))
print(command)
input("Enter to continue")
print()
def parse_hardware_ports(ports_string):
ports = {}
port_name = None
for l in ports_string.decode("utf-8").split("\n"):
if l.startswith("Hardware Port:"):
port_name = l.strip()[15:]
elif l.startswith("Device:"):
ports[port_name] = l.strip()[8:]
port_name = None
return ports
def extract_connectivity(hosts, verbose):
# Extract the current connectivity from the remote hosts
thunderbolt_connections = []
for h in hosts:
log(verbose, "Getting connectivity from", h.ssh_hostname)
thunderbolt_connections.append(
json.loads(
run(
[
"ssh",
h.ssh_hostname,
"system_profiler",
"SPThunderboltDataType",
"-json",
],
capture_output=True,
).stdout
)
)
interface_maps = []
for h in hosts:
log(verbose, "Getting interface names from", h.ssh_hostname)
interface_maps.append(
parse_hardware_ports(
run(
[
"ssh",
h.ssh_hostname,
"networksetup",
"-listallhardwareports",
],
capture_output=True,
).stdout
)
)
# Parse the connectivity into some simple dataclasses
tb_hosts = []
for c, iface_map in zip(thunderbolt_connections, interface_maps):
name = ""
ports = []
for t in c["SPThunderboltDataType"]:
uuid = t.get("domain_uuid_key")
if uuid is None:
continue
name = t["device_name_key"]
tag = t["receptacle_1_tag"]["receptacle_id_key"]
items = t.get("_items", [])
connected_items = [item for item in items if "domain_uuid_key" in item]
connected_to = (
connected_items[0]["domain_uuid_key"] if connected_items else None
)
iface = iface_map[f"Thunderbolt {tag}"]
ports.append(ThunderboltPort(iface, uuid, connected_to))
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
# Create a reverse index to be able to map uuids to (host, port) quickly
uuid_reverse_index = {}
for i, h in enumerate(tb_hosts):
for j, p in enumerate(h.ports):
uuid_reverse_index[p.uuid] = (i, j)
return tb_hosts, uuid_reverse_index
def make_connectivity_matrix(tb_hosts, uuid_reverse_index):
connectivity = []
for i, h in enumerate(tb_hosts):
c = [0] * len(tb_hosts)
for p in h.ports:
if p.connected_to in uuid_reverse_index:
j, _ = uuid_reverse_index[p.connected_to]
c[j] += 1
connectivity.append(c)
return connectivity
def tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index):
# Make ids per node
names = []
for i in range(len(tb_hosts)):
n = ""
j = i
while True:
n += chr(97 + j % 26)
j //= 26
if j == 0:
break
names.append(n)
print("graph G {")
print(" node [shape=rectangle];")
for i, h in enumerate(hosts):
print(f' {names[i]} [label="{h.ssh_hostname}"];')
for i, h in enumerate(tb_hosts):
for p in h.ports:
if not p.connected_to:
continue
dst = uuid_reverse_index[p.connected_to]
if dst[0] < i:
continue
print(f" {names[i]} -- {names[dst[0]]}", end="")
print(f' [label="{p.iface}/{tb_hosts[dst[0]].ports[dst[1]].iface}"]')
print("}")
def extract_rings(connectivity):
rings = []
existing_rings = set()
num_nodes = len(connectivity)
def dfs(start_node, node, path, visited):
path.append(node)
visited.add(node)
for j in range(num_nodes):
if connectivity[node][j] <= 0:
continue
if j == start_node:
yield path[:]
if j not in visited:
yield from dfs(start_node, j, path, visited)
path.pop()
visited.remove(node)
for start in range(num_nodes):
for r in dfs(start, start, [], set()):
cnt = min(connectivity[r[i]][r[(i + 1) % len(r)]] for i in range(len(r)))
rkey = tuple(sorted(r))
if rkey not in existing_rings:
rings.append((r, cnt))
existing_rings.add(rkey)
return sorted(rings, key=lambda x: -len(x[0]))
def check_valid_mesh(hosts, connectivity, strict=True):
num_nodes = len(connectivity)
for i in range(num_nodes):
for j in range(num_nodes):
if i == j:
continue
if connectivity[i][j] <= 0:
if strict:
log_error(
f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}"
)
log_error()
log_error("Try passing --dot to visualize the connectivity")
sys.exit(1)
else:
return False
return True
def check_ssh_connections(hosts):
results = [None] * len(hosts)
def _check(hostname, i):
info = SSHInfo(False, False)
results[i] = info
# Check for ssh
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"echo",
"success",
],
stdout=DEVNULL,
stderr=DEVNULL,
)
info.can_ssh = result.returncode == 0
if not info.can_ssh:
return
# Check for sudo
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"sudo",
"ls",
],
stdout=DEVNULL,
stderr=DEVNULL,
)
info.has_sudo = result.returncode == 0
threads = [
threading.Thread(target=_check, args=(h.ssh_hostname, i))
for i, h in enumerate(hosts)
]
for t in threads:
t.start()
for t in threads:
t.join()
if not all(results):
log_error("Could not ssh to the following hosts:")
for i, h in enumerate(hosts):
if not results[i]:
log_error(" - ", h.ssh_hostname)
log_error()
log_error("Maybe they are not set-up for password-less ssh?")
sys.exit(1)
return results
def prepare_ethernet_hostfile(args, hosts):
log(args.verbose, f"Preparing an ethernet hostfile")
add_ethernet_ips(hosts, args.verbose)
hostfile = []
for h in hosts:
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def configure_ring(args, hosts, ips, ring, sshinfo):
log(args.verbose, "Prepare a ring hostfile")
ring, count = ring
hostfile = []
for i, node in enumerate(ring):
h = hosts[node]
peer = ring[i - 1]
hostfile.append(
{
"ssh": h.ssh_hostname,
"ips": [ips.ips[node, peer][c][1] for c in range(count)],
"rdma": [],
}
)
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def configure_jaccl(args, hosts, ips, sshinfo):
log(args.verbose, "Prepare a jaccl hostfile")
check_rdma(hosts, args.verbose)
add_ethernet_ips(hosts, args.verbose)
hostfile = []
for i, h in enumerate(hosts):
rdma = []
for j in range(len(hosts)):
if i == j:
rdma.append(None)
else:
rdma.append(f"rdma_{ips.ips[i, j][0][0]}")
hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma})
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def prepare_tb_hostfile(args, hosts, sshinfo):
log(args.verbose, f"Preparing for communication over thunderbolt")
tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)
if args.dot:
tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index)
return
ips = IPConfigurator(hosts, tb_hosts, uuid_reverse_index)
connectivity = make_connectivity_matrix(tb_hosts, uuid_reverse_index)
if args.backend is None:
rings = extract_rings(connectivity)
has_mesh = check_valid_mesh(hosts, connectivity, False)
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
if not has_ring and not has_mesh:
log_error("Neither thunderbolt mesh nor ring found.")
log_error("Perhaps run with --dot to generate a plot of the connectivity.")
sys.exit(1)
elif has_ring:
configure_ring(args, hosts, ips, rings[0], sshinfo)
else:
configure_jaccl(args, hosts, ips, sshinfo)
elif args.backend == "ring":
rings = extract_rings(connectivity)
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
if not has_ring:
log_error("Could not find a full ring.")
log_error()
log_error("Try passing --dot to visualize the connectivity")
if len(rings) > 0:
log_error("Rings found:")
for r in rings:
log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}")
sys.exit(1)
configure_ring(args, hosts, ips, rings[0], sshinfo)
elif args.backend == "jaccl":
check_valid_mesh(hosts, connectivity)
configure_jaccl(args, hosts, ips, sshinfo)
def main():
parser = argparse.ArgumentParser(
description="Configure remote machines for use with MLX distributed"
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--over",
choices=["thunderbolt", "ethernet"],
default="thunderbolt",
help="What type of connectivity to configure",
required=True,
)
parser.add_argument(
"--output-hostfile", help="If provided, save the hostfile to this path"
)
parser.add_argument(
"--auto-setup",
"--no-auto-setup",
action=OptionalBoolAction,
nargs=0,
dest="auto_setup",
default=None,
)
parser.add_argument(
"--dot", action="store_true", help="Output the topology in DOT format and exit"
)
parser.add_argument(
"--backend",
choices=["ring", "jaccl"],
default=None,
help="Which distributed backend to configure",
)
args = parser.parse_args()
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, 1)
# Check that we can ssh
log(
args.verbose,
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
)
sshinfo = check_ssh_connections(hosts)
# Prepare a hostfile for communication over ethernet using the ips of the
# provided hostnames
if args.over == "ethernet":
prepare_ethernet_hostfile(args, hosts)
# Configure the macs for communication over thunderbolt, both via RDMA and IP
else:
prepare_tb_hostfile(args, hosts, sshinfo)

View File

@@ -1,557 +0,0 @@
# Copyright © 2025 Apple Inc.
import argparse
import base64
import json
import os
import shlex
import shutil
import sys
import tempfile
import threading
from collections import Counter
from itertools import chain
from pathlib import Path
from queue import Empty as QueueEmpty
from queue import Queue
from select import select
from subprocess import PIPE, Popen, run
import mlx.core as mx
from .common import log, log_warning, parse_hostfile, parse_hostlist, positive_number
class CommandProcess:
@property
def process(self):
"""Return the Popen object that refers to the current command."""
raise NotImplementedError()
@property
def exit_status(self):
"""Return a tuple (returncode, killed) for the command. It should be
(None, None) while the command is running normally."""
raise NotImplementedError()
def preprocess_output(self, data: str, is_stdout=False):
"""Preprocess the output of the command so that extra data can be
capture or the format changed on the fly."""
raise NotImplementedError()
def terminate(self):
"""Terminate or return the exit code."""
raise NotImplementedError()
class RemoteProcess(CommandProcess):
def __init__(self, rank, host, python, cwd, files, env, command):
is_local = host == "127.0.0.1"
cmd = RemoteProcess.make_launch_script(rank, cwd, files, env, command)
if not is_local:
cmd = f"ssh -tt -o LogLevel=QUIET {host} {shlex.quote(cmd)}"
self._host = host
self._pidfile = None
self._is_local = is_local
self._process = Popen(
cmd,
shell=True,
executable="/bin/bash",
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
self._killed = False
@property
def process(self):
return self._process
@property
def exit_status(self):
return self._process.poll(), self._killed
def preprocess_output(self, data, is_stdout=False):
if self._pidfile is None:
pidfile, *rest = data.split("\n", maxsplit=1)
self._pidfile = pidfile
return rest[0] if rest else ""
return data
def terminate(self):
if self._killed:
return
self._process.terminate()
self._process.wait()
# Kill the remote program if possible
cmd = RemoteProcess.make_kill_script(self._pidfile)
if not self._is_local:
cmd = f"ssh {self._host} {shlex.quote(cmd)}"
c = run(
cmd,
check=True,
shell=True,
executable="/bin/bash",
capture_output=True,
text=True,
)
self._killed = c.stdout.strip() == "1"
@staticmethod
def make_launch_script(rank, cwd, files, env, command):
script = ""
# Disable echo
script = "stty -echo; "
# Write the PID to a file so we can kill the process if needed
script += "pidfile=$(mktemp); "
script += "echo $$ > $pidfile; "
script += 'printf "%s\\n" $pidfile; '
# Change the working directory if one was requested. Otherwise attempt to
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if [[ -d {repr(d)} ]]; then "
script += f" cd {repr(d)}; "
if cwd is not None:
script += "else "
script += f" echo 'Failed to change directory to' {repr(d)} >2; "
script += "fi; "
# Add the environment variables that were requested
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
if not all(c.isalnum() or c == "_" for c in key):
log_warning(
f"'{e}' is an invalid environment variable so it is ignored"
)
continue
script += f"export {key}={value}; "
# Make the temporary files
for env_name, content in files.items():
script += "fname=$(mktemp); "
script += f"echo {shlex.quote(content)} >$fname; "
script += f"export {env_name}=$fname; "
# Finally add the rank
script += f"export MLX_RANK={rank}; "
# Replace the process with the script
script += f"cmd=({' '.join(map(shlex.quote, command))}); "
script += 'exec "${cmd[@]}"'
return script
@staticmethod
def make_kill_script(pidfile):
script = ""
script += f"pid=$(cat {pidfile}); "
script += "if ps -p $pid >/dev/null; then "
script += " kill $pid; "
script += " echo 1; "
script += "else "
script += " echo 0; "
script += "fi; "
script += f"rm {pidfile}"
return script
def _launch_with_io(command_class, arguments, verbose):
stop = False
exit_codes = [(None, None)] * len(arguments)
def _thread_fn(rank, *args, **kwargs):
stdin_queue = kwargs.pop("stdin_queue")
stdout_queue = kwargs.pop("stdout_queue")
stderr_queue = kwargs.pop("stderr_queue")
command = command_class(rank, *args, **kwargs)
p = command.process
os.set_blocking(p.stdout.fileno(), False)
os.set_blocking(p.stderr.fileno(), False)
os.set_blocking(p.stdin.fileno(), False)
to_read = [p.stdout.fileno(), p.stderr.fileno()]
to_write = [p.stdin.fileno()]
stdin_buffer = b""
while p.poll() is None:
try:
stdin_buffer += stdin_queue.get_nowait()
except QueueEmpty:
pass
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
for fd in rlist:
is_stdout = fd == p.stdout.fileno()
msg = os.read(fd, 8192).decode(errors="ignore")
msg = command.preprocess_output(msg, is_stdout)
if is_stdout:
stdout_queue.put(msg.encode())
else:
stderr_queue.put(msg.encode())
for fd in wlist:
if len(stdin_buffer) > 0:
n = os.write(fd, stdin_buffer)
stdin_buffer = stdin_buffer[n:]
if stop:
command.terminate()
break
exit_codes[rank] = command.exit_status
if exit_codes[rank][1]:
log_warning(f"Node with rank {rank} was killed")
elif exit_codes[rank][0] != 0:
log_warning(f"Node with rank {rank} exited with code {exit_codes[rank][0]}")
else:
log(verbose, f"Node with rank {rank} completed")
stdin_queues = []
stdout_queues = []
stderr_queues = []
threads = []
for i, (args, kwargs) in enumerate(arguments):
stdin_queues.append(Queue())
stdout_queues.append(Queue())
stderr_queues.append(Queue())
t = threading.Thread(
target=_thread_fn,
args=args,
kwargs=kwargs
| {
"stdin_queue": stdin_queues[-1],
"stdout_queue": stdout_queues[-1],
"stderr_queue": stderr_queues[-1],
},
)
t.start()
threads.append(t)
os.set_blocking(sys.stdin.fileno(), False)
os.set_blocking(sys.stdout.fileno(), True)
os.set_blocking(sys.stderr.fileno(), True)
while not stop or any(not q.empty() for q in chain(stdout_queues, stderr_queues)):
# Broadcast user input to the jobs
rlist, _, _ = select([sys.stdin.fileno()], [], [], 0.1)
for fd in rlist:
stdin_buffer = os.read(fd, 8192)
for q in stdin_queues:
q.put(stdin_buffer)
# Gather job output
for q in stdout_queues:
try:
while not q.empty():
sys.stdout.buffer.write(q.get_nowait())
except QueueEmpty:
pass
for q in stderr_queues:
try:
while not q.empty():
sys.stderr.buffer.write(q.get_nowait())
except QueueEmpty:
pass
sys.stdout.buffer.flush()
sys.stderr.buffer.flush()
# Check if all are running and terminate otherwise
if any(t.is_alive() for t in threads):
for i, t in enumerate(threads):
if not t.is_alive():
if exit_codes[i][0] != 0:
stop = True
break
else:
break
# Wait for the jobs to finish
for t in threads:
t.join()
# Process any remaining outputs
for q in stdout_queues:
while not q.empty():
sys.stdout.buffer.write(q.get())
for q in stderr_queues:
while not q.empty():
sys.stderr.buffer.write(q.get())
sys.stdout.buffer.flush()
sys.stderr.buffer.flush()
def launch_ring(parser, hosts, args, command):
if any(len(h.ips) == 0 for h in hosts):
parser.error(
"The ring backend requires IPs to be provided instead of hostnames"
)
port = args.starting_port
ring_hosts = []
for h in hosts:
node = []
for ip in h.ips:
for i in range(args.connections_per_ip):
node.append(f"{ip}:{port}")
port += 1
ring_hosts.append(node)
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
files = {"MLX_HOSTFILE": hostfile}
env = args.env
if args.verbose:
env.append("MLX_RING_VERBOSE=1")
cwd = args.cwd
log(args.verbose, "Running", shlex.join(command))
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
)
def launch_nccl(parser, hosts, args, command):
if not hosts[0].ips:
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
master_host = hosts[0].ips[0]
master_port = args.nccl_port
world_size = len(hosts)
env = args.env
cwd = args.cwd
if args.verbose:
env.append("NCCL_DEBUG=INFO")
env.append(f"NCCL_HOST_IP={master_host}")
env.append(f"NCCL_PORT={master_port}")
env.append(f"MLX_WORLD_SIZE={world_size}")
log(args.verbose, "Running", shlex.join(command))
_launch_with_io(
RemoteProcess,
[
(
(
rank,
h.ssh_hostname,
args.python,
cwd,
{},
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
command,
),
{},
)
for rank, h in enumerate(hosts)
],
args.verbose,
)
def launch_jaccl(parser, hosts, args, command):
if not hosts[0].ips:
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
have_rdmas = all(len(h.rdma) == len(hosts) for h in hosts)
have_nulls = all(h.rdma[i] is None for i, h in enumerate(hosts))
if not have_rdmas or not have_nulls:
raise ValueError("Malformed hostfile for jaccl backend")
coordinator = hosts[0].ips[0]
env = args.env
cwd = args.cwd
env.append(f"MLX_JACCL_COORDINATOR={coordinator}:{args.starting_port}")
files = {"MLX_IBV_DEVICES": json.dumps([h.rdma for h in hosts])}
log(args.verbose, "Running", shlex.join(command))
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
)
def get_mpi_libname():
try:
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
ompi_info = ompi_info.stdout.strip().decode()
if platform.system() == "Darwin":
otool_output = run(
["otool", "-L", ompi_info], check=True, capture_output=True
)
else:
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
otool_output = otool_output.stdout.decode()
# StopIteration if not found
libmpi_line = next(
filter(lambda line: "libmpi" in line, otool_output.splitlines())
)
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
except:
return None
def launch_mpi(parser, hosts, args, command):
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
mpirun = mpirun.stdout.strip().decode()
# Compatibility with homebrew and pip installs
mpi_libname = get_mpi_libname()
if mpi_libname is not None:
dyld = Path(mpirun).parent.parent / "lib"
args.env = [
f"DYLD_LIBRARY_PATH={str(dyld)}",
f"MLX_MPI_LIBNAME={mpi_libname}",
] + args.env
log(args.verbose, f"Using '{mpirun}'")
with tempfile.NamedTemporaryFile(mode="w") as f:
hosts = Counter((h.ssh_hostname for h in hosts))
for h, n in hosts.items():
print(f"{h} slots={n}", file=f)
f.flush()
cmd = [
mpirun,
"--output",
":raw", # do not line buffer output
"--hostfile",
f.name,
*(["-cwd", args.cwd] if args.cwd else []),
*sum((["-x", e] for e in args.env), []),
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
"--",
*command,
]
log(args.verbose, "Running", " ".join(cmd))
try:
run(cmd)
except KeyboardInterrupt:
pass
def main():
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
parser.add_argument(
"--print-python",
action="store_true",
help="Print the path to the current python executable and exit",
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument(
"--repeat-hosts",
"-n",
type=positive_number,
default=1,
help="Repeat each host a given number of times",
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl", "jaccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to launch",
)
parser.add_argument(
"--env",
action="append",
default=[],
help="Set environment variables for the jobs",
)
parser.add_argument(
"--mpi-arg",
action="append",
default=[],
help="Arguments to pass directly to mpirun",
)
parser.add_argument(
"--connections-per-ip",
default=1,
type=int,
help="How many connections per ip to use for the ring backend",
)
parser.add_argument(
"--starting-port",
"-p",
type=int,
default=32323,
help="For the ring backend listen on this port increasing by 1 per rank and IP",
)
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
parser.add_argument(
"--no-verify-script",
action="store_false",
dest="verify_script",
help="Do not verify that the script exists",
)
parser.add_argument(
"--python", default=sys.executable, help="Use this python on the remote hosts"
)
args, rest = parser.parse_known_args()
if args.print_python:
print(args.python)
return
if len(rest) == 0:
parser.error("No script is provided")
if rest[0] == "--":
rest.pop(0)
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
# Check if the script is a file and convert it to a full path
if (script := Path(rest[0])).exists() and script.is_file():
rest[0:1] = [args.python, str(script.resolve())]
elif (command := shutil.which(rest[0])) is not None:
rest[0] = command
elif args.verify_script:
raise ValueError(f"Invalid script or command {rest[0]}")
# Launch
if args.backend == "ring":
launch_ring(parser, hosts, args, rest)
if args.backend == "mpi":
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if args.backend == "jaccl":
launch_jaccl(parser, hosts, args, rest)

View File

@@ -0,0 +1,909 @@
# Copyright © 2025 Apple Inc.
import argparse
import base64
import ipaddress
import json
import os
import platform
import shlex
import shutil
import sys
import tempfile
import threading
import time
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from queue import Empty as QueueEmpty
from queue import Queue
from select import select
from subprocess import PIPE, Popen, run
from typing import Optional
import mlx.core as mx
@dataclass
class Host:
rank: int
ssh_hostname: str
ips: list[str]
@dataclass
class ThunderboltPort:
iface: str
uuid: str
connected_to: Optional[str]
@dataclass
class ThunderboltHost:
name: str
ports: list[ThunderboltPort]
def parse_hardware_ports(ports_string):
ports = {}
port_name = None
for l in ports_string.decode("utf-8").split("\n"):
if l.startswith("Hardware Port:"):
port_name = l.strip()[15:]
elif l.startswith("Device:"):
ports[port_name] = l.strip()[8:]
port_name = None
return ports
def get_num_nvidia_gpus():
result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
return len(result.stdout.strip().split("\n"))
def extract_rings(hosts, index):
def usable_port(i, j, used_ports):
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
def dfs(start_node, node, path, visited, used_ports):
path.append(node)
visited.add(node)
for j, p in enumerate(hosts[node].ports):
if not usable_port(node, j, used_ports):
continue
next_node, _ = index[p.connected_to]
if next_node == start_node:
yield path[:]
if next_node not in visited:
yield from dfs(start_node, next_node, path, visited, used_ports)
path.pop()
visited.remove(node)
# Concretize maps the found cycle to real thunderbolt ports. It also adds
# those ports to the used set so next cycles can't use them again.
def concretize(cycle, used_ports):
concrete_path = []
for n1, n2 in zip(cycle, cycle[1:] + cycle[:1]):
for j, p in enumerate(hosts[n1].ports):
if not usable_port(n1, j, used_ports):
continue
n2_hat, nj = index[p.connected_to]
if n2 == n2_hat:
concrete_path.append(((n1, j), (n2, nj)))
used_ports.add((n1, j))
used_ports.add((n2, nj))
break
if concrete_path[-1][0][0] != n1:
raise RuntimeError("Couldn't concretize the cycle")
return concrete_path
# Normalize tries to ensure that the cycles have the same direction so we can
# use them together. We achieve this by selecting the direction such that
# the smallest rank hosts connect to larger rank hosts.
def normalize(path):
small_to_large = sum(1 for p in path if p[0][0] < p[1][0])
if small_to_large > len(path) - small_to_large:
return path
else:
return [(p[1], p[0]) for p in path]
rings = []
used_ports = set()
for start_node in range(len(hosts)):
while True:
ring = []
for r in dfs(start_node, start_node, [], set(), used_ports):
if len(r) > len(ring):
ring = r
# Break early since we won't find a bigger ring no matter what
if len(ring) == len(hosts):
break
if not ring:
break
try:
rings.append(normalize(concretize(ring, used_ports)))
except RuntimeError:
if len(rings) > 0:
return rings
raise
return rings
def positive_number(x):
x = int(x)
if x <= 0:
raise ValueError("Number should be positive")
return x
def log(verbose, *args, **kwargs):
if not verbose:
return
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
def log_warning(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
def log_error(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
def parse_hostfile(parser, hostfile):
"""Parse the json hostfile that contains both the hostnames to ssh into and
the ips to communicate over when using the ring backend.
Example:
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
...
{"ssh": "hostnameN", "ips": ["123.123.123.N"]},
]
Args:
hostfile (str): The path to the json file containing the host
information
"""
hostfile = Path(hostfile)
if not hostfile.exists():
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
try:
hosts = []
with open(hostfile) as f:
for i, h in enumerate(json.load(f)):
hosts.append(Host(i, h["ssh"], h.get("ips", [])))
return hosts
except Exception as e:
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")
def parse_hostlist(parser, hostlist, repeats):
hosts = []
for i, h in enumerate(hostlist.split(",")):
if h == "":
raise ValueError("Hostname cannot be empty")
try:
ipaddress.ip_address(h)
ips = [h]
except ValueError:
ips = []
for i in range(repeats):
hosts.append(Host(i, h, ips))
return hosts
def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
# Imports that are used throughout
script = ""
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Write the PID to a file so we can kill the process if needed
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
# Change the working directory if one was requested. Otherwise attempt to
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
if cwd is not None:
script += "else:\n"
script += (
f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
)
script += f" sys.exit(1)\n"
# Add the environment variables that were given to us
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
if not all(c.isalnum() or c == "_" for c in key):
log_warning(f"'{e}' is an invalid environment variable so it is ignored")
continue
script += f"env[{repr(key)}] = {repr(value)}\n"
# Add the environment variables to enable the ring distributed backend
if hostfile != "":
script += "_, hostfile = tempfile.mkstemp()\n"
script += "with open(hostfile, 'w') as f:\n"
script += f" f.write({repr(hostfile)})\n"
if verbose:
script += "env['MLX_RING_VERBOSE'] = '1'\n"
script += "env['MLX_HOSTFILE'] = hostfile\n"
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
# Replace the process with the script
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
return script
def launch_ring(parser, hosts, args, command):
stop = False
exit_codes = [None] * len(hosts)
def node_thread(rank, host, hostfile, input_queue):
is_local = host == "127.0.0.1"
script = make_monitor_script(
rank, hostfile, args.cwd, args.env, command, args.verbose
)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
if not is_local:
cmd = f"ssh {host} '{cmd}'"
p = Popen(
cmd,
shell=True,
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
os.set_blocking(p.stdout.fileno(), False)
os.set_blocking(p.stderr.fileno(), False)
os.set_blocking(p.stdin.fileno(), False)
# Repeat the stdout and stderr to the local machine
to_read = [p.stdout.fileno(), p.stderr.fileno()]
to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()]
pidfile = ""
stdin_buffer = b""
stdout_buffer = b""
stderr_buffer = b""
while p.poll() is None:
try:
stdin_buffer += input_queue.get_nowait()
except QueueEmpty:
pass
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
for fd in rlist:
msg = os.read(fd, 8192).decode(errors="ignore")
# Fetch the PID file first if we haven't already
if pidfile == "":
pidfile, *msg = msg.split("\n", maxsplit=1)
msg = msg[0] if msg else ""
is_stdout = fd == p.stdout.fileno()
if is_stdout:
stdout_buffer += msg.encode()
else:
stderr_buffer += msg.encode()
for fd in wlist:
if fd == p.stdin.fileno() and len(stdin_buffer) > 0:
n = os.write(fd, stdin_buffer)
stdin_buffer = stdin_buffer[n:]
elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0:
n = os.write(fd, stdout_buffer)
stdout_buffer = stdout_buffer[n:]
elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0:
n = os.write(fd, stderr_buffer)
stderr_buffer = stderr_buffer[n:]
if stop:
p.terminate()
break
p.wait()
exit_codes[rank] = p.returncode
# Kill the remote program if possible
cmd = ""
cmd += f"pid=$(cat {pidfile}); "
cmd += "if ps -p $pid >/dev/null; then "
cmd += " kill $pid; "
cmd += " echo 1; "
cmd += "else "
cmd += " echo 0; "
cmd += "fi; "
cmd += f"rm {pidfile}"
if not is_local:
cmd = f"ssh {host} '{cmd}'"
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
if c.stdout.strip() == "1":
log_warning(f"Node with rank {rank} was killed")
elif p.returncode != 0:
log_warning(f"Node with rank {rank} exited with code {p.returncode}")
else:
log(args.verbose, f"Node with rank {rank} completed")
if all(len(h.ips) == 0 for h in hosts):
parser.error(
"The ring backend requires IPs to be provided instead of hostnames"
)
port = args.starting_port
ring_hosts = []
for h in hosts:
node = []
for ip in h.ips:
for i in range(args.connections_per_ip):
node.append(f"{ip}:{port}")
port += 1
ring_hosts.append(node)
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
log(args.verbose, "Running", shlex.join(command))
input_queues = []
threads = []
for i, h in enumerate(hosts):
if i + 1 == len(hosts):
time.sleep(1.0)
input_queues.append(Queue())
t = threading.Thread(
target=node_thread, args=(i, h.ssh_hostname, hostfile, input_queues[-1])
)
t.start()
threads.append(t)
os.set_blocking(sys.stdin.fileno(), False)
while not stop:
rlist, _, _ = select([sys.stdin.fileno()], [], [], 1.0)
for fd in rlist:
stdin_buffer = os.read(fd, 8192)
for q in input_queues:
q.put(stdin_buffer)
if any(t.is_alive() for t in threads):
for i, t in enumerate(threads):
if not t.is_alive():
if exit_codes[i] != 0:
stop = True
break
else:
break
for t in threads:
t.join()
def get_mpi_libname():
try:
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
ompi_info = ompi_info.stdout.strip().decode()
if platform.system() == "Darwin":
otool_output = run(
["otool", "-L", ompi_info], check=True, capture_output=True
)
else:
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
otool_output = otool_output.stdout.decode()
# StopIteration if not found
libmpi_line = next(
filter(lambda line: "libmpi" in line, otool_output.splitlines())
)
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
except:
return None
def launch_mpi(parser, hosts, args, command):
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
mpirun = mpirun.stdout.strip().decode()
# Compatibility with homebrew and pip installs
mpi_libname = get_mpi_libname()
if mpi_libname is not None:
dyld = Path(mpirun).parent.parent / "lib"
args.env = [
f"DYLD_LIBRARY_PATH={str(dyld)}",
f"MLX_MPI_LIBNAME={mpi_libname}",
] + args.env
log(args.verbose, f"Using '{mpirun}'")
with tempfile.NamedTemporaryFile(mode="w") as f:
hosts = Counter((h.ssh_hostname for h in hosts))
for h, n in hosts.items():
print(f"{h} slots={n}", file=f)
f.flush()
cmd = [
mpirun,
"--output",
":raw", # do not line buffer output
"--hostfile",
f.name,
*(["-cwd", args.cwd] if args.cwd else []),
*sum((["-x", e] for e in args.env), []),
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
"--",
*command,
]
log(args.verbose, "Running", " ".join(cmd))
try:
run(cmd)
except KeyboardInterrupt:
pass
def launch_nccl(parser, hosts, args, command):
master_host = hosts[0].ips[0]
if master_host != "127.0.0.1":
raise ValueError("The NCCL backend only supports localhost for now.")
master_port = args.nccl_port
world_size = len(hosts)
base_env = os.environ.copy()
base_env.update(
{
"NCCL_DEBUG": base_env.get(
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
),
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
"NCCL_HOST_IP": master_host,
"NCCL_PORT": str(master_port),
"MLX_WORLD_SIZE": str(world_size),
}
)
procs = []
num_gpus = get_num_nvidia_gpus()
if num_gpus == 0:
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
if args.repeat_hosts > num_gpus:
raise RuntimeError("NCCL requires a separate GPU per process.")
try:
for rank in range(world_size):
env = base_env.copy()
mlx_rank = str(rank % args.repeat_hosts)
env["MLX_RANK"] = mlx_rank
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
p = Popen(command, env=env)
procs.append(p)
for p in procs:
ret = p.wait()
if ret != 0:
raise RuntimeError(f"Rank process exited with {ret}")
except (RuntimeError, KeyboardInterrupt) as err:
for p in procs:
if p.poll() is None:
try:
p.kill()
except Exception:
pass
raise
def check_ssh_connections(hosts):
results = [False] * len(hosts)
def _check(hostname, i):
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"echo",
"success",
],
stdout=PIPE,
stderr=PIPE,
)
results[i] = result.returncode == 0
threads = [
threading.Thread(target=_check, args=(h.ssh_hostname, i))
for i, h in enumerate(hosts)
]
for t in threads:
t.start()
for t in threads:
t.join()
if not all(results):
log_error("Could not ssh to the following hosts:")
for i, h in enumerate(hosts):
if not results[i]:
log_error(" - ", h.ssh_hostname)
log_error()
log_error("Maybe they are not set-up for password-less ssh?")
sys.exit(1)
def prepare_tb_ring(args, hosts):
log(
args.verbose,
f"Preparing a thunderbolt ring for {', '.join(h.ssh_hostname for h in hosts)}",
)
# Check that we can ssh
check_ssh_connections(hosts)
if args.auto_setup and args.verbose:
log_warning(
"--auto-setup is requested which requires password-less sudo",
"on the remote hosts",
)
# Extract the current connectivity from the remote hosts
thunderbolt_connections = []
for h in hosts:
log(args.verbose, "Getting connectivity from", h.ssh_hostname)
thunderbolt_connections.append(
json.loads(
run(
[
"ssh",
h.ssh_hostname,
"system_profiler",
"SPThunderboltDataType",
"-json",
],
capture_output=True,
).stdout
)
)
interface_maps = []
for h in hosts:
log(args.verbose, "Getting interface names from", h.ssh_hostname)
interface_maps.append(
parse_hardware_ports(
run(
[
"ssh",
h.ssh_hostname,
"networksetup",
"-listallhardwareports",
],
capture_output=True,
).stdout
)
)
# Parse the connectivity into some simple dataclasses
tb_hosts = []
for c, iface_map in zip(thunderbolt_connections, interface_maps):
name = ""
ports = []
for t in c["SPThunderboltDataType"]:
uuid = t.get("domain_uuid_key")
if uuid is None:
continue
name = t["device_name_key"]
tag = t["receptacle_1_tag"]["receptacle_id_key"]
items = t.get("_items", [])
connected_items = [item for item in items if "domain_uuid_key" in item]
connected_to = (
connected_items[0]["domain_uuid_key"] if connected_items else None
)
iface = iface_map[f"Thunderbolt {tag}"]
ports.append(ThunderboltPort(iface, uuid, connected_to))
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
# Create a reverse index to be able to map uuids to (host, port) quickly
uuid_reverse_index = {}
for i, h in enumerate(tb_hosts):
for j, p in enumerate(h.ports):
uuid_reverse_index[p.uuid] = (i, j)
# Find the rings by simply walking and marking visited (host, port) tuples
# and keeping the largest rings greedily.
log(args.verbose, "Extracting rings from the parsed connectivity")
rings = extract_rings(tb_hosts, uuid_reverse_index)
# Just output a DOT graphical representation of the found rings
if args.dot:
names = []
for i in range(len(tb_hosts)):
n = ""
j = i
while True:
n += chr(97 + j % 26)
j //= 26
if j == 0:
break
names.append(n)
print("graph G {")
print(" node [shape=rectangle];")
for i, h in enumerate(hosts):
print(f' {names[i]} [label="{h.ssh_hostname}"];')
for r in rings:
for (i, _), (j, _) in r:
print(f" {names[i]} -- {names[j]};")
print("}")
return
# Assign IPs to each interface such that the interfaces can communicate
ips = {}
pairs = {}
expecting = set()
ip0 = 0
ip1 = 0
netmask = "255.255.255.252"
for r in rings:
for a, b in r:
ips[a] = f"192.168.{ip0}.{ip1 + 1}"
ips[b] = f"192.168.{ip0}.{ip1 + 2}"
pairs[a] = b
pairs[b] = a
expecting.add(b)
ip1 += 4
if ip1 > 255:
ip0 += 1
ip1 = 0
if ip0 > 255:
raise ValueError("Ran out of available local IPs for the ring")
# Extract the host order from the first ring
hostmap = dict((r[0][0], r[1][0]) for r in rings[0])
first_host = min(hostmap.keys())
order = [first_host]
while hostmap[order[-1]] != first_host:
order.append(hostmap[order[-1]])
# Create the hostfile
hostfile = []
for i in order:
h = hosts[i]
host = {
"ssh": h.ssh_hostname,
"ips": [
ips[i, j]
for j, p in enumerate(tb_hosts[i].ports)
if (i, j) in expecting
],
}
hostfile.append(host)
if not args.hostfile_only:
for i, h in enumerate(hosts):
command = ""
command += "sudo ifconfig bridge0 down\n"
for j, p in enumerate(tb_hosts[i].ports):
if (i, j) not in ips:
continue
iface = p.iface
ip = ips[i, j]
peer = ips[pairs[i, j]]
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
command += f"sudo route change {peer} -interface {iface}\n"
if args.auto_setup:
print(f"Running auto setup for {h.ssh_hostname}")
command = command.strip().replace("\n", " && ")
command = ["ssh", h.ssh_hostname, command]
log(args.verbose, shlex.join(command))
run(command)
else:
msg = f"Setup for {h.ssh_hostname}"
print(msg)
print("=" * len(msg))
print(command)
input("Enter to continue")
print()
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def prepare_hostfile(args, hosts):
log(
args.verbose,
f"Preparing an ethernet hostfile for {', '.join(h.ssh_hostname for h in hosts)}",
)
# Check that we can ssh
check_ssh_connections(hosts)
# Get the ips for each host
for h in hosts:
log(args.verbose, "Getting the ip from", h.ssh_hostname)
h.ips.append(
run(
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
capture_output=True,
text=True,
).stdout.strip()
)
hostfile = []
for h in hosts:
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def distributed_config():
parser = argparse.ArgumentParser(
description="Configure remote machines for use with MLX distributed"
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to configure",
)
parser.add_argument(
"--over",
choices=["thunderbolt", "ethernet"],
default="thunderbolt",
help="What type of connectivity to configure",
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--dot", action="store_true", help="Output the topology in DOT format and exit"
)
parser.add_argument(
"--hostfile-only", action="store_true", help="If set only compute the hostfile"
)
parser.add_argument(
"--output-hostfile", help="If provided, save the hostfile to this path"
)
parser.add_argument(
"--auto-setup",
action="store_true",
help="If set we will attempt to automatically configure the machines via ssh",
)
args = parser.parse_args()
if args.backend == "mpi" and args.over == "thunderbolt":
raise ValueError(
(
"The configuration of MPI over thunderbolt is "
"not supported yet by mlx.distributed_config"
)
)
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, 1)
if args.over == "thunderbolt":
prepare_tb_ring(args, hosts)
else:
prepare_hostfile(args, hosts)
def main():
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
parser.add_argument(
"--print-python",
action="store_true",
help="Print the path to the current python executable and exit",
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument(
"--repeat-hosts",
"-n",
type=positive_number,
default=1,
help="Repeat each host a given number of times",
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to launch",
)
parser.add_argument(
"--env",
action="append",
default=[],
help="Set environment variables for the jobs",
)
parser.add_argument(
"--mpi-arg",
action="append",
default=[],
help="Arguments to pass directly to mpirun",
)
parser.add_argument(
"--connections-per-ip",
default=1,
type=int,
help="How many connections per ip to use for the ring backend",
)
parser.add_argument(
"--starting-port",
"-p",
type=int,
default=5000,
help="For the ring backend listen on this port increasing by 1 per rank and IP",
)
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
args, rest = parser.parse_known_args()
if args.print_python:
print(sys.executable)
return
if len(rest) == 0:
parser.error("No script is provided")
if rest[0] == "--":
rest.pop(0)
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
# Check if the script is a file and convert it to a full path
if (script := Path(rest[0])).exists():
rest[0:1] = [sys.executable, str(script.resolve())]
elif (command := shutil.which(rest[0])) is not None:
rest[0] = command
else:
raise ValueError(f"Invalid script or command {rest[0]}")
# Launch
if args.backend == "ring":
launch_ring(parser, hosts, args, rest)
if args.backend == "mpi":
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if __name__ == "__main__":
main()

View File

@@ -52,25 +52,9 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"is_available",
[](const std::string& backend) {
return mx::distributed::is_available(backend);
},
"backend"_a = "any",
nb::sig("def is_available(backend: str = 'any') -> bool"),
&mx::distributed::is_available,
R"pbdoc(
Check if a communication backend is available.
Note, this function returns whether MLX has the capability of
instantiating that distributed backend not whether it is possible to
create a communication group. For that purpose one should use
``init(strict=True)``.
Args:
backend (str, optional): The name of the backend to check for availability.
It takes the same values as :func:`init()`. Default: ``"any"``.
Returns:
bool: Whether the distributed backend is available.
)pbdoc");
m.def(
@@ -95,10 +79,10 @@ void init_distributed(nb::module_& parent_module) {
in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False``
backend (str, optional): Which distributed backend to initialize.
Possible values ``mpi``, ``ring``, ``nccl``, ``jaccl``, ``any``. If
set to ``any`` all available backends are tried and the first one
that succeeds becomes the global group which will be returned in
subsequent calls. Default: ``any``
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
available backends are tried and the first one that succeeds
becomes the global group which will be returned in subsequent
calls. Default: ``any``
Returns:
Group: The group representing all the launched processes.

View File

@@ -89,7 +89,8 @@ static PyType_Spec gc_func_spec = {
/* .name = */ "mlx.gc_func",
/* .basicsize = */ (int)sizeof(gc_func),
/* .itemsize = */ 0,
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL,
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_HAVE_VECTORCALL,
/* .slots = */ gc_func_slots};
static PyTypeObject* gc_func_tp = nullptr;

View File

@@ -16,8 +16,7 @@ struct type_caster<mlx::core::SmallVector<Type, Size, Alloc>> {
NB_TYPE_CASTER(
List,
const_name(NB_TYPING_TUPLE "[") + make_caster<Type>::Name +
const_name(", ...]"))
const_name("tuple[") + make_caster<Type>::Name + const_name(", ...]"))
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
size_t size;

View File

@@ -124,37 +124,53 @@ auto py_value_and_grad(
// Collect the arrays
std::vector<mx::array> arrays;
std::vector<nb::object> array_objects;
auto flatten_with_objects = [&arrays, &array_objects](
auto tree, bool strict) {
tree_visit(tree, [&](nb::handle obj) {
if (nb::isinstance<mx::array>(obj)) {
arrays.push_back(nb::cast<mx::array>(obj));
array_objects.push_back(nb::borrow<nb::object>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
};
std::vector<int> counts(1, 0);
std::vector<int> gradient_indices;
for (int i = 0, j = 0; i < args.size(); ++i) {
bool needs_grad = (j < argnums.size() && argnums[j] == i);
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
auto pre_size = arrays.size();
flatten_with_objects(args[i], /* strict = */ needs_grad);
if (needs_grad) {
auto old_size = gradient_indices.size();
gradient_indices.resize(old_size + argsi.size());
auto delta_size = arrays.size() - pre_size;
gradient_indices.resize(old_size + delta_size);
std::iota(
gradient_indices.begin() + old_size,
gradient_indices.end(),
arrays.size());
pre_size);
j++;
counts.push_back(argsi.size());
counts.push_back(delta_size);
}
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
}
for (auto item : kwargs) {
bool needs_grad =
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
auto argsk = tree_flatten(item.second, /* strict = */ needs_grad);
auto pre_size = arrays.size();
flatten_with_objects(item.second, /* strict = */ needs_grad);
if (needs_grad) {
auto old_size = gradient_indices.size();
gradient_indices.resize(old_size + argsk.size());
auto delta_size = arrays.size() - pre_size;
gradient_indices.resize(old_size + delta_size);
std::iota(
gradient_indices.begin() + old_size,
gradient_indices.end(),
arrays.size());
counts.push_back(argsk.size());
pre_size);
counts.push_back(delta_size);
}
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
}
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
@@ -163,7 +179,7 @@ auto py_value_and_grad(
nb::object py_value_out;
auto value_and_grads = mx::value_and_grad(
[&fun,
&arrays,
&array_objects,
&args,
&kwargs,
&py_value_out,
@@ -183,8 +199,9 @@ auto py_value_and_grad(
tree_visit_update(tree, [&](nb::handle node) {
auto replace_arr = nb::cast<mx::array>(node);
if (replace_arr.id() == a[index].id()) {
return nb::cast(arrays[index++]);
return array_objects[index++];
} else {
index++;
return nb::cast(replace_arr);
}
});

View File

@@ -780,9 +780,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
return arrs[0]
arrs = [mx.array(1.0)]
init_id = id(arrs[0])
arr = arrs[0]
mx.grad(fun)(arrs)
self.assertEqual(init_id, id(arrs[0]))
self.assertEqual(id(arr), id(arrs[0]))
def fun(arrs):
arrs[1] = sum(arrs)
return arrs[1]
arrs = [mx.array(1.0), mx.array(1.0), mx.array(1.0)]
a_0, a_1, a_2 = arrs
mx.grad(fun)(arrs)
self.assertEqual(id(a_0), id(arrs[0]))
self.assertNotEqual(id(a_1), id(arrs[1]))
self.assertEqual(id(a_2), id(arrs[2]))
def test_grad_with_inplace_update(self):
def loss_fn(model):

View File

@@ -4,12 +4,12 @@ import gc
import inspect
import io
import math
import unittest
from functools import partial, wraps
from io import StringIO
import mlx.core as mx
import mlx_tests
import numpy as np
class TestCompile(mlx_tests.MLXTestCase):
@@ -1252,6 +1252,26 @@ class TestCompile(mlx_tests.MLXTestCase):
loss, grads = step(emb, w, x)
mx.eval(loss, grads)
def test_compile_donates_input_buffer(self):
mx.set_default_device(mx.cpu)
def fun(x):
return mx.sin(x) + 1
compiled_fn = mx.compile(fun)
input = mx.arange(16, dtype=mx.float32)
mx.eval(input)
in_ptr = np.asarray(input, copy=False).__array_interface__["data"][0]
out = compiled_fn(input)
del input # Ensure the reference is dropped
mx.eval(out)
self.assertEqual(
np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr
)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()

View File

@@ -744,7 +744,6 @@ class TestVmap(mlx_tests.MLXTestCase):
return Vector([t[0] + 10, t[1] * 10])
x = State(mx.array(1), mx.array(2))
print(f"{transform(x)=}")
vmap_transform = mx.vmap(transform)
vmap_transform_tuple = mx.vmap(transform_tuple)

View File

@@ -255,7 +255,7 @@ if __name__ == "__main__":
extras = {
"dev": [
"nanobind==2.4.0",
"nanobind==2.10.2",
"numpy",
"pre-commit",
"setuptools>=80",
@@ -265,8 +265,8 @@ if __name__ == "__main__":
}
entry_points = {
"console_scripts": [
"mlx.launch = mlx._distributed_utils.launch:main",
"mlx.distributed_config = mlx._distributed_utils.config:main",
"mlx.launch = mlx.distributed_run:main",
"mlx.distributed_config = mlx.distributed_run:distributed_config",
]
}
install_requires = []