Compare commits

..

13 Commits

Author SHA1 Message Date
Angelos Katharopoulos
cd4b12ce1b Refactoring launcher 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
425043ccca Change the name to a fun pun 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
95d92af8a0 Add headers for gcc 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
bfdddd644b Expose per-backend availability in C++ and python 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
1216afdc91 Add a no_ibv 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
04e94d78bb Add empty sum_scatter 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
60d4e8b2a8 Add send/recv 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
c5745fddd2 Make sure that there is space for work completions 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
e937a8033f Add working reduce and semi-working all gather 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
4dfe02d7c6 Fix ring 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
5c2cff9329 Fix side channel initialization for more than 2 peers 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
325dab9559 All gather 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
67e454ab0a Initial working all reduce 2025-12-08 15:50:05 -08:00
46 changed files with 2347 additions and 858 deletions

View File

@@ -11,7 +11,7 @@ runs:
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools nanobind==2.10.2
pip install cmake setuptools nanobind==2.4.0
pip install -e . -v
- name: Generate package stubs

View File

@@ -10,29 +10,23 @@ inputs:
description: 'Version of python to set up'
required: false
default: '3.10'
use-ccache:
description: 'Whether to enable ccache'
required: false
default: 'true'
runs:
using: "composite"
steps:
- name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
max-size: 1GB
- name: Install common dependencies
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
- name: Use ccache
if: ${{ inputs.use-ccache == 'true' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}
max-size: 1GB
# ccache-action bug: running "apt-get update" fails on large arm runner.
update-package-index: false
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
@@ -42,7 +36,7 @@ runs:
run: |
python -m venv .venv
source .venv/bin/activate
pip install setuptools cmake nanobind==2.10.2
pip install setuptools cmake nanobind==2.4.0
echo PATH=$PATH >> $GITHUB_ENV
# Make cmake search .venv for nanobind
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV

View File

@@ -23,14 +23,14 @@ jobs:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: "x86_64"
- name: Upload mlx artifacts
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
@@ -89,7 +89,7 @@ jobs:
with:
toolkit: 'cuda-12.9'
- name: Upload artifacts
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl

View File

@@ -57,20 +57,19 @@ jobs:
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
use-ccache: false
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: ${{ matrix.arch }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v5
with:
overwrite: true
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-cpu-${{ matrix.arch }}
@@ -96,7 +95,7 @@ jobs:
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools nanobind==2.10.2
pip install cmake setuptools nanobind==2.4.0
pip install -e . -v
- name: Generate package stubs
shell: bash -l {0}
@@ -114,14 +113,14 @@ jobs:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-metal
@@ -142,13 +141,12 @@ jobs:
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
use-ccache: false
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-cuda
@@ -164,12 +162,12 @@ jobs:
name: pypi
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v7
- uses: actions/download-artifact@v6
with:
pattern: linux-wheels-*
merge-multiple: true
path: dist
- uses: actions/download-artifact@v7
- uses: actions/download-artifact@v6
with:
pattern: mac-wheels-*
merge-multiple: true
@@ -191,7 +189,7 @@ jobs:
name: pypi
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v7
- uses: actions/download-artifact@v6
with:
name: mlx-cuda
path: dist
@@ -212,7 +210,7 @@ jobs:
name: pypi
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v7
- uses: actions/download-artifact@v6
with:
pattern: mlx-cpu-*
merge-multiple: true
@@ -234,7 +232,7 @@ jobs:
name: pypi
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v7
- uses: actions/download-artifact@v6
with:
name: mlx-metal
path: dist

View File

@@ -119,6 +119,10 @@ if(MLX_BUILD_METAL)
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path"
OUTPUT_VARIABLE CMAKE_OSX_SYSROOT
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
@@ -273,7 +277,7 @@ target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(
Python 3.10
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(

View File

@@ -186,7 +186,7 @@ Boolean masks follow NumPy semantics:
.. code-block:: shell
>>> a = mx.arange(1000).reshape(10, 10, 10)
>>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.

View File

@@ -3,6 +3,6 @@ requires = [
"setuptools>=42",
"cmake>=3.25",
"mlx>=0.18.0",
"nanobind==2.10.2",
"nanobind==2.4.0",
]
build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.25
mlx>=0.21.0
nanobind==2.10.2
nanobind==2.4.0

View File

@@ -130,7 +130,7 @@ void compiled_allocate_outputs(
// - Donatable
// - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && !is_constant(i)) {
in.is_donatable() && is_constant(i)) {
outputs[o++].copy_shared_buffer(in);
}
// Get representative input flags to properly set non-donated outputs
@@ -158,7 +158,7 @@ void compiled_allocate_outputs(
// - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
!is_constant(i)) {
is_constant(i)) {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;

View File

@@ -291,17 +291,6 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
num_keys,
kshape = keys.shape(),
kstrides = keys.strides()]() mutable {
auto copy_remaining = [&](char* cptr, size_t loc, uint32_t v) {
if (4 * loc + 4 <= bytes_per_key) {
reinterpret_cast<uint32_t*>(cptr)[loc] = v;
} else {
std::copy(
reinterpret_cast<char*>(&v),
reinterpret_cast<char*>(&v) + bytes_per_key - 4 * loc,
cptr + 4 * loc);
}
};
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
@@ -321,12 +310,18 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first;
copy_remaining(cptr, count.second, rb.second);
if (bytes_per_key % 4 > 0) {
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
}
}
if (!even) {
count.second = 0;
copy_remaining(
cptr, half_size, random::threefry2x32_hash(key, count).first);
ptr[half_size] = random::threefry2x32_hash(key, count).first;
}
}
});

View File

@@ -3,9 +3,5 @@
#include "mlx/backend/cpu/simd/base_simd.h"
#ifdef MLX_USE_ACCELERATE
#if defined(__x86_64__)
// the accelerate_simd implementation require neon -- use base implementation
#else
#include "mlx/backend/cpu/simd/accelerate_simd.h"
#endif
#endif

View File

@@ -338,43 +338,28 @@ std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
}
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
switch (type) {
case cudaGraphNodeTypeGraph: {
// Try to be updatable for a structure like graph -> graph -> kernel
cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
is_updatable &= sub_is_updatable;
key += subkey;
break;
}
case cudaGraphNodeTypeHost:
key += "H";
break;
case cudaGraphNodeTypeMemset:
key += "M";
break;
case cudaGraphNodeTypeKernel: {
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only allow dim.x to be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
is_updatable = false;
} else {
key += "K";
key += std::to_string(cluster_dim.clusterDim.x);
}
break;
}
case cudaGraphNodeTypeWaitEvent:
key += "W";
break;
case cudaGraphNodeTypeEventRecord:
key += "R";
break;
default:
if (type == cudaGraphNodeTypeGraph) {
// Try to be updatable for a structure like graph -> graph -> kernel
cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
is_updatable &= sub_is_updatable;
key += subkey;
} else if (type == cudaGraphNodeTypeMemset) {
key += "M";
} else if (type != cudaGraphNodeTypeKernel) {
is_updatable = false;
} else {
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only allow dim.x to be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
is_updatable = false;
} else {
key += "K";
key += std::to_string(cluster_dim.clusterDim.x);
}
}
}
key += ")";

View File

@@ -2,11 +2,7 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh"
#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh"
#include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/backend/cuda/vector_types.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
@@ -17,6 +13,17 @@
namespace mlx::core {
namespace cu {
template <int bits>
struct Quantize {
__device__ uint8_t operator()(float x) {
if constexpr (bits == 8) {
return __nv_fp8_e4m3(x).__x;
} else {
return __nv_fp4_e2m1(x).__x;
}
}
};
template <int bits>
struct Dequantize {
__device__ float operator()(uint8_t x) {
@@ -30,40 +37,29 @@ struct Dequantize {
namespace cg = cooperative_groups;
template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
__global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) {
using Tx2 = Vector2_t<T>;
using Tx4 = Vector4_t<T>;
uint32_t rbits = 0; // reserved bits for future use
template <typename T, int group_size, int bits, bool use_mx_scale>
__global__ void
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;
size_t thread_idx = tidx + grid_dim_x * size_t(tidy);
size_t base_idx = thread_idx * group_size;
if (base_idx >= size) {
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
size_t index = tidx + grid_dim_x * size_t(tidy);
if (index >= size) {
return;
}
auto w_tile = load_vector<group_size, T>(w, thread_idx);
float scale = 0.0f;
float w_thread = w[index];
Tx2 amax_2x = Tx2{0.0f, 0.0f};
#pragma unroll
for (int i = 0; i < group_size; i += 2) {
auto pair = Tx2{w_tile[i], w_tile[i + 1]};
abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
}
scale = static_cast<float>(
max(fabsf(static_cast<float>(amax_2x.x)),
fabsf(static_cast<float>(amax_2x.y))));
cg::greater<float> max_op;
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
float scale = cg::reduce(warp, abs(w_thread), max_op);
scale /= bits == 4 ? 6.0f : 448.0f;
// Convert to mx scale or nv scale
using ScaleType =
@@ -72,24 +68,21 @@ __global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) {
uint8_t q_scale = s.__x;
scale = float(s);
scales[thread_idx] = q_scale;
constexpr int elem_per_byte = bits == 8 ? 1 : 2;
AlignedVector<uint8_t, group_size / elem_per_byte> quantized;
#pragma unroll
for (int i = 0; i < group_size / 4; i++) {
Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);
if constexpr (bits == 8) {
uint32_t quantized_val =
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
*reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;
} else {
uint16_t quantized_val =
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
*reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;
}
// Write out the scales
size_t gindex = index / group_size;
if (index % group_size == 0) {
scales[gindex] = q_scale;
}
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
if (bits == 4) {
uint8_t sval = warp.shfl_down(output, 1);
output |= sval << bits;
}
constexpr int pack_factor = bits == 8 ? 1 : 2;
if (index % pack_factor == 0) {
out[index / pack_factor] = output;
}
store_vector<group_size / elem_per_byte>(out, thread_idx, quantized);
}
template <typename T, int group_size, int bits, bool use_mx_scale>
@@ -149,16 +142,15 @@ void fp_quantize(
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_quantize<T, 32, 4, true, false>;
auto kernel = cu::fp_quantize<T, 32, 4, true>;
if (bits == 8) {
kernel = cu::fp_quantize<T, 32, 8, true, false>;
kernel = cu::fp_quantize<T, 32, 8, true>;
} else if (group_size == 16) {
kernel = cu::fp_quantize<T, 16, 4, false, false>;
kernel = cu::fp_quantize<T, 16, 4, false>;
}
bool large = w.size() > UINT_MAX;
auto [num_blocks, block_dims] =
get_launch_args(w.size(), w.shape(), w.strides(), large, group_size);
get_launch_args(w.size(), w.shape(), w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,

View File

@@ -1,32 +0,0 @@
#pragma once
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
// TODO implement fast path
template <typename T>
__device__ __forceinline__ uint32_t
scale_cvt_Tx4_to_fp8x4_fallback(const Vector4_t<T> input, const float scale) {
uint32_t out_fp8x4 = 0;
float4 scaled;
scaled.x = static_cast<float>(input.x) * scale;
scaled.y = static_cast<float>(input.y) * scale;
scaled.z = static_cast<float>(input.z) * scale;
scaled.w = static_cast<float>(input.w) * scale;
out_fp8x4 = __nv_fp8x4_e4m3(scaled).__x;
return out_fp8x4;
}
// Place holder for future fast path implementation
template <typename T, bool USE_SR>
__device__ __forceinline__ uint32_t scale_cvt_Tx4_to_fp8x4(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
return scale_cvt_Tx4_to_fp8x4_fallback(input, scale);
}
} // namespace mlx::core::cu

View File

@@ -1,334 +0,0 @@
#pragma once
#include <cuda.h>
#include <cuda_fp4.h>
#include <cuda_runtime.h>
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
using bf16x4 = Vector4_t<__nv_bfloat16>;
using fp16x4 = Vector4_t<__half>;
using f32x4 = Vector4_t<float>;
template <typename T>
__device__ __forceinline__ uint16_t
scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t<T> input, const float scale) {
// Fallback implementation for architectures that do not support cvt
// instructions or for cuda versions with no fp4 support (< 12.8) -> scalar
uint16_t out_fp4x4 = 0;
fp32x4 scaled;
scaled.x = static_cast<float>(input.x) * scale;
scaled.y = static_cast<float>(input.y) * scale;
scaled.z = static_cast<float>(input.z) * scale;
scaled.w = static_cast<float>(input.w) * scale;
uint8_t q0 = __nv_fp4_e2m1(scaled.x).__x;
uint8_t q1 = __nv_fp4_e2m1(scaled.y).__x;
uint8_t q2 = __nv_fp4_e2m1(scaled.z).__x;
uint8_t q3 = __nv_fp4_e2m1(scaled.w).__x;
out_fp4x4 = (static_cast<uint16_t>(q3) << 12) |
(static_cast<uint16_t>(q2) << 8) | (static_cast<uint16_t>(q1) << 4) |
static_cast<uint16_t>(q0);
return out_fp4x4;
}
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
defined(__CUDA_ARCH_SPECIFIC__)
__device__ __forceinline__ uint16_t
scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_bf16; \n\t" // first bf16
".reg.b16 x1_bf16; \n\t" // second bf16
".reg.b16 x2_bf16; \n\t" // third bf16
".reg.b16 x3_bf16; \n\t" // fourth bf16
".reg.b32 x0; \n\t" // to hold scaled first
".reg.b32 x1; \n\t" // to hold scaled second
".reg.b32 x2; \n\t" // to hold scaled third
".reg.b32 x3; \n\t" // to hold scaled fourth
".reg.b64 x01; \n\t" // to hold vector mul
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t" // output byte fp4x2 (first pair)
".reg.b8 q1; \n\t" // output byte fp4x2 (second pair)
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16
"cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32
"cvt.f32.bf16 x1, x1_bf16; \n\t"
"cvt.f32.bf16 x2, x2_bf16; \n\t"
"cvt.f32.bf16 x3, x3_bf16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t" // scale first pair
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t" // scale second pair
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first
// pair
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second
// pair
"mov.b16 %0, {q0, q1}; \n\t" // pack to output
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
"l"(reinterpret_cast<const uint64_t&>(
scale))); // here cast is needed becuase an asm operand must have
// scalar type
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs(
const bf16x4 input_bf16x4,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_bf16; \n\t"
".reg.b16 x1_bf16; \n\t"
".reg.b16 x2_bf16; \n\t"
".reg.b16 x3_bf16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t"
"cvt.f32.bf16 x0, x0_bf16; \n\t"
"cvt.f32.bf16 x1, x1_bf16; \n\t"
"cvt.f32.bf16 x2, x2_bf16; \n\t"
"cvt.f32.bf16 x3, x3_bf16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn(
const float2 input_fp32x2_0,
const float2 input_fp32x2_1,
const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t"
".reg.b8 q1; \n\t"
"mov.b64 x01, {%1, %2}; \n\t"
"mul.f32x2 x01, x01, %5; \n\t"
"mov.b64 x23, {%3, %4}; \n\t"
"mul.f32x2 x23, x23, %5; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
"mov.b16 %0, {q0, q1}; \n\t"
"}"
: "=h"(out_fp4x4)
: "f"(input_fp32x2_0.x),
"f"(input_fp32x2_0.y),
"f"(input_fp32x2_1.x),
"f"(input_fp32x2_1.y),
"l"(reinterpret_cast<const uint64_t&>(scale)));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs(
const float2 input_fp32x2_0,
const float2 input_fp32x2_1,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 x01, {%1, %2}; \n\t"
"mul.f32x2 x01, x01, %5; \n\t"
"mov.b64 x23, {%3, %4}; \n\t"
"mul.f32x2 x23, x23, %5; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t"
"}"
: "=h"(out_fp4x4)
: "f"(input_fp32x2_0.x),
"f"(input_fp32x2_0.y),
"f"(input_fp32x2_1.x),
"f"(input_fp32x2_1.y),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t
scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_fp16; \n\t"
".reg.b16 x1_fp16; \n\t"
".reg.b16 x2_fp16; \n\t"
".reg.b16 x3_fp16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t"
".reg.b8 q1; \n\t"
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
"cvt.f32.f16 x0, x0_fp16; \n\t"
"cvt.f32.f16 x1, x1_fp16; \n\t"
"cvt.f32.f16 x2, x2_fp16; \n\t"
"cvt.f32.f16 x3, x3_fp16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
"mov.b16 %0, {q0, q1}; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs(
const fp16x4 input_fp16x4,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_fp16; \n\t"
".reg.b16 x1_fp16; \n\t"
".reg.b16 x2_fp16; \n\t"
".reg.b16 x3_fp16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
"cvt.f32.f16 x0, x0_fp16; \n\t"
"cvt.f32.f16 x1, x1_fp16; \n\t"
"cvt.f32.f16 x2, x2_fp16; \n\t"
"cvt.f32.f16 x3, x3_fp16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4(
const bf16x4 input,
const float scale,
uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
if constexpr (USE_SR) {
return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
} else {
return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2);
}
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4(
const fp16x4 input,
const float scale,
uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
if constexpr (USE_SR) {
return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
} else {
return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2);
}
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t
scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
float2 input_fp32x2_0 = make_float2(input.x, input.y);
float2 input_fp32x2_1 = make_float2(input.z, input.w);
if constexpr (USE_SR) {
return scale_cvt_fp32x4_to_fp4x4_rs(
input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits);
} else {
return scale_cvt_fp32x4_to_fp4x4_rn(
input_fp32x2_0, input_fp32x2_1, scale_fp32x2);
}
}
template <typename T, bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return scale_cvt_bf16x4_to_fp4x4<USE_SR>(input, scale, rbits);
} else if constexpr (std::is_same<T, __half>::value) {
return scale_cvt_fp16x4_to_fp4x4<USE_SR>(input, scale, rbits);
} else {
return scale_cvt_f32x4_to_fp4x4<USE_SR>(input, scale, rbits);
}
}
#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
// (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
template <typename T, bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
(__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
return scale_cvt_Tx4_to_fp4x4_fast<T, USE_SR>(input, scale, rbits);
#else
static_assert(
!USE_SR,
"Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000.");
return scale_cvt_Tx4_to_fp4x4_fallback(input, scale);
#endif
}
} // namespace mlx::core::cu

View File

@@ -15,22 +15,6 @@ inline constexpr __device__ short get_bytes_per_pack() {
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T>
__device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) {
if constexpr (
(std::is_same<T, __nv_bfloat162>::value) ||
(std::is_same<T, __half2>::value)) {
T a = x1;
T b = x2;
out = __hmax2(__habs2(a), __habs2(b));
} else if constexpr (std::is_same<T, float2>::value) {
float2 a = x1;
float2 b = x2;
out.x = fmaxf(fabsf(a.x), fabsf(b.x));
out.y = fmaxf(fabsf(a.y), fabsf(b.y));
}
}
} // namespace cu
template <typename F>

View File

@@ -3,10 +3,31 @@
#pragma once
#include "mlx/backend/cuda/steel/utils.cuh"
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
// Map types to their vector of 2 type float -> float2, double -> double2 etc
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
/**
* The basic building block for Ampere mmas. A 16x16 tile distributed across
* the warp.

View File

@@ -1,48 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace mlx::core::cu {
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
template <typename T>
struct Vector4 {
T x, y, z, w;
};
template <typename T>
using Vector4_t = Vector4<T>;
using bf16x4 = Vector4_t<__nv_bfloat16>;
using fp16x4 = Vector4_t<__half>;
using fp32x4 = Vector4_t<float>;
} // namespace mlx::core::cu

View File

@@ -347,7 +347,7 @@ template <
MMAFrag_mask_t::load_safe(
mfrag,
mask,
int64_t(mask_params->M_strides[2]),
int(mask_params->M_strides[2]),
Int<1>{},
params->qL,
params->kL,

View File

@@ -346,7 +346,7 @@ template <
MSubTile mfrag;
mfrag.load_safe(
mask,
int64_t(mask_params->M_strides[2]),
int(mask_params->M_strides[2]),
Int<1>{},
params->qL,
params->kL,

View File

@@ -105,20 +105,17 @@ struct BaseMMAFrag<T, 8, 8> {
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
src += off_x * str_x + off_y * str_y;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[i * kElemCols + j] = static_cast<T>(src[0]);
dst[i * kElemCols + j] =
static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
} else {
dst[i * kElemCols + j] = T(0);
}
src += str_y;
}
src -= kElemCols * str_y;
src += str_x;
}
}

View File

@@ -4,6 +4,11 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
if(MLX_BUILD_CPU AND NOT WIN32)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)

View File

@@ -5,6 +5,7 @@
#include "mlx/backend/cuda/cuda.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/jaccl/jaccl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h"
@@ -102,7 +103,27 @@ class EmptyGroup : public GroupImpl {
} // namespace detail
bool is_available() {
return mpi::is_available() || ring::is_available() || nccl::is_available();
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
jaccl::is_available();
}
bool is_available(const std::string& bk) {
if (bk == "any") {
return is_available();
}
if (bk == "mpi") {
return mpi::is_available();
}
if (bk == "ring") {
return ring::is_available();
}
if (bk == "nccl") {
return nccl::is_available();
}
if (bk == "jaccl") {
return jaccl::is_available();
}
return false;
}
int Group::rank() const {
@@ -135,6 +156,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = ring::init(strict);
} else if (bk == "nccl") {
group = nccl::init(strict);
} else if (bk == "jaccl") {
group = jaccl::init(strict);
} else if (bk == "any") {
if (mlx::core::cu::is_available()) {
group = nccl::init(false);
@@ -148,13 +171,17 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = mpi::init(false);
bk_ = "mpi";
}
if (group == nullptr) {
group = jaccl::init(false);
bk_ = "jaccl";
}
if (group == nullptr && strict) {
throw std::runtime_error("[distributed] Couldn't initialize any backend");
}
} else {
std::ostringstream msg;
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
<< "and 'ring' but '" << bk << "' was provided.";
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
<< "'jaccl' and 'ring' but '" << bk << "' was provided.";
throw std::invalid_argument(msg.str());
}

View File

@@ -16,6 +16,7 @@ class GroupImpl;
/* Check if a communication backend is available */
bool is_available();
bool is_available(const std::string& bk);
/**
* A distributed::Group represents a group of independent mlx processes that

View File

@@ -0,0 +1,8 @@
if(MLX_BUILD_CPU
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp)
target_link_libraries(mlx PRIVATE rdma)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)
endif()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,12 @@
// Copyright © 2025 Apple Inc.
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::jaccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::jaccl

View File

@@ -0,0 +1,20 @@
// Copyright © 2025 Apple Inc.
#include "mlx/distributed/jaccl/jaccl.h"
namespace mlx::core::distributed::jaccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available() {
return false;
}
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) {
throw std::runtime_error("Cannot initialize jaccl distributed backend.");
}
return nullptr;
}
} // namespace mlx::core::distributed::jaccl

View File

@@ -0,0 +1,38 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::distributed::detail {
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
};
template <typename T>
struct MaxOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output = std::max(*output, *input);
input++;
output++;
}
}
};
template <typename T>
struct MinOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output = std::min(*output, *input);
input++;
output++;
}
}
};
} // namespace mlx::core::distributed::detail

View File

@@ -1,9 +1,6 @@
// Copyright © 2024 Apple Inc.
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <unistd.h>
@@ -22,6 +19,8 @@
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/reduction_ops.h"
#include "mlx/distributed/utils.h"
#include "mlx/threadpool.h"
#ifndef SOL_TCP
@@ -94,6 +93,7 @@ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
constexpr const size_t ALL_SUM_BUFFERS = 2;
constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000;
constexpr const char* RING_TAG = "[ring]";
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
using json = nlohmann::json;
@@ -296,55 +296,6 @@ class CommunicationThreads {
std::unordered_map<int, SocketThread> threads_;
};
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* get() const {
return (struct sockaddr*)&addr;
}
};
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port) {
struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port) {
auto colon = ip_port.find(":");
if (colon == std::string::npos) {
std::ostringstream msg;
msg << "Can't parse address " << ip_port;
throw std::runtime_error(msg.str());
}
std::string ip(ip_port.begin(), ip_port.begin() + colon);
std::string port(ip_port.begin() + colon + 1, ip_port.end());
return parse_address(ip, port);
}
/**
* Load all addresses from the json hostfile. The hostfile is a list of
* addresses in order of rank. For each rank there can be many addresses so
@@ -357,15 +308,15 @@ address_t parse_address(const std::string& ip_port) {
* ["ip3:5000", "ip3:5001"],
* ]
*/
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
std::vector<std::vector<address_t>> nodes;
std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
std::vector<std::vector<detail::address_t>> nodes;
std::ifstream f(hostfile);
json hosts = json::parse(f);
for (auto& h : hosts) {
std::vector<address_t> host;
std::vector<detail::address_t> host;
for (auto& ips : h) {
host.push_back(parse_address(ips.get<std::string>()));
host.push_back(std::move(detail::parse_address(ips.get<std::string>())));
}
nodes.push_back(std::move(host));
}
@@ -377,73 +328,15 @@ std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
* Create a socket and accept one connection for each of the provided
* addresses.
*/
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
std::vector<int> accept_connections(
const std::vector<detail::address_t>& addresses) {
std::vector<int> sockets;
int success;
for (auto& address : addresses) {
// Create the socket to wait for connections from the peers
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind the socket to the address and port
success = bind(sock, address.get(), address.len);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Wait for connections
success = listen(sock, 0);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
int peer_socket = accept(sock, nullptr, nullptr);
if (peer_socket < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Close the listening socket
shutdown(sock, 2);
close(sock);
sockets.push_back(peer_socket);
detail::TCPSocket socket(RING_TAG);
socket.listen(RING_TAG, address);
sockets.push_back(socket.accept(RING_TAG).detach());
}
return sockets;
@@ -454,93 +347,42 @@ std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
* provided addresses.
*/
std::vector<int> make_connections(
const std::vector<address_t>& addresses,
const std::vector<detail::address_t>& addresses,
bool verbose) {
std::vector<int> sockets;
int success;
for (auto& address : addresses) {
int sock;
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
// backoff. TODO: Do we need that?
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
// Create the socket
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
if (attempt > 0) {
int wait = (1 << (attempt - 1)) * CONN_WAIT;
log_info(
verbose,
"Attempt",
attempt,
"wait",
wait,
"ms (error:",
errno,
")");
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
success = connect(sock, address.get(), address.len);
if (success == 0) {
break;
}
}
if (success < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't connect (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockets.push_back(sock);
sockets.push_back(detail::TCPSocket::connect(
RING_TAG,
address,
CONN_ATTEMPTS,
CONN_WAIT,
[verbose](int attempt, int wait) {
log_info(
verbose,
"Attempt",
attempt,
"waiting",
wait,
"ms (error:",
errno,
")");
})
.detach());
}
return sockets;
}
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
};
template <typename T>
struct MaxOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output = std::max(*output, *input);
input++;
output++;
}
}
};
template <typename T>
struct MinOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output = std::min(*output, *input);
input++;
output++;
}
}
};
} // namespace
class RingGroup : public GroupImpl {
public:
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
RingGroup(
int rank,
std::vector<std::vector<detail::address_t>> nodes,
bool verbose)
: rank_(rank), verbose_(verbose), pool_(0) {
if (rank_ > 0 && rank_ >= nodes.size()) {
throw std::runtime_error(
@@ -633,17 +475,17 @@ class RingGroup : public GroupImpl {
void all_sum(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
output, all_reduce<T>(input, output, stream, detail::SumOp<T>()));
}
void all_max(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
output, all_reduce<T>(input, output, stream, detail::MaxOp<T>()));
}
void all_min(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
output, all_reduce<T>(input, output, stream, detail::MinOp<T>()));
}
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {

204
mlx/distributed/utils.cpp Normal file
View File

@@ -0,0 +1,204 @@
// Copyright © 2025 Apple Inc.
#include <netdb.h>
#include <unistd.h>
#include <cstring>
#include <sstream>
#include <thread>
#include "mlx/distributed/utils.h"
namespace mlx::core::distributed::detail {
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port) {
struct addrinfo hints, *res;
std::memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port) {
auto colon = ip_port.find(":");
if (colon == std::string::npos) {
std::ostringstream msg;
msg << "Can't parse address " << ip_port;
throw std::runtime_error(msg.str());
}
std::string ip(ip_port.begin(), ip_port.begin() + colon);
std::string port(ip_port.begin() + colon + 1, ip_port.end());
return parse_address(ip, port);
}
TCPSocket::TCPSocket(const char* tag) {
sock_ = socket(AF_INET, SOCK_STREAM, 0);
if (sock_ < 0) {
std::ostringstream msg;
msg << tag << " Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
TCPSocket::TCPSocket(TCPSocket&& s) {
sock_ = s.sock_;
s.sock_ = -1;
}
TCPSocket& TCPSocket::operator=(TCPSocket&& s) {
if (this != &s) {
sock_ = s.sock_;
s.sock_ = -1;
}
return *this;
}
TCPSocket::TCPSocket(int s) : sock_(s) {}
TCPSocket::~TCPSocket() {
if (sock_ > 0) {
shutdown(sock_, 2);
close(sock_);
}
}
int TCPSocket::detach() {
int s = sock_;
sock_ = -1;
return s;
}
void TCPSocket::listen(const char* tag, const address_t& addr) {
int success;
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't enable reuseport (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind the socket to the address and port
success = bind(sock_, addr.get(), addr.len);
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't bind socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Prepare waiting for connections
success = ::listen(sock_, 0);
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
TCPSocket TCPSocket::accept(const char* tag) {
int peer = ::accept(sock_, nullptr, nullptr);
if (peer < 0) {
std::ostringstream msg;
msg << tag << " Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
return TCPSocket(peer);
}
void TCPSocket::send(const char* tag, const void* data, size_t len) {
while (len > 0) {
auto n = ::send(sock_, data, len, 0);
if (n <= 0) {
std::ostringstream msg;
msg << tag << " Send failed with errno=" << errno;
throw std::runtime_error(msg.str());
}
len -= n;
data = static_cast<const char*>(data) + n;
}
}
void TCPSocket::recv(const char* tag, void* data, size_t len) {
while (len > 0) {
auto n = ::recv(sock_, data, len, 0);
if (n <= 0) {
std::ostringstream msg;
msg << tag << " Recv failed with errno=" << errno;
throw std::runtime_error(msg.str());
}
len -= n;
data = static_cast<char*>(data) + n;
}
}
TCPSocket TCPSocket::connect(
const char* tag,
const address_t& addr,
int num_retries,
int wait,
std::function<void(int, int)> cb) {
int sock, success;
// Attempt to connect `num_retries` times with exponential backoff.
for (int attempt = 0; attempt < num_retries; attempt++) {
// Create the socket
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << tag << " Couldn't create socket to connect (error: " << errno
<< ")";
throw std::runtime_error(msg.str());
}
success = ::connect(sock, addr.get(), addr.len);
if (success == 0) {
break;
}
cb(attempt, wait);
if (wait > 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
wait <<= 1;
}
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't connect (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
return TCPSocket(sock);
}
} // namespace mlx::core::distributed::detail

67
mlx/distributed/utils.h Normal file
View File

@@ -0,0 +1,67 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/socket.h>
#include <functional>
#include <string>
namespace mlx::core::distributed::detail {
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* get() const {
return (struct sockaddr*)&addr;
}
};
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port);
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port);
/**
* Small wrapper over a TCP socket to simplify initiating connections.
*/
class TCPSocket {
public:
TCPSocket(const char* tag);
TCPSocket(const TCPSocket&) = delete;
TCPSocket& operator=(const TCPSocket&) = delete;
TCPSocket(TCPSocket&& s);
TCPSocket& operator=(TCPSocket&&);
~TCPSocket();
void listen(const char* tag, const address_t& addr);
TCPSocket accept(const char* tag);
void send(const char* tag, const void* data, size_t len);
void recv(const char* tag, void* data, size_t len);
int detach();
operator int() const {
return sock_;
}
static TCPSocket connect(
const char* tag,
const address_t& addr,
int num_retries = 1,
int wait = 0,
std::function<void(int, int)> cb = nullptr);
private:
TCPSocket(int sock);
int sock_;
};
} // namespace mlx::core::distributed::detail

View File

@@ -880,11 +880,6 @@ std::vector<array> ScaledDotProductAttention::vjp(
std::vector<array> returned_vjps;
for (int arg : argnums) {
if (arg >= 3) {
throw std::invalid_argument(
"[scale_dot_product_attention] Does not support VJP with respect "
" to mask or attention sinks.");
}
returned_vjps.push_back(std::move(vjps[arg]));
}
return returned_vjps;

View File

@@ -1,7 +1,7 @@
[build-system]
requires = [
"setuptools>=80",
"nanobind==2.10.2",
"nanobind==2.4.0",
"cmake>=3.25",
]
build-backend = "setuptools.build_meta"

View File

@@ -0,0 +1,85 @@
# Copyright © 2025 Apple Inc.
import ipaddress
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
@dataclass
class Host:
rank: int
ssh_hostname: str
ips: list[str]
rdma: list[Optional[str]]
def positive_number(x):
x = int(x)
if x <= 0:
raise ValueError("Number should be positive")
return x
def log(verbose, *args, **kwargs):
if not verbose:
return
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
def log_warning(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
def log_error(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
def parse_hostlist(parser, hostlist, repeats):
hosts = []
for i, h in enumerate(hostlist.split(",")):
if h == "":
raise ValueError("Hostname cannot be empty")
try:
ipaddress.ip_address(h)
ips = [h]
except ValueError:
ips = []
for i in range(repeats):
hosts.append(Host(i, h, ips))
return hosts
def parse_hostfile(parser, hostfile):
"""Parse the json hostfile that contains both the hostnames to ssh into and
the ips to communicate over when using the ring backend.
Example:
[
{"ssh": "hostname1", "ips": ["123.123.123.1"], "rdma": [null, "rdma_en2", "rdma_en3"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"], "rdma": ["rdma_en2", null, "rdma_en3"]},
...
{"ssh": "hostnameN", "ips": ["123.123.123.N"], "rdma": ["rdma_en2", "rdma_en3", null]},
]
Args:
hostfile (str): The path to the json file containing the host
information
"""
hostfile = Path(hostfile)
if not hostfile.exists():
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
try:
hosts = []
with open(hostfile) as f:
for i, h in enumerate(json.load(f)):
hosts.append(Host(i, h["ssh"], h.get("ips", []), h.get("rdma", [])))
return hosts
except Exception as e:
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")

View File

View File

@@ -832,7 +832,7 @@ def main():
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl"],
choices=["ring", "mpi", "nccl", "jaccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to launch",
)
@@ -903,6 +903,8 @@ def main():
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if args.backend == "jaccl":
launch_jaccl(parser, hosts, args, rest)
if __name__ == "__main__":

View File

@@ -0,0 +1,540 @@
# Copyright © 2025 Apple Inc.
import argparse
import base64
import json
import os
import shlex
import shutil
import sys
import tempfile
import threading
from collections import Counter
from itertools import chain
from pathlib import Path
from queue import Empty as QueueEmpty
from queue import Queue
from select import select
from subprocess import PIPE, Popen, run
import mlx.core as mx
from .common import log, log_warning, parse_hostfile, parse_hostlist, positive_number
class CommandProcess:
@property
def process(self):
"""Return the Popen object that refers to the current command."""
raise NotImplementedError()
@property
def exit_status(self):
"""Return a tuple (returncode, killed) for the command. It should be
(None, None) while the command is running normally."""
raise NotImplementedError()
def preprocess_output(self, data: str, is_stdout=False):
"""Preprocess the output of the command so that extra data can be
capture or the format changed on the fly."""
raise NotImplementedError()
def terminate(self):
"""Terminate or return the exit code."""
raise NotImplementedError()
class RemoteProcess(CommandProcess):
def __init__(self, rank, host, cwd, files, env, command):
is_local = host == "127.0.0.1"
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
if not is_local:
cmd = f"ssh {host} '{cmd}'"
self._host = host
self._pidfile = None
self._is_local = is_local
self._process = Popen(
cmd,
shell=True,
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
self._killed = False
@property
def process(self):
return self._process
@property
def exit_status(self):
return self._process.poll(), self._killed
def preprocess_output(self, data, is_stdout=False):
if self._pidfile is None:
pidfile, *rest = data.split("\n", maxsplit=1)
self._pidfile = pidfile
return rest[0] if rest else ""
return data
def terminate(self):
if self._killed:
return
self._process.terminate()
self._process.wait()
# Kill the remote program if possible
cmd = ""
cmd += f"pid=$(cat {self._pidfile}); "
cmd += "if ps -p $pid >/dev/null; then "
cmd += " kill $pid; "
cmd += " echo 1; "
cmd += "else "
cmd += " echo 0; "
cmd += "fi; "
cmd += f"rm {self._pidfile}"
if not self._is_local:
cmd = f"ssh {self._host} '{cmd}'"
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
self._killed = c.stdout.strip() == "1"
@staticmethod
def make_monitor_script(rank, cwd, files, env, command):
# Imports that are used throughout
script = ""
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Write the PID to a file so we can kill the process if needed
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
# Change the working directory if one was requested. Otherwise attempt to
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
if cwd is not None:
script += "else:\n"
script += f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
script += f" sys.exit(1)\n"
# Add the environment variables that were requested
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
if not all(c.isalnum() or c == "_" for c in key):
log_warning(
f"'{e}' is an invalid environment variable so it is ignored"
)
continue
script += f"env[{repr(key)}] = {repr(value)}\n"
# Make the temporary files
for env_name, content in files.items():
script += "_, fname = tempfile.mkstemp()\n"
script += "with open(fname, 'w') as f:\n"
script += f" f.write({repr(content)})\n"
script += f"env[{repr(env_name)}] = fname\n"
# Finally add the rank
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
# Replace the process with the script
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
return script
def _launch_with_io(command_class, arguments, verbose):
stop = False
exit_codes = [(None, None)] * len(arguments)
def _thread_fn(rank, *args, **kwargs):
stdin_queue = kwargs.pop("stdin_queue")
stdout_queue = kwargs.pop("stdout_queue")
stderr_queue = kwargs.pop("stderr_queue")
command = command_class(rank, *args, **kwargs)
p = command.process
os.set_blocking(p.stdout.fileno(), False)
os.set_blocking(p.stderr.fileno(), False)
os.set_blocking(p.stdin.fileno(), False)
to_read = [p.stdout.fileno(), p.stderr.fileno()]
to_write = [p.stdin.fileno()]
stdin_buffer = b""
while p.poll() is None:
try:
stdin_buffer += stdin_queue.get_nowait()
except QueueEmpty:
pass
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
for fd in rlist:
is_stdout = fd == p.stdout.fileno()
msg = os.read(fd, 8192).decode(errors="ignore")
msg = command.preprocess_output(msg, is_stdout)
if is_stdout:
stdout_queue.put(msg.encode())
else:
stderr_queue.put(msg.encode())
for fd in wlist:
if len(stdin_buffer) > 0:
n = os.write(fd, stdin_buffer)
stdin_buffer = stdin_buffer[n:]
if stop:
command.terminate()
break
exit_codes[rank] = command.exit_status
if exit_codes[rank][1]:
log_warning(f"Node with rank {rank} was killed")
elif exit_codes[rank][0] != 0:
log_warning(f"Node with rank {rank} exited with code {exit_codes[rank][0]}")
else:
log(verbose, f"Node with rank {rank} completed")
stdin_queues = []
stdout_queues = []
stderr_queues = []
threads = []
for i, (args, kwargs) in enumerate(arguments):
stdin_queues.append(Queue())
stdout_queues.append(Queue())
stderr_queues.append(Queue())
t = threading.Thread(
target=_thread_fn,
args=args,
kwargs=kwargs
| {
"stdin_queue": stdin_queues[-1],
"stdout_queue": stdout_queues[-1],
"stderr_queue": stderr_queues[-1],
},
)
t.start()
threads.append(t)
os.set_blocking(sys.stdin.fileno(), False)
os.set_blocking(sys.stdout.fileno(), True)
os.set_blocking(sys.stderr.fileno(), True)
while not stop or any(not q.empty() for q in chain(stdout_queues, stderr_queues)):
# Broadcast user input to the jobs
rlist, _, _ = select([sys.stdin.fileno()], [], [], 0.1)
for fd in rlist:
stdin_buffer = os.read(fd, 8192)
for q in stdin_queues:
q.put(stdin_buffer)
# Gather job output
for q in stdout_queues:
try:
while not q.empty():
sys.stdout.buffer.write(q.get_nowait())
except QueueEmpty:
pass
for q in stderr_queues:
try:
while not q.empty():
sys.stderr.buffer.write(q.get_nowait())
except QueueEmpty:
pass
sys.stdout.buffer.flush()
sys.stderr.buffer.flush()
# Check if all are running and terminate otherwise
if any(t.is_alive() for t in threads):
for i, t in enumerate(threads):
if not t.is_alive():
if exit_codes[i][0] != 0:
stop = True
break
else:
break
# Wait for the jobs to finish
for t in threads:
t.join()
# Process any remaining outputs
for q in stdout_queues:
while not q.empty():
sys.stdout.buffer.write(q.get())
for q in stderr_queues:
while not q.empty():
sys.stderr.buffer.write(q.get())
sys.stdout.buffer.flush()
sys.stderr.buffer.flush()
def launch_ring(parser, hosts, args, command):
if any(len(h.ips) == 0 for h in hosts):
parser.error(
"The ring backend requires IPs to be provided instead of hostnames"
)
port = args.starting_port
ring_hosts = []
for h in hosts:
node = []
for ip in h.ips:
for i in range(args.connections_per_ip):
node.append(f"{ip}:{port}")
port += 1
ring_hosts.append(node)
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
files = {"MLX_HOSTFILE": hostfile}
env = args.env
if args.verbose:
env.append("MLX_RING_VERBOSE=1")
cwd = args.cwd
log(args.verbose, "Running", shlex.join(command))
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
)
def launch_nccl(parser, hosts, args, command):
if not hosts[0].ips:
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
master_host = hosts[0].ips[0]
master_port = args.nccl_port
world_size = len(hosts)
env = args.env
cwd = args.cwd
if args.verbose:
env.append("NCCL_DEBUG=INFO")
env.append(f"NCCL_HOST_IP={master_host}")
env.append(f"NCCL_PORT={master_port}")
env.append(f"MLX_WORLD_SIZE={world_size}")
log(args.verbose, "Running", shlex.join(command))
_launch_with_io(
RemoteProcess,
[
(
(
rank,
h.ssh_hostname,
cwd,
{},
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
command,
),
{},
)
for rank, h in enumerate(hosts)
],
args.verbose,
)
def launch_jaccl(parser, hosts, args, command):
if not hosts[0].ips:
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
have_rdmas = all(len(h.rdma) == len(hosts) for h in hosts)
have_nulls = all(h.rdma[i] is None for i, h in enumerate(hosts))
if not have_rdmas or not have_nulls:
raise ValueError("Malformed hostfile for jaccl backend")
coordinator = hosts[0].ips[0]
env = args.env
cwd = args.cwd
env.append(f"MLX_JACCL_COORDINATOR={coordinator}:{args.starting_port}")
files = {"MLX_IBV_DEVICES": json.dumps([h.rdma for h in hosts])}
log(args.verbose, "Running", shlex.join(command))
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
)
def get_mpi_libname():
try:
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
ompi_info = ompi_info.stdout.strip().decode()
if platform.system() == "Darwin":
otool_output = run(
["otool", "-L", ompi_info], check=True, capture_output=True
)
else:
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
otool_output = otool_output.stdout.decode()
# StopIteration if not found
libmpi_line = next(
filter(lambda line: "libmpi" in line, otool_output.splitlines())
)
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
except:
return None
def launch_mpi(parser, hosts, args, command):
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
mpirun = mpirun.stdout.strip().decode()
# Compatibility with homebrew and pip installs
mpi_libname = get_mpi_libname()
if mpi_libname is not None:
dyld = Path(mpirun).parent.parent / "lib"
args.env = [
f"DYLD_LIBRARY_PATH={str(dyld)}",
f"MLX_MPI_LIBNAME={mpi_libname}",
] + args.env
log(args.verbose, f"Using '{mpirun}'")
with tempfile.NamedTemporaryFile(mode="w") as f:
hosts = Counter((h.ssh_hostname for h in hosts))
for h, n in hosts.items():
print(f"{h} slots={n}", file=f)
f.flush()
cmd = [
mpirun,
"--output",
":raw", # do not line buffer output
"--hostfile",
f.name,
*(["-cwd", args.cwd] if args.cwd else []),
*sum((["-x", e] for e in args.env), []),
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
"--",
*command,
]
log(args.verbose, "Running", " ".join(cmd))
try:
run(cmd)
except KeyboardInterrupt:
pass
def main():
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
parser.add_argument(
"--print-python",
action="store_true",
help="Print the path to the current python executable and exit",
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument(
"--repeat-hosts",
"-n",
type=positive_number,
default=1,
help="Repeat each host a given number of times",
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl", "jaccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to launch",
)
parser.add_argument(
"--env",
action="append",
default=[],
help="Set environment variables for the jobs",
)
parser.add_argument(
"--mpi-arg",
action="append",
default=[],
help="Arguments to pass directly to mpirun",
)
parser.add_argument(
"--connections-per-ip",
default=1,
type=int,
help="How many connections per ip to use for the ring backend",
)
parser.add_argument(
"--starting-port",
"-p",
type=int,
default=32323,
help="For the ring backend listen on this port increasing by 1 per rank and IP",
)
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
args, rest = parser.parse_known_args()
if args.print_python:
print(sys.executable)
return
if len(rest) == 0:
parser.error("No script is provided")
if rest[0] == "--":
rest.pop(0)
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
# Check if the script is a file and convert it to a full path
if (script := Path(rest[0])).exists() and script.is_file():
rest[0:1] = [sys.executable, str(script.resolve())]
elif (command := shutil.which(rest[0])) is not None:
rest[0] = command
else:
raise ValueError(f"Invalid script or command {rest[0]}")
# Launch
if args.backend == "ring":
launch_ring(parser, hosts, args, rest)
if args.backend == "mpi":
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if args.backend == "jaccl":
launch_jaccl(parser, hosts, args, rest)

View File

@@ -52,9 +52,25 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"is_available",
&mx::distributed::is_available,
[](const std::string& backend) {
return mx::distributed::is_available(backend);
},
"backend"_a = "any",
nb::sig("def is_available(backend: str = 'any') -> bool"),
R"pbdoc(
Check if a communication backend is available.
Note, this function returns whether MLX has the capability of
instantiating that distributed backend not whether it is possible to
create a communication group. For that purpose one should use
``init(strict=True)``.
Args:
backend (str, optional): The name of the backend to check for availability.
It takes the same values as ``init()``. Default: ``any``.
Returns:
bool: Whether the distributed backend is available.
)pbdoc");
m.def(
@@ -79,10 +95,10 @@ void init_distributed(nb::module_& parent_module) {
in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False``
backend (str, optional): Which distributed backend to initialize.
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
available backends are tried and the first one that succeeds
becomes the global group which will be returned in subsequent
calls. Default: ``any``
Possible values ``mpi``, ``ring``, ``nccl``, ``jaccl``, ``any``. If
set to ``any`` all available backends are tried and the first one
that succeeds becomes the global group which will be returned in
subsequent calls. Default: ``any``
Returns:
Group: The group representing all the launched processes.

View File

@@ -89,8 +89,7 @@ static PyType_Spec gc_func_spec = {
/* .name = */ "mlx.gc_func",
/* .basicsize = */ (int)sizeof(gc_func),
/* .itemsize = */ 0,
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_HAVE_VECTORCALL,
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL,
/* .slots = */ gc_func_slots};
static PyTypeObject* gc_func_tp = nullptr;

View File

@@ -16,7 +16,8 @@ struct type_caster<mlx::core::SmallVector<Type, Size, Alloc>> {
NB_TYPE_CASTER(
List,
const_name("tuple[") + make_caster<Type>::Name + const_name(", ...]"))
const_name(NB_TYPING_TUPLE "[") + make_caster<Type>::Name +
const_name(", ...]"))
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
size_t size;

View File

@@ -124,53 +124,37 @@ auto py_value_and_grad(
// Collect the arrays
std::vector<mx::array> arrays;
std::vector<nb::object> array_objects;
auto flatten_with_objects = [&arrays, &array_objects](
auto tree, bool strict) {
tree_visit(tree, [&](nb::handle obj) {
if (nb::isinstance<mx::array>(obj)) {
arrays.push_back(nb::cast<mx::array>(obj));
array_objects.push_back(nb::borrow<nb::object>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
};
std::vector<int> counts(1, 0);
std::vector<int> gradient_indices;
for (int i = 0, j = 0; i < args.size(); ++i) {
bool needs_grad = (j < argnums.size() && argnums[j] == i);
auto pre_size = arrays.size();
flatten_with_objects(args[i], /* strict = */ needs_grad);
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
if (needs_grad) {
auto old_size = gradient_indices.size();
auto delta_size = arrays.size() - pre_size;
gradient_indices.resize(old_size + delta_size);
gradient_indices.resize(old_size + argsi.size());
std::iota(
gradient_indices.begin() + old_size,
gradient_indices.end(),
pre_size);
arrays.size());
j++;
counts.push_back(delta_size);
counts.push_back(argsi.size());
}
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
}
for (auto item : kwargs) {
bool needs_grad =
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
auto pre_size = arrays.size();
flatten_with_objects(item.second, /* strict = */ needs_grad);
auto argsk = tree_flatten(item.second, /* strict = */ needs_grad);
if (needs_grad) {
auto old_size = gradient_indices.size();
auto delta_size = arrays.size() - pre_size;
gradient_indices.resize(old_size + delta_size);
gradient_indices.resize(old_size + argsk.size());
std::iota(
gradient_indices.begin() + old_size,
gradient_indices.end(),
pre_size);
counts.push_back(delta_size);
arrays.size());
counts.push_back(argsk.size());
}
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
}
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
@@ -179,7 +163,7 @@ auto py_value_and_grad(
nb::object py_value_out;
auto value_and_grads = mx::value_and_grad(
[&fun,
&array_objects,
&arrays,
&args,
&kwargs,
&py_value_out,
@@ -199,9 +183,8 @@ auto py_value_and_grad(
tree_visit_update(tree, [&](nb::handle node) {
auto replace_arr = nb::cast<mx::array>(node);
if (replace_arr.id() == a[index].id()) {
return array_objects[index++];
return nb::cast(arrays[index++]);
} else {
index++;
return nb::cast(replace_arr);
}
});

View File

@@ -780,21 +780,9 @@ class TestAutograd(mlx_tests.MLXTestCase):
return arrs[0]
arrs = [mx.array(1.0)]
arr = arrs[0]
init_id = id(arrs[0])
mx.grad(fun)(arrs)
self.assertEqual(id(arr), id(arrs[0]))
def fun(arrs):
arrs[1] = sum(arrs)
return arrs[1]
arrs = [mx.array(1.0), mx.array(1.0), mx.array(1.0)]
a_0, a_1, a_2 = arrs
mx.grad(fun)(arrs)
self.assertEqual(id(a_0), id(arrs[0]))
self.assertNotEqual(id(a_1), id(arrs[1]))
self.assertEqual(id(a_2), id(arrs[2]))
self.assertEqual(init_id, id(arrs[0]))
def test_grad_with_inplace_update(self):
def loss_fn(model):

View File

@@ -4,12 +4,12 @@ import gc
import inspect
import io
import math
import unittest
from functools import partial, wraps
from io import StringIO
import mlx.core as mx
import mlx_tests
import numpy as np
class TestCompile(mlx_tests.MLXTestCase):
@@ -1252,26 +1252,6 @@ class TestCompile(mlx_tests.MLXTestCase):
loss, grads = step(emb, w, x)
mx.eval(loss, grads)
def test_compile_donates_input_buffer(self):
mx.set_default_device(mx.cpu)
def fun(x):
return mx.sin(x) + 1
compiled_fn = mx.compile(fun)
input = mx.arange(16, dtype=mx.float32)
mx.eval(input)
in_ptr = np.asarray(input, copy=False).__array_interface__["data"][0]
out = compiled_fn(input)
del input # Ensure the reference is dropped
mx.eval(out)
self.assertEqual(
np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr
)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()

View File

@@ -744,6 +744,7 @@ class TestVmap(mlx_tests.MLXTestCase):
return Vector([t[0] + 10, t[1] * 10])
x = State(mx.array(1), mx.array(2))
print(f"{transform(x)=}")
vmap_transform = mx.vmap(transform)
vmap_transform_tuple = mx.vmap(transform_tuple)

View File

@@ -255,7 +255,7 @@ if __name__ == "__main__":
extras = {
"dev": [
"nanobind==2.10.2",
"nanobind==2.4.0",
"numpy",
"pre-commit",
"setuptools>=80",
@@ -265,8 +265,8 @@ if __name__ == "__main__":
}
entry_points = {
"console_scripts": [
"mlx.launch = mlx.distributed_run:main",
"mlx.distributed_config = mlx.distributed_run:distributed_config",
"mlx.launch = mlx._distributed_utils.launch:main",
# "mlx.distributed_config = mlx.distributed_run:distributed_config",
]
}
install_requires = []