Compare commits

..

8 Commits

Author SHA1 Message Date
Angelos Katharopoulos
4987e7615a Improve the cutlass gemm 2025-08-25 18:18:19 -07:00
Angelos Katharopoulos
e1303f6160 Reset cutlass gemm to working state again 2025-08-21 01:29:43 -07:00
Angelos Katharopoulos
cf5eef095d tmp 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
395d582719 Add a cutlass gemm 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
05583bcd10 More pipelining for the sm_80 gemm 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
6fce01593a Improve gemm 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
97afe40b7b Remove duplicate register tile 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
f70c62d69c Simple gemm example 2025-08-20 23:51:25 -07:00
38 changed files with 809 additions and 1116 deletions

View File

@@ -222,7 +222,6 @@ jobs:
sudo apt-get update sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12 sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf - curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64 rm -rf ccache-4.11.3-linux-x86_64
@@ -405,7 +404,6 @@ jobs:
sudo dpkg -i cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12 sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libnccl2 libnccl-dev
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip sudo apt-get install zip
pip install auditwheel pip install auditwheel

View File

@@ -25,11 +25,6 @@ MLX was developed with contributions from the following individuals:
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a> </a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software # Third-Party Software
MLX leverages several third-party software, listed here together with MLX leverages several third-party software, listed here together with

View File

@@ -1,54 +0,0 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
dpkg -i cuda-keyring_1.1-1_all.deb dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y apt-get update -y
apt-get -y install cuda-toolkit-12-9 apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag When building either the Python or C++ APIs make sure to pass the cmake flag

View File

@@ -234,7 +234,6 @@ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
template <typename MaskT, typename T1, typename T2, int N> template <typename MaskT, typename T1, typename T2, int N>
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) { Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
static_assert(std::is_same_v<MaskT, bool>);
if constexpr (sizeof(T1) == 1) { if constexpr (sizeof(T1) == 1) {
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value)); return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
} else if constexpr (sizeof(T1) == 2) { } else if constexpr (sizeof(T1) == 2) {
@@ -252,13 +251,9 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
return asd::pow(base.value, exp.value); return asd::pow(base.value, exp.value);
} else { } else {
Simd<T, N> res = 1; Simd<T, N> res = 1;
// Raising an integer to a negative power is undefined while (any(exp)) {
if (any(exp < 0)) { res = select(exp & 1, res * base, res);
return 0; base = select(exp, base * base, base);
}
while (any(exp > 0)) {
res = select((exp & 1) != 0, res * base, res);
base = select(exp > 0, base * base, base);
exp = exp >> 1; exp = exp >> 1;
} }
return res; return res;

View File

@@ -22,11 +22,12 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cutlass_gemm.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simple_gemm.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -89,6 +90,9 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
target_compile_options(mlx target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>") PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Keep ptx around for inspection
target_compile_options(mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--keep>")
# Enable calling host constexpr functions from device. This is needed because # Enable calling host constexpr functions from device. This is needed because
# the constexpr version of isnan is host only. # the constexpr version of isnan is host only.
target_compile_options( target_compile_options(
@@ -174,3 +178,12 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
# Install CCCL headers for JIT. # Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
# Fetch and make available cutlass
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG v4.1.0)
FetchContent_Populate(cutlass)
target_include_directories(
mlx PRIVATE $<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/include>)

View File

@@ -30,15 +30,8 @@ SmallSizePool::SmallSizePool() {
next_free_ = buffer_; next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size)); CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = 0;
#else
int loc = 0;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, loc)); cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
auto curr = next_free_; auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) { for (size_t i = 1; i < num_blocks; ++i) {

View File

@@ -269,13 +269,7 @@ void CommandEncoder::commit() {
if (node_count_ > 0) { if (node_count_ > 0) {
if (!from_nodes_.empty()) { if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies( CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_, graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
from_nodes_.data(),
to_nodes_.data(),
#if CUDART_VERSION >= 13000
nullptr, // edgeData
#endif // CUDART_VERSION >= 13000
from_nodes_.size()));
} }
graph_key_ += "."; graph_key_ += ".";

View File

@@ -204,12 +204,6 @@ struct Power {
__device__ T operator()(T base, T exp) { __device__ T operator()(T base, T exp) {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
T res = 1; T res = 1;
// Raising an integer to a negative power is undefined
if constexpr (cuda::std::is_signed_v<T>) {
if (exp < 0) {
return 0;
}
}
while (exp) { while (exp) {
if (exp & 1) { if (exp & 1) {
res *= base; res *= base;

View File

@@ -1,56 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/distributed/primitives.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core::distributed {
void AllReduce::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto set_input_output =
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s);
return {out, out};
} else if (in.is_donatable()) {
out.copy_shared_buffer(in);
return {in, out};
} else {
out.set_data(allocator::malloc(out.nbytes()));
return {in, out};
}
};
auto [input, output] = set_input_output(inputs[0], outputs[0]);
auto& encoder = cu::get_command_encoder(stream());
encoder.set_input_array(input);
encoder.set_output_array(output);
auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), input, output, s);
break;
case Max:
distributed::detail::all_max(group(), input, output, s);
break;
case Min:
distributed::detail::all_min(group(), input, output, s);
break;
default:
throw std::runtime_error(
"Only all reduce sum, max, and min are supported.");
}
}
} // namespace mlx::core::distributed

View File

@@ -0,0 +1,396 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cute/tensor.hpp>
#include <cutlass/arch/arch.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/numeric_types.h>
#include <iostream>
namespace mlx::core::cu {
namespace {
using namespace cute;
using bf16 = cute::bfloat16_t;
template <typename Kernel>
void configure_matmul(Kernel kernel, int smem_size) {
static bool initialized = false;
if (!initialized) {
initialized = true;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
}
template <bool transpose, typename Tiler>
constexpr int get_feature_size(Tiler smem) {
int feature_size = (transpose) ? size<0>(smem) : size<1>(smem);
return (feature_size >= 64) ? 64 : feature_size;
}
constexpr int constexpr_log2(int x) {
return (x > 0) ? 1 + constexpr_log2(x >> 1) : -1;
}
template <int feature_size, int itemsize, int copy_bits>
constexpr int get_swizzle_bits() {
constexpr int swizzle_bits =
constexpr_log2(feature_size * itemsize / copy_bits);
return (swizzle_bits > 3) ? 3 : swizzle_bits;
}
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
constexpr auto make_smem_layout(Tiler smem) {
constexpr int feature_size = get_feature_size<transpose>(smem);
constexpr int swizzle_bits =
get_swizzle_bits<feature_size, itemsize, copy_bits>();
using F = Int<feature_size>;
using BaseLayout = std::conditional_t<
transpose,
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
auto swizzled =
make_composed_layout(Swizzle<swizzle_bits, 3, 3>{}, 0, BaseLayout{});
return tile_to_shape(swizzled, smem);
}
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
constexpr auto make_result_smem_layout(Tiler smem) {
constexpr int feature_size = get_feature_size<transpose>(smem);
constexpr int swizzle_bits =
get_swizzle_bits<feature_size, itemsize, copy_bits>();
using F = Int<feature_size>;
using BaseLayout = std::conditional_t<
transpose,
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
auto swizzled = make_composed_layout(
Swizzle<transpose ? 0 : swizzle_bits, 3, 4>{}, 0, BaseLayout{});
return tile_to_shape(swizzled, smem);
}
template <
int num_threads,
int itemsize,
bool transpose,
int copy_bits,
typename Copier,
typename Tiler>
constexpr auto make_tiled_copy(Copier copy_op, Tiler smem) {
constexpr int num_elements = copy_bits / itemsize;
constexpr int feature_size = transpose ? size<0>(smem) : size<1>(smem);
constexpr int copies_per_feature = feature_size / num_elements;
using E = Int<num_elements>;
using C = Int<copies_per_feature>;
using R = Int<num_threads / copies_per_feature>;
using ThreadLayout = std::conditional_t<
transpose,
Layout<cute::Shape<C, R>, cute::Stride<_1, C>>,
Layout<cute::Shape<R, C>, cute::Stride<C, _1>>>;
using ValueLayout = std::conditional_t<
transpose,
Layout<cute::Shape<E, _1>>,
Layout<cute::Shape<_1, E>>>;
return make_tiled_copy(copy_op, ThreadLayout{}, ValueLayout{});
}
template <int rasterization_factor>
__device__ inline int2 raster_tile(int x, int y) {
return {
x / rasterization_factor,
(x % rasterization_factor) + y * rasterization_factor};
}
template <
typename T,
typename SLayoutA,
typename SLayoutB,
typename SLayoutC,
typename CopyA,
typename CopyB,
typename CopyC,
typename MMA,
int rasterization_factor>
__global__ static __launch_bounds__(decltype(size(MMA{}))::value) void matmul_kernel(
const T* __restrict__ A,
const T* __restrict__ B,
T* __restrict__ C,
SLayoutA SA,
SLayoutB SB,
SLayoutC SC,
CopyA copy_a,
CopyB copy_b,
CopyC copy_c,
MMA mma,
int M,
int N,
int K) {
constexpr auto BM = size<0>(SA);
constexpr auto BN = size<0>(SB);
constexpr auto BK = size<1>(SA);
constexpr auto PIPE = size<2>(SA);
const int2 tile = raster_tile<rasterization_factor>(blockIdx.x, blockIdx.y);
const int blocks_m = ceil_div(M, BM);
const int blocks_n = ceil_div(N, BN);
// Exit early if the tile is OOB
if (tile.x >= blocks_m || tile.y >= blocks_n) {
return;
}
// Make the full tensors
Tensor full_A =
make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{}));
Tensor full_B =
make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{}));
Tensor full_C =
make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{}));
// Partition the tensors into tiles and select the ones for this threadblock
Tensor local_A =
local_tile(full_A, make_shape(BM, BK), make_coord(tile.x, _));
Tensor local_B =
local_tile(full_B, make_shape(BN, BK), make_coord(tile.y, _));
Tensor local_C =
local_tile(full_C, make_shape(BM, BN), make_coord(tile.x, tile.y));
// Make shared memory tensors
extern __shared__ char shared_memory[];
T* shared_A_ptr = reinterpret_cast<T*>(shared_memory);
T* shared_B_ptr =
reinterpret_cast<T*>(shared_memory + cosize(SA) * sizeof(T));
T* shared_C_ptr = reinterpret_cast<T*>(shared_memory);
Tensor shared_A = make_tensor(make_smem_ptr(shared_A_ptr), SA);
Tensor shared_B = make_tensor(make_smem_ptr(shared_B_ptr), SB);
Tensor shared_C = make_tensor(make_smem_ptr(shared_C_ptr), SC);
// Get the copies that correspond to this thread
auto thread_copy_a = copy_a.get_slice(threadIdx.x);
Tensor local_A_src = thread_copy_a.partition_S(local_A);
Tensor local_A_dst = thread_copy_a.partition_D(shared_A);
auto thread_copy_b = copy_b.get_slice(threadIdx.x);
Tensor local_B_src = thread_copy_a.partition_S(local_B);
Tensor local_B_dst = thread_copy_a.partition_D(shared_B);
auto thread_copy_c = copy_c.get_slice(threadIdx.x);
Tensor local_C_src = thread_copy_c.partition_S(shared_C);
Tensor local_C_dst = thread_copy_c.partition_D(local_C);
// Start fetches
int k_tile_count = size<2>(local_A);
int k_tile_next = 0;
CUTE_UNROLL
for (int k = 0; k < PIPE - 1; k++) {
copy(copy_a, local_A_src(_, _, _, k_tile_next), local_A_dst(_, _, _, k));
copy(copy_b, local_B_src(_, _, _, k_tile_next), local_B_dst(_, _, _, k));
cp_async_fence();
k_tile_count--;
k_tile_next += (k_tile_count > 0);
}
// Get the MMA that corresponds to this thread and allocate registers
auto thread_mma = mma.get_slice(threadIdx.x);
Tensor mma_shared_A = thread_mma.partition_A(shared_A);
Tensor mma_shared_B = thread_mma.partition_B(shared_B);
Tensor mma_shared_C = thread_mma.partition_C(shared_C);
Tensor mma_global_C = thread_mma.partition_C(local_C);
Tensor mma_frag_A = mma.make_fragment_A(mma_shared_A(_, _, _, 0));
Tensor mma_frag_B = mma.make_fragment_B(mma_shared_B(_, _, _, 0));
Tensor mma_frag_C = mma.make_fragment_C(mma_global_C);
clear(mma_frag_C);
// Make shared to register copies
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_a;
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_b;
auto s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
auto s2r_thread_copy_a = s2r_copy_a.get_slice(threadIdx.x);
auto s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
auto s2r_thread_copy_b = s2r_copy_b.get_slice(threadIdx.x);
Tensor mma_A_src = s2r_thread_copy_a.partition_S(shared_A);
Tensor mma_A_dst = s2r_thread_copy_a.retile_D(mma_frag_A);
Tensor mma_B_src = s2r_thread_copy_b.partition_S(shared_B);
Tensor mma_B_dst = s2r_thread_copy_b.retile_D(mma_frag_B);
constexpr auto RPIPE = size<2>(mma_shared_A);
int smem_read = 0;
int smem_write = PIPE - 1;
Tensor mma_A_src_p = mma_A_src(_, _, _, smem_read);
Tensor mma_B_src_p = mma_B_src(_, _, _, smem_read);
// Start the register pipeline
if constexpr (RPIPE > 1) {
cp_async_wait<PIPE - 2>();
__syncthreads();
copy(s2r_copy_a, mma_A_src_p(_, _, Int<0>{}), mma_A_dst(_, _, Int<0>{}));
copy(s2r_copy_b, mma_B_src_p(_, _, Int<0>{}), mma_B_dst(_, _, Int<0>{}));
}
CUTE_NO_UNROLL
while (k_tile_count > -(PIPE - 1)) {
CUTE_UNROLL
for (int k_block = 0; k_block < RPIPE; k_block++) {
if (k_block == RPIPE - 1) {
mma_A_src_p = mma_A_src(_, _, _, smem_read);
mma_B_src_p = mma_B_src(_, _, _, smem_read);
cp_async_wait<PIPE - 2>();
__syncthreads();
}
// Load the next register tile
auto k_block_next = (k_block + 1) % RPIPE;
copy(
s2r_copy_a,
mma_A_src_p(_, _, k_block_next),
mma_A_dst(_, _, k_block_next));
copy(
s2r_copy_b,
mma_B_src_p(_, _, k_block_next),
mma_B_dst(_, _, k_block_next));
if (k_block == 0) {
copy(
copy_a,
local_A_src(_, _, _, k_tile_next),
local_A_dst(_, _, _, smem_write));
copy(
copy_b,
local_B_src(_, _, _, k_tile_next),
local_B_dst(_, _, _, smem_write));
cp_async_fence();
k_tile_count--;
k_tile_next += (k_tile_count > 0);
smem_write = smem_read;
smem_read = (smem_read == PIPE - 1) ? 0 : (smem_read + 1);
}
gemm(
mma,
mma_frag_A(_, _, k_block),
mma_frag_B(_, _, k_block),
mma_frag_C);
}
}
copy(mma_frag_C, mma_shared_C);
__syncthreads();
copy(copy_c, local_C_src, local_C_dst);
// if (threadIdx.x == 0) {
// print("fC: "); print(mma_frag_C); print("\n");
// print("sC: "); print(mma_shared_C); print("\n");
// print("dC: "); print(local_C_dst); print("\n");
//
// print(s2r_atom_a); print("\n");
// }
}
} // namespace
void cutlass_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc) {
enc.set_input_array(a);
enc.set_input_array(b);
enc.set_output_array(out);
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
using namespace cute;
// Tile definitions
auto BM = Int<128>{};
auto BN = Int<128>{};
auto BK = Int<64>{};
auto BP = Int<3>{};
auto GM = Int<8>{};
// Thread definitions
using TM = Int<2>;
using TN = Int<2>;
using TK = Int<1>;
constexpr int num_threads = TM::value * TN::value * 32;
auto SA = make_smem_layout<16, false, 128>(make_shape(BM, BK, BP));
auto SB = make_smem_layout<16, false, 128>(make_shape(BN, BK, BP));
auto SC = make_result_smem_layout<16, false, 128>(make_shape(BM, BN));
constexpr auto smem_size = (cosize(SA) + cosize(SB)) * sizeof(bf16);
auto async_copy_op =
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, bf16>{};
auto tiled_copy_a = make_tiled_copy<num_threads, 16, false, 128>(
async_copy_op, make_shape(BM, BK));
auto tiled_copy_b = make_tiled_copy<num_threads, 16, false, 128>(
async_copy_op, make_shape(BN, BK));
auto sync_copy_op = Copy_Atom<UniversalCopy<uint128_t>, bf16>{};
auto tiled_copy_c = make_tiled_copy<num_threads, 16, false, 128>(
sync_copy_op, make_shape(BM, BN));
auto mma_op = SM80_16x8x16_F32BF16BF16F32_TN{};
auto tiled_mma = make_tiled_mma(
mma_op, Layout<cute::Shape<TM, TN, TK>>{}, Tile<_32, _32, _16>{});
auto kernel = matmul_kernel<
bf16,
decltype(SA),
decltype(SB),
decltype(SC),
decltype(tiled_copy_a),
decltype(tiled_copy_b),
decltype(tiled_copy_c),
decltype(tiled_mma),
GM.value>;
configure_matmul(kernel, smem_size);
dim3 block(size(tiled_mma));
dim3 grid(
size(ceil_div(M, BM) * GM), size(ceil_div(ceil_div(N, BN), GM)));
enc.add_kernel_node(
kernel,
grid,
block,
smem_size,
a.data<bf16>(),
b.data<bf16>(),
out.data<bf16>(),
SA,
SB,
SC,
tiled_copy_a,
tiled_copy_b,
tiled_copy_c,
tiled_mma,
M,
N,
K);
} else {
throw std::runtime_error("Only bfloat16 supported");
}
});
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,18 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
void cutlass_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc);
}

View File

@@ -0,0 +1,69 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/steel/gemm.cuh"
#include "mlx/dtype_utils.h"
#include <iostream>
namespace mlx::core::cu {
namespace {
template <typename Kernel>
static void configure_smem(Kernel kernel, int SM) {
static bool done = false;
if (done) {
return;
}
std::cout << "configuring" << std::endl;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SM);
cudaFuncSetAttribute(
kernel,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared);
done = true;
}
} // namespace
void simple_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc) {
enc.set_input_array(a);
enc.set_input_array(b);
enc.set_output_array(out);
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 32;
constexpr int PIPE = 3;
constexpr int SM = PIPE * sizeof(DataType) * (BM * BK + BN * BK);
constexpr int WM = 2;
constexpr int WN = 4;
auto kernel = ab_t_aligned<DataType, BM, BN, BK, WM, WN, PIPE>;
configure_smem(kernel, SM);
dim3 grid(N / BN, M / BM);
enc.add_kernel_node(
kernel,
grid,
WM * WN * WARP_SIZE,
SM,
a.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
N,
K);
});
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,18 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
void simple_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc);
}

View File

@@ -3,7 +3,9 @@
#include "mlx/backend/common/matmul.h" #include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/gemms/cutlass_gemm.h"
#include "mlx/backend/cuda/gemms/gemv.h" #include "mlx/backend/cuda/gemms/gemv.h"
#include "mlx/backend/cuda/gemms/simple_gemm.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -11,8 +13,14 @@
#include <numeric> #include <numeric>
namespace mlx::core { namespace mlx::core {
namespace { namespace {
int get_test_gemm() {
static int t = env::get_var("MLX_ENABLE_TEST_GEMM", 0);
return t;
}
std::tuple<bool, int64_t, array> std::tuple<bool, int64_t, array>
check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2]; auto stx = arr.strides()[arr.ndim() - 2];
@@ -95,6 +103,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
b_transposed && batch_count == 1 && get_test_gemm() == 1) {
cu::simple_gemm(a, b, out, M, N, K, encoder);
return;
}
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
b_transposed && batch_count == 1 && get_test_gemm() == 2) {
cu::cutlass_gemm(a, b, out, M, N, K, encoder);
return;
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt
CublasGemm gemm( CublasGemm gemm(

View File

@@ -42,6 +42,7 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
namespace distributed { namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather) NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send) NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv) NO_GPU_MULTI(Recv)

View File

@@ -4,95 +4,189 @@
namespace mlx::core::cu { namespace mlx::core::cu {
template <typename T, int BM, int BN, int BK, int WM, int WN>
__device__ inline void gemm_ab_t(
RegisterTile<float, BM / WM, BN / WN>& C,
SharedTile<T, BM, BK>& As,
SharedTile<T, BN, BK>& Bs,
RegisterTileLoader<SharedTile<T, BM, BK>>& rloader_a,
RegisterTileLoader<SharedTile<T, BN, BK>>& rloader_b) {
RegisterTile<T, BM / WM, 16> A[2];
RegisterTile<T, BN / WN, 16> B[2];
rloader_a.load(A[0], As.base_addr(), 0);
rloader_b.load(B[0], Bs.base_addr(), 0);
MLX_UNROLL
for (int k = 1; k < BK / 16; k++) {
rloader_a.load(A[k & 1], As.base_addr(), k);
rloader_b.load(B[k & 1], Bs.base_addr(), k);
mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]);
}
mma_t(C, A[(BK / 16 - 1) & 1], B[(BK / 16 - 1) & 1]);
}
/** /**
* An example gemm written with the utils. * An example gemm written with the utils.
* *
* Computes A @ B.T when A and B are all aligned with the block sizes. * Computes A @ B.T when A and B are all aligned with the block sizes.
*/ */
template <typename T, int BM, int BN, int BK> // template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { //__global__ __launch_bounds__(WM * WN * WARP_SIZE, 1)
constexpr int WARPS_M = 2; // void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
constexpr int WARPS_N = 2; // constexpr int NUM_WARPS = WM * WN;
constexpr int NUM_WARPS = WARPS_M * WARPS_N; // constexpr int WARP_STEP_M = BM / WM;
constexpr int WARP_STEP_M = BM / WARPS_M; // constexpr int WARP_STEP_N = BN / WN;
constexpr int WARP_STEP_N = BN / WARPS_N; //
// // Precompute some offsets for each thread
// const int warpid = threadIdx.x / 32;
// const int laneid = threadIdx.x % 32;
// const int wm = warpid / WN;
// const int wn = warpid % WN;
// const int offset_m = wm * WARP_STEP_M;
// const int offset_n = wn * WARP_STEP_N;
//
// // Allocate shared memory
// extern __shared__ char shmem[];
// SharedTile<T, BM, BK>(&as)[PIPE] =
// *(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
// SharedTile<T, BN, BK>(&bs)[PIPE] =
// *(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
//
// // Move the global pointers to the tile
// a += blockIdx.y * BM * K;
// b += blockIdx.x * BN * K;
// y += blockIdx.y * BM * N + blockIdx.x * BN;
//
// // Make the loaders to/from SMEM
// SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>> sloader_a(a, K);
// SharedTileLoader<NUM_WARPS, SharedTile<T, BN, BK>> sloader_b(b, K);
// RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
// RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
//
// // Start the SM pipeline
// MLX_UNROLL
// for (int i = 0; i < PIPE - 1; i++) {
// sloader_a.load_async(as[i].base_addr());
// sloader_b.load_async(bs[i].base_addr());
// cp_async_commit();
// sloader_a.next();
// sloader_b.next();
// }
//
// // Allocate and zero the MMA accumulator
// RegisterTile<float, BM / WM, BN / WN> C;
// C.fill(0);
//
// // Matmul loop
// int num_blocks = K / BK;
// int sread = 0;
// int swrite = PIPE - 1;
// for (int i = 0; i < num_blocks; i++) {
// cp_async_wait<PIPE - 1>();
//
// gemm_ab_t<T, BM, BN, BK, WM, WN>(
// C, as[sread], bs[sread], rloader_a, rloader_b);
//
// sloader_a.load_async(as[swrite].base_addr());
// sloader_b.load_async(bs[swrite].base_addr());
// cp_async_commit();
// sloader_a.next(i + PIPE < num_blocks);
// sloader_b.next(i + PIPE < num_blocks);
//
// swrite = sread;
// sread = (sread + 1) % PIPE;
// }
//
// C.store_global(y, N, offset_m, offset_n);
// }
template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
__global__ __launch_bounds__(
WM* WN* WARP_SIZE,
1) void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
constexpr int NUM_WARPS = WM * WN;
constexpr int WARP_STEP_M = BM / WM;
constexpr int WARP_STEP_N = BN / WN;
// Precompute some offsets for each thread // Precompute some offsets for each thread
const int warpid = threadIdx.x / 32; const int warpid = threadIdx.x / 32;
const int laneid = threadIdx.x % 32; const int laneid = threadIdx.x % 32;
const int wm = warpid / WARPS_N; const int wm = warpid / WN;
const int wn = warpid % WARPS_N; const int wn = warpid % WN;
const int offset_m = wm * WARP_STEP_M; const int offset_m = wm * WARP_STEP_M;
const int offset_n = wn * WARP_STEP_N; const int offset_n = wn * WARP_STEP_N;
// Allocate shared memory // Allocate shared memory
extern __shared__ char shmem[]; extern __shared__ char shmem[];
SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]); SharedTile<T, BM, BK>(&as)[PIPE] =
SharedTile<T, BN, BK>(&bs)[2] = *(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
*(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]); SharedTile<T, BN, BK>(&bs)[PIPE] =
*(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
// Allocate registers for the MMA
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
RegisterTile<T, BM / WARPS_M, 16> A;
RegisterTile<T, BN / WARPS_N, 16> B;
// Move the global pointers to the tile // Move the global pointers to the tile
a += blockIdx.y * BM * K; a += blockIdx.y * BM * K;
b += blockIdx.x * BN * K; b += blockIdx.x * BN * K;
y += blockIdx.y * BM * N + blockIdx.x * BN; y += blockIdx.y * BM * N + blockIdx.x * BN;
// Zero the accumulators // Make the loaders to/from SMEM
C.fill(0); using sloader = SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>>;
constexpr int SSTEP = sloader::STEP_ROWS * sizeof(T) * BK;
const int srow = threadIdx.x / sloader::NUM_LOADS_PER_ROW;
const int scol =
(threadIdx.x % sloader::NUM_LOADS_PER_ROW) * sloader::ELEMENTS_PER_LOAD;
a += srow * K + scol;
b += srow * K + scol;
uint32_t sm_offsets[PIPE][2];
MLX_UNROLL
for (int s = 0; s < PIPE; s++) {
sm_offsets[s][0] = as[s].loc(as[s].base_addr(), srow, scol);
sm_offsets[s][1] = bs[s].loc(bs[s].base_addr(), srow, scol);
}
RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
// Start the SM pipeline // Start the SM pipeline
load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K); MLX_UNROLL
load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K); for (int s = 0; s < PIPE - 1; s++) {
cp_async_commit();
int tic = 0;
for (int k_block = BK; k_block < K; k_block += BK) {
load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
cp_async_commit();
cp_async_wait<1>();
__syncthreads();
MLX_UNROLL MLX_UNROLL
for (int k = 0; k < BK / 16; k++) { for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
A.load( cp_async<16>(sm_offsets[s][0] + l * SSTEP, a);
as[tic], cp_async<16>(sm_offsets[s][1] + l * SSTEP, b);
as[tic].base_addr(), a += sloader::STEP_ROWS * K;
offset_m + laneid % 16, b += sloader::STEP_ROWS * K;
k * 16 + laneid / 16 * 8);
B.load(
bs[tic],
bs[tic].base_addr(),
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
} }
cp_async_commit();
tic ^= 1;
} }
// Empty the pipeline // Allocate and zero the MMA accumulator
cp_async_wait_all(); RegisterTile<float, BM / WM, BN / WN> C;
__syncthreads(); C.fill(0);
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
as[tic],
as[tic].base_addr(),
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
bs[tic],
bs[tic].base_addr(),
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B); // Matmul loop
int num_blocks = K / BK;
int sread = 0;
int swrite = PIPE - 1;
for (int i = 0; i < num_blocks; i++) {
cp_async_wait<PIPE - 1>();
gemm_ab_t<T, BM, BN, BK, WM, WN>(
C, as[sread], bs[sread], rloader_a, rloader_b);
if (false) {
MLX_UNROLL
for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
cp_async<16>(sm_offsets[swrite][0] + l * SSTEP, a);
cp_async<16>(sm_offsets[swrite][1] + l * SSTEP, b);
a += sloader::STEP_ROWS * K;
b += sloader::STEP_ROWS * K;
}
}
cp_async_commit();
swrite = sread;
sread = (sread + 1) % PIPE;
} }
C.store_global(y, N, offset_m, offset_n); C.store_global(y, N, offset_m, offset_n);

View File

@@ -223,59 +223,10 @@ struct RegisterTile {
} }
}; };
/**
* A simple container of multiple Tile16x16.
*
* Provides utility functions for loading and manipulating collections of basic
* tiles.
*/
template <typename T, int ROWS_, int COLS_>
struct RegisterTile {
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
static constexpr int TILES_Y = ROWS / 16;
Tile16x16<T> data[TILES_X * TILES_Y];
__device__ inline void fill(T v) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].fill(v);
}
}
}
template <typename Tile>
__device__ inline void
load(Tile& tile, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].load(
tile.loc(base_address, row + i * 16, col + j * 16));
}
}
}
template <typename U>
__device__ inline void store_global(U* x, int N, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global(
x + (row + i * 16) * N + col + j * 16, N);
}
}
}
};
template <typename T, int ROWS_, int COLS_> template <typename T, int ROWS_, int COLS_>
struct SharedTile { struct SharedTile {
using value_type = T;
static constexpr int ROWS = ROWS_; static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_; static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16; static constexpr int TILES_X = COLS / 16;
@@ -317,23 +268,26 @@ struct SharedTile {
} }
} }
// Return the location of the element at (row, col) using the swizzle. __device__ static inline uint32_t offset(int row, int col) {
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
if constexpr (swizzle_bytes > 0) { if constexpr (swizzle_bytes > 0) {
static constexpr int swizzle_repeat = swizzle_bytes * 8; static constexpr int swizzle_repeat = swizzle_bytes * 8;
static constexpr int subtile_cols = swizzle_bytes / sizeof(T); static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
const int outer_idx = col / subtile_cols; const int outer_idx = col / subtile_cols;
const uint32_t addr = ptr + const uint32_t addr = sizeof(T) *
sizeof(T) * (outer_idx * ROWS * subtile_cols + row * subtile_cols +
(outer_idx * ROWS * subtile_cols + row * subtile_cols + col % subtile_cols);
col % subtile_cols);
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
return (addr ^ swizzle); return (addr ^ swizzle);
} else { } else {
return ptr + sizeof(T) * (row * COLS + col); return sizeof(T) * (row * COLS + col);
} }
} }
// Return the location of the element at (row, col) using the swizzle.
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
return ptr + offset(row, col);
}
// Convenience functions to edit elements going through the swizzle. // Convenience functions to edit elements going through the swizzle.
__device__ inline T& operator()(int row, int col) { __device__ inline T& operator()(int row, int col) {
return *ptr(data, row, col); return *ptr(data, row, col);
@@ -364,6 +318,76 @@ struct SharedTile {
} }
}; };
template <int NUM_WARPS, typename Tile>
struct SharedTileLoader {
using T = typename Tile::value_type;
static constexpr int NUM_THREADS = NUM_WARPS * 32;
static constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
static constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
static constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
static constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
static constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const T* x_;
int N_;
uint32_t offset_;
__device__ SharedTileLoader(const T* x, int N) : x_(x), N_(N) {
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x_ += row * N + col * ELEMENTS_PER_LOAD;
offset_ = Tile::offset(row, col * ELEMENTS_PER_LOAD);
}
__device__ inline void load_async(uint32_t base_address) {
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
cp_async<16>(
base_address + offset_ + i * STEP_ROWS * sizeof(T) * Tile::COLS,
x_ + i * STEP_ROWS * N_);
}
}
__device__ inline void next() {
x_ += Tile::COLS;
}
};
template <typename Tile>
struct RegisterTileLoader {
using T = typename Tile::value_type;
uint32_t offset_[Tile::COLS / 16];
__device__ RegisterTileLoader(int offset_row, int laneid) {
const int row = offset_row + laneid & 15;
const int col = (laneid >> 4) << 3;
MLX_UNROLL
for (int i = 0; i < Tile::COLS / 16; i++) {
offset_[i] = Tile::offset(row, col + i * 16);
}
}
template <typename T, int ROWS, int COLS>
__device__ inline void
load(RegisterTile<T, ROWS, COLS>& x, uint32_t base_address, int col) {
constexpr int TILES_Y = RegisterTile<T, ROWS, COLS>::TILES_Y;
constexpr int TILES_X = RegisterTile<T, ROWS, COLS>::TILES_X;
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
x.data[i * TILES_X + j].load(
base_address + offset_[j + col] + i * 16 * Tile::COLS * sizeof(T));
}
}
}
};
/** /**
* Load the tile from global memory by loading 16 bytes at a time and storing * Load the tile from global memory by loading 16 bytes at a time and storing
* them immediately. * them immediately.

View File

@@ -21,15 +21,15 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
#if defined(MLX_CUDA_SM_80_ENABLED) #if defined(MLX_CUDA_SM_80_ENABLED)
if constexpr (N == 16) { if constexpr (N == 16) {
asm volatile( asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address), "cp.async.cg.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int4*>(x))); "l"(reinterpret_cast<const int4*>(x)));
} else if constexpr (N == 8) { } else if constexpr (N == 8) {
asm volatile( asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address), "cp.async.cg.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int2*>(x))); "l"(reinterpret_cast<const int2*>(x)));
} else if constexpr (N == 4) { } else if constexpr (N == 4) {
asm volatile( asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address), "cp.async.cg.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int*>(x))); "l"(reinterpret_cast<const int*>(x)));
} }
#endif #endif

View File

@@ -223,11 +223,6 @@ struct Power {
template <typename T> template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) { metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1; T res = 1;
// Undefined to raise integer to negative power
if (exp < 0) {
return 0;
}
while (exp) { while (exp) {
if (exp & 1) { if (exp & 1) {
res *= base; res *= base;

View File

@@ -6,4 +6,3 @@ target_sources(
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)

View File

@@ -2,21 +2,15 @@
#include <unordered_map> #include <unordered_map>
#include "mlx/backend/cuda/cuda.h"
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h" #include "mlx/distributed/ring/ring.h"
namespace mlx::core::distributed { namespace mlx::core::distributed {
namespace detail { namespace detail {
Stream communication_stream(Group group, StreamOrDevice s /* = {} */) {
return group.raw_group()->communication_stream(s);
}
void all_sum(Group group, const array& input, array& output, Stream stream) { void all_sum(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_sum(input, output, stream); group.raw_group()->all_sum(input, output, stream);
} }
@@ -43,10 +37,6 @@ void recv(Group group, array& out, int src, Stream stream) {
class EmptyGroup : public GroupImpl { class EmptyGroup : public GroupImpl {
public: public:
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s);
}
int rank() override { int rank() override {
return 0; return 0;
} }
@@ -90,7 +80,7 @@ class EmptyGroup : public GroupImpl {
} // namespace detail } // namespace detail
bool is_available() { bool is_available() {
return mpi::is_available() || ring::is_available() || nccl::is_available(); return mpi::is_available() || ring::is_available();
} }
int Group::rank() const { int Group::rank() const {
@@ -115,23 +105,15 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
} }
// Create the requested communication group // Create the requested communication group
std::shared_ptr<detail::GroupImpl> group{nullptr}; std::shared_ptr<detail::GroupImpl> group;
std::string bk_ = bk; std::string bk_ = bk;
if (bk == "mpi") { if (bk == "mpi") {
group = mpi::init(strict); group = mpi::init(strict);
} else if (bk == "ring") { } else if (bk == "ring") {
group = ring::init(strict); group = ring::init(strict);
} else if (bk == "nccl") {
group = nccl::init(strict);
} else if (bk == "any") { } else if (bk == "any") {
if (mlx::core::cu::is_available()) { group = ring::init(false);
group = nccl::init(false); bk_ = "ring";
bk_ = "nccl";
}
if (group == nullptr) {
group = ring::init(false);
bk_ = "ring";
}
if (group == nullptr) { if (group == nullptr) {
group = mpi::init(false); group = mpi::init(false);
bk_ = "mpi"; bk_ = "mpi";

View File

@@ -5,7 +5,6 @@
#include <memory> #include <memory>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/utils.h"
namespace mlx::core::distributed { namespace mlx::core::distributed {

View File

@@ -13,15 +13,10 @@ class GroupImpl {
public: public:
virtual ~GroupImpl() {} virtual ~GroupImpl() {}
// Choose the stream this communication group can operate on
virtual Stream communication_stream(StreamOrDevice s = {}) = 0;
// Group operations
virtual int rank() = 0; virtual int rank() = 0;
virtual int size() = 0; virtual int size() = 0;
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0; virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
// Actual communication operations
virtual void all_sum(const array& input, array& output, Stream stream) = 0; virtual void all_sum(const array& input, array& output, Stream stream) = 0;
virtual void all_gather(const array& input, array& output, Stream stream) = 0; virtual void all_gather(const array& input, array& output, Stream stream) = 0;
virtual void send(const array& input, int dst, Stream stream) = 0; virtual void send(const array& input, int dst, Stream stream) = 0;
@@ -30,9 +25,6 @@ class GroupImpl {
virtual void all_min(const array& input, array& output, Stream stream) = 0; virtual void all_min(const array& input, array& output, Stream stream) = 0;
}; };
/* Define the MLX stream that the communication should happen in. */
Stream communication_stream(Group group, StreamOrDevice s = {});
/* Perform an all reduce sum operation */ /* Perform an all reduce sum operation */
void all_sum(Group group, const array& input, array& output, Stream stream); void all_sum(Group group, const array& input, array& output, Stream stream);

View File

@@ -349,10 +349,6 @@ class MPIGroup : public GroupImpl {
} }
} }
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::cpu);
}
int rank() override { int rank() override {
if (rank_ < 0) { if (rank_ < 0) {
mpi().rank(comm_, &rank_); mpi().rank(comm_, &rank_);

View File

@@ -1,8 +0,0 @@
if(MLX_BUILD_CUDA)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
find_package(NCCL REQUIRED)
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
endif()

View File

@@ -1,372 +0,0 @@
#include <arpa/inet.h>
#include <cuda_runtime.h>
#include <nccl.h>
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <mutex>
#include <stdexcept>
#include <string>
#include <type_traits>
#include "mlx/backend/cuda/device.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/dtype_utils.h"
#include "mlx/utils.h"
namespace mlx::core::distributed::nccl {
#define CHECK_CUDA(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
fprintf( \
stderr, \
"CUDA error %s:%d '%s'\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString(e)); \
exit(1); \
} \
} while (0)
#define CHECK_NCCL(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
fprintf( \
stderr, \
"NCCL error %s:%d '%s'\n", \
__FILE__, \
__LINE__, \
ncclGetErrorString(r)); \
exit(1); \
} \
} while (0)
#define MLX_NCCL_TYPE_LIST(X) \
X(int8_t, ncclChar) \
X(uint8_t, ncclUint8) \
X(int32_t, ncclInt) \
X(uint32_t, ncclUint32) \
X(int64_t, ncclInt64) \
X(uint64_t, ncclUint64) \
X(float16_t, ncclHalf) \
X(bfloat16_t, ncclBfloat16) \
X(float, ncclFloat) \
X(double, ncclDouble)
template <class>
struct nccl_map {
static constexpr bool ok = false; // default: unsupported
};
#define MLX_DEF_NCCL_MAP(T, E) \
template <> \
struct nccl_map<T> { \
static constexpr bool ok = true; \
static constexpr ncclDataType_t value = E; \
};
MLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP)
#undef MLX_DEF_NCCL_MAP
namespace detail {
template <typename F>
void dispatch_dtype(const array& arr, F&& f) {
dispatch_all_types(arr.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
if constexpr (nccl_map<T>::ok) {
f(type_tag, nccl_map<T>::value);
} else {
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
}
});
}
inline void sendAll(int sock, const void* buf, size_t len) {
const char* ptr = reinterpret_cast<const char*>(buf);
while (len > 0) {
ssize_t sent = send(sock, ptr, len, 0);
if (sent <= 0) {
perror("send");
exit(1);
}
ptr += sent;
len -= sent;
}
}
inline void recvAll(int sock, void* buf, size_t len) {
char* ptr = reinterpret_cast<char*>(buf);
while (len > 0) {
ssize_t rec = recv(sock, ptr, len, 0);
if (rec <= 0) {
perror("recv");
exit(1);
}
ptr += rec;
len -= rec;
}
}
inline void bootstrap_unique_id(
ncclUniqueId& id,
int rank,
int size,
const std::string& initMethod) {
// Parse the init method to extract the host and port
if (initMethod.rfind("tcp://", 0) != 0)
throw;
auto hostport = initMethod.substr(6);
auto colon = hostport.find(':');
std::string host = hostport.substr(0, colon);
int port = std::stoi(hostport.substr(colon + 1));
if (rank == 0) {
// create a unique id on the rank 0
CHECK_NCCL(ncclGetUniqueId(&id));
// create a socket to send the unique id to all other ranks
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockaddr_in serv = {};
serv.sin_family = AF_INET;
serv.sin_addr.s_addr = htonl(INADDR_ANY);
serv.sin_port = htons(port);
int reuse = 1;
// Without this, if rank-0 crashes or restarts process quickly,
// the OS might refuse to let binding to the same port, so reuse
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
std::ostringstream msg;
msg << "[nccl] setsockopt() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
std::ostringstream msg;
msg << "[nccl] bind() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
if (listen(sock, size - 1) < 0) {
std::ostringstream msg;
msg << "[nccl] listen() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
for (int peer = 1; peer < size; ++peer) {
int conn = accept(sock, nullptr, nullptr);
if (conn < 0) {
std::ostringstream msg;
msg << "[nccl] accept() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
sendAll(conn, &id, sizeof(id));
close(conn);
}
close(sock);
} else {
// Here just wanted to make show that rank 0 has enough time to bind
// so we will retry to connect until max attempts
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[nccl] socket() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
hostent* he = gethostbyname(host.c_str());
if (!he) {
throw std::runtime_error("[nccl] lookup failed for host: " + host);
}
sockaddr_in serv = {};
serv.sin_family = AF_INET;
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
serv.sin_port = htons(port);
const int max_retries = 30;
int attempt = 0;
bool connected = false;
bool do_log = std::getenv("NCCL_DEBUG") == "INFO";
for (attempt = 0; attempt < max_retries; ++attempt) {
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
0) {
connected = true;
if (do_log) {
std::cout << "[Rank " << rank
<< "] Connected successfully on attempt " << attempt + 1
<< std::endl;
break;
}
}
if (errno != ECONNREFUSED) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
if (!connected) {
std::ostringstream msg;
msg << "[Rank " << rank << "] connect() failed after " << attempt
<< " retries: " << strerror(errno);
close(sock);
throw std::runtime_error(msg.str());
}
recvAll(sock, &id, sizeof(id));
close(sock);
}
}
} // namespace detail
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
class NCCLGroup : public GroupImpl {
public:
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
: rank_(worldRank),
size_(worldSize),
comm_(nullptr),
initMethod_(initMethod) {
if (initialized_)
return;
int ndev;
CHECK_CUDA(cudaGetDeviceCount(&ndev));
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
initialized_ = true;
}
~NCCLGroup() {
ncclCommDestroy(comm_);
ncclGroupEnd();
initialized_ = false;
}
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::gpu);
}
int rank() override {
return rank_;
}
int size() override {
return size_;
}
void all_sum(const array& input, array& output, Stream stream) override {
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
});
}
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[nccl] Group split not supported.");
}
void all_gather(const array& input, array& output, Stream stream) override {
throw std::runtime_error(
"[nccl] All gather not supported in NCCL backend.");
}
void send(const array& input, int dst, Stream stream) override {
throw std::runtime_error("[nccl] Send not supported in NCCL backend.");
}
void recv(array& output, int src, Stream stream) override {
throw std::runtime_error("[nccl] Recv not supported in NCCL backend.");
}
void all_max(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
}
void all_min(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
}
template <typename T>
void all_reduce_impl(
const array& input,
array& output,
Stream stream,
ncclDataType_t dt,
ncclRedOp_t op) {
auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclAllReduce(
input.data<T>(),
output.data<T>(),
input.size(),
dt,
op,
comm_,
encoder.stream()));
}
int rank_, size_;
std::string initMethod_;
ncclUniqueId uniqueId_;
ncclComm_t comm_;
bool initialized_ = false;
};
bool is_available() {
return true;
}
namespace detail {
std::string get_env_var_or_throw(const char* env_var_name, bool strict) {
const char* value = std::getenv(env_var_name);
if (value == nullptr && strict) {
std::ostringstream msg;
msg << "[nccl] Required environment variable '" << env_var_name
<< "' is not set. "
<< "Please set it before initializing the distributed backend.";
throw std::runtime_error(msg.str());
}
if (value == nullptr) {
return "";
}
return std::string(value);
}
} // namespace detail
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP", strict);
std::string port = detail::get_env_var_or_throw("NCCL_PORT", strict);
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK", strict);
std::string n_nodes_str =
detail::get_env_var_or_throw("MLX_WORLD_SIZE", strict);
if (!strict &&
(host.empty() || port.empty() || rank_str.empty() ||
n_nodes_str.empty())) {
return nullptr;
}
int rank = std::stoi(rank_str);
int n_nodes = std::stoi(n_nodes_str);
std::string init_method = "tcp://" + host + ":" + port;
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
}
} // namespace mlx::core::distributed::nccl

View File

@@ -1,12 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::nccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::nccl

View File

@@ -1,20 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/nccl/nccl.h"
namespace mlx::core::distributed::nccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available() {
return false;
}
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) {
throw std::runtime_error("Cannot initialize nccl distributed backend.");
}
return nullptr;
}
} // namespace mlx::core::distributed::nccl

View File

@@ -2,9 +2,6 @@
#include <sstream> #include <sstream>
#include "mlx/backend/cuda/cuda.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/ops.h" #include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h" #include "mlx/distributed/primitives.h"
@@ -31,12 +28,11 @@ array all_sum(
if (group.size() == 1) { if (group.size() == 1) {
return x; return x;
} }
auto stream = detail::communication_stream(group, s);
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
std::make_shared<AllReduce>(stream, group, AllReduce::Sum), std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Sum),
{x}); {x});
} }
@@ -49,12 +45,11 @@ array all_max(
if (group.size() == 1) { if (group.size() == 1) {
return x; return x;
} }
auto stream = detail::communication_stream(group, s);
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
std::make_shared<AllReduce>(stream, group, AllReduce::Max), std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Max),
{x}); {x});
} }
@@ -67,12 +62,11 @@ array all_min(
if (group.size() == 1) { if (group.size() == 1) {
return x; return x;
} }
auto stream = detail::communication_stream(group, s);
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
std::make_shared<AllReduce>(stream, group, AllReduce::Min), std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Min),
{x}); {x});
} }
@@ -85,7 +79,6 @@ array all_gather(
if (group.size() == 1) { if (group.size() == 1) {
return x; return x;
} }
auto stream = detail::communication_stream(group, s);
auto result_shape = x.shape(); auto result_shape = x.shape();
if (result_shape.size() == 0) { if (result_shape.size() == 0) {
@@ -96,7 +89,7 @@ array all_gather(
return array( return array(
std::move(result_shape), std::move(result_shape),
x.dtype(), x.dtype(),
std::make_shared<AllGather>(stream, group), std::make_shared<AllGather>(to_stream(s, Device::cpu), group),
{x}); {x});
} }
@@ -110,7 +103,6 @@ array send(
if (group.size() == 1) { if (group.size() == 1) {
throw std::invalid_argument("Cannot send to a singleton group"); throw std::invalid_argument("Cannot send to a singleton group");
} }
auto stream = detail::communication_stream(group, s);
if (dst < 0 || dst >= group.size()) { if (dst < 0 || dst >= group.size()) {
std::ostringstream msg; std::ostringstream msg;
@@ -120,7 +112,10 @@ array send(
} }
return array( return array(
x.shape(), x.dtype(), std::make_shared<Send>(stream, group, dst), {x}); x.shape(),
x.dtype(),
std::make_shared<Send>(to_stream(s, Device::cpu), group, dst),
{x});
} }
array recv( array recv(
@@ -134,7 +129,6 @@ array recv(
if (group.size() == 1) { if (group.size() == 1) {
throw std::invalid_argument("Cannot recv from a singleton group"); throw std::invalid_argument("Cannot recv from a singleton group");
} }
auto stream = detail::communication_stream(group, s);
if (src < 0 || src >= group.size()) { if (src < 0 || src >= group.size()) {
std::ostringstream msg; std::ostringstream msg;
@@ -145,7 +139,7 @@ array recv(
return array( return array(
std::move(shape), std::move(shape),
std::move(dtype), std::move(dtype),
std::make_shared<Recv>(stream, group, src), std::make_shared<Recv>(to_stream(s, Device::cpu), group, src),
std::vector<array>{}); std::vector<array>{});
} }

View File

@@ -619,10 +619,6 @@ class RingGroup : public GroupImpl {
} }
} }
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::cpu);
}
int rank() override { int rank() override {
return rank_; return rank_;
} }

View File

@@ -20,8 +20,6 @@ from select import select
from subprocess import PIPE, Popen, run from subprocess import PIPE, Popen, run
from typing import Optional from typing import Optional
import mlx.core as mx
@dataclass @dataclass
class Host: class Host:
@@ -55,11 +53,6 @@ def parse_hardware_ports(ports_string):
return ports return ports
def get_num_nvidia_gpus():
result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
return len(result.stdout.strip().split("\n"))
def extract_rings(hosts, index): def extract_rings(hosts, index):
def usable_port(i, j, used_ports): def usable_port(i, j, used_ports):
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
@@ -422,57 +415,6 @@ def launch_mpi(parser, hosts, args, command):
pass pass
def launch_nccl(parser, hosts, args, command):
master_host = hosts[0].ips[0]
if master_host != "127.0.0.1":
raise ValueError("The NCCL backend only supports localhost for now.")
master_port = args.nccl_port
world_size = len(hosts)
base_env = os.environ.copy()
base_env.update(
{
"NCCL_DEBUG": base_env.get(
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
),
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
"NCCL_HOST_IP": master_host,
"NCCL_PORT": str(master_port),
"MLX_WORLD_SIZE": str(world_size),
}
)
procs = []
num_gpus = get_num_nvidia_gpus()
if num_gpus == 0:
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
if args.repeat_hosts > num_gpus:
raise RuntimeError("NCCL requires a separate GPU per process.")
try:
for rank in range(world_size):
env = base_env.copy()
mlx_rank = str(rank % args.repeat_hosts)
env["MLX_RANK"] = mlx_rank
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
p = Popen(command, env=env)
procs.append(p)
for p in procs:
ret = p.wait()
if ret != 0:
raise RuntimeError(f"Rank process exited with {ret}")
except (RuntimeError, KeyboardInterrupt) as err:
for p in procs:
if p.poll() is None:
try:
p.kill()
except Exception:
pass
raise
def check_ssh_connections(hosts): def check_ssh_connections(hosts):
results = [False] * len(hosts) results = [False] * len(hosts)
@@ -723,8 +665,8 @@ def distributed_config():
) )
parser.add_argument( parser.add_argument(
"--backend", "--backend",
choices=["ring", "mpi", "nccl"], choices=["ring", "mpi"],
default="nccl" if mx.cuda.is_available() else "ring", default="ring",
help="Which distributed backend to configure", help="Which distributed backend to configure",
) )
parser.add_argument( parser.add_argument(
@@ -795,8 +737,8 @@ def main():
parser.add_argument("--hostfile", help="The file containing the hosts") parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument( parser.add_argument(
"--backend", "--backend",
choices=["ring", "mpi", "nccl"], choices=["ring", "mpi"],
default="nccl" if mx.cuda.is_available() else "ring", default="ring",
help="Which distributed backend to launch", help="Which distributed backend to launch",
) )
parser.add_argument( parser.add_argument(
@@ -827,14 +769,9 @@ def main():
parser.add_argument( parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one" "--cwd", help="Set the working directory on each node to the provided one"
) )
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
args, rest = parser.parse_known_args() args, rest = parser.parse_known_args()
if rest[0] == "--":
rest.pop(0)
if args.print_python: if args.print_python:
print(sys.executable) print(sys.executable)
@@ -862,10 +799,8 @@ def main():
# Launch # Launch
if args.backend == "ring": if args.backend == "ring":
launch_ring(parser, hosts, args, rest) launch_ring(parser, hosts, args, rest)
if args.backend == "mpi": elif args.backend == "mpi":
launch_mpi(parser, hosts, args, rest) launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -76,7 +76,6 @@ def average_gradients(
group: Optional[mx.distributed.Group] = None, group: Optional[mx.distributed.Group] = None,
all_reduce_size: int = 32 * 1024**2, all_reduce_size: int = 32 * 1024**2,
communication_type: Optional[mx.Dtype] = None, communication_type: Optional[mx.Dtype] = None,
stream: mx.Stream = mx.cpu,
): ):
"""Average the gradients across the distributed processes in the passed group. """Average the gradients across the distributed processes in the passed group.
@@ -95,7 +94,6 @@ def average_gradients(
communication_type (Optional[mlx.core.Dtype]): If provided cast to this communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``. smaller float to reduce the communication size. Default: ``None``.
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
""" """
group = group or mx.distributed.init() group = group or mx.distributed.init()
N = group.size() N = group.size()
@@ -106,7 +104,7 @@ def average_gradients(
def _average(x): def _average(x):
dt = x.dtype dt = x.dtype
x = x.astype(communication_type) if communication_type is not None else x x = x.astype(communication_type) if communication_type is not None else x
return mx.distributed.all_sum(x, stream=stream).astype(dt) / N return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
if all_reduce_size <= 0: if all_reduce_size <= 0:
return tree_map(_average, gradients) return tree_map(_average, gradients)

View File

@@ -6,7 +6,6 @@ auditwheel repair dist/* \
--exclude libnvrtc* \ --exclude libnvrtc* \
--exclude libcuda* \ --exclude libcuda* \
--exclude libcudnn* \ --exclude libcudnn* \
--exclude libnccl* \
-w wheel_tmp -w wheel_tmp
@@ -18,7 +17,7 @@ rm "${repaired_wheel}"
mlx_so="mlx/lib/libmlx.so" mlx_so="mlx/lib/libmlx.so"
rpath=$(patchelf --print-rpath "${mlx_so}") rpath=$(patchelf --print-rpath "${mlx_so}")
base="\$ORIGIN/../../nvidia" base="\$ORIGIN/../../nvidia"
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib:${base}/nccl/lib rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so" patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
python ../python/scripts/repair_record.py ${mlx_so} python ../python/scripts/repair_record.py ${mlx_so}

View File

@@ -79,7 +79,7 @@ void init_distributed(nb::module_& parent_module) {
in case ``mx.distributed.is_available()`` returns False otherwise in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False`` it throws a runtime error. Default: ``False``
backend (str, optional): Which distributed backend to initialize. backend (str, optional): Which distributed backend to initialize.
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all Possible values ``mpi``, ``ring``, ``any``. If set to ``any`` all
available backends are tried and the first one that succeeds available backends are tried and the first one that succeeds
becomes the global group which will be returned in subsequent becomes the global group which will be returned in subsequent
calls. Default: ``any`` calls. Default: ``any``

View File

@@ -1,284 +0,0 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx_tests
from mlx.nn.layers.distributed import shard_inplace, shard_linear
from mlx.nn.utils import average_gradients
class TestNCCLDistributed(mlx_tests.MLXTestCase):
@classmethod
def setUpClass(cls):
world = mx.distributed.init(strict=True, backend="nccl")
rank = world.rank()
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
def test_all_reduce(self):
world = mx.distributed.init()
dtypes = [
(mx.int8, 0),
(mx.uint8, 0),
(mx.int32, 0),
(mx.uint32, 0),
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
]
sizes = [
(7,),
(10,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
for dt, rtol in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
# All sum
y = mx.distributed.all_sum(x[world.rank()])
z = x.sum(0)
maxrelerror = (y - z).abs()
if rtol > 0:
maxrelerror /= z.abs()
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
def test_average_gradients(self):
original_all_sum = mx.distributed.all_sum
n_calls = 0
xtype = None
def new_all_sum(x, **kwargs):
nonlocal n_calls
nonlocal xtype
n_calls += 1
if xtype is not None:
self.assertEqual(xtype, x.dtype)
return original_all_sum(x, **kwargs)
mx.distributed.all_sum = new_all_sum
try:
grads = [mx.ones(10) for i in range(10)]
new_grads = average_gradients(grads, stream=mx.gpu)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 1)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 10)
n_calls = 0
xtype = mx.float16
new_grads = average_gradients(
grads,
all_reduce_size=2 * 50,
communication_type=mx.float16,
stream=mx.gpu,
)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)
finally:
mx.distributed.all_sum = original_all_sum
def test_donation(self):
x = mx.random.normal((1024,))
mx.eval(x)
mx.synchronize()
mx.reset_peak_memory()
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
mx.synchronize()
all_sum_only = mx.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
mx.synchronize()
all_sum_with_binary = mx.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)
def test_shard_linear(self):
# Seed the prng to have the same inputs and weights generated everywhere
mx.random.seed(0xF0F0F0F0)
# Prepare inputs
world = mx.distributed.init()
part = (
slice(None),
slice(
world.rank() * 1024 // world.size(),
(world.rank() + 1) * 1024 // world.size(),
),
)
x = mx.random.normal((4, 1024))
# Create and shard some linear layers
lin = nn.Linear(1024, 1024, bias=True)
slin1 = shard_linear(lin, "all-to-sharded")
slin2 = shard_linear(lin, "sharded-to-all")
y = lin(x)
y1 = slin1(x)
y2 = slin2(x[part])
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
# Check the backward works as expected
def dummy_loss(model, x, y):
return (model(x) * y).sum()
mod = nn.Sequential(
nn.Linear(128, 128),
nn.Linear(128, 128),
nn.Linear(128, 128),
nn.Linear(128, 128),
)
smod = nn.Sequential(
shard_linear(mod.layers[0], "all-to-sharded"),
shard_linear(mod.layers[1], "sharded-to-all"),
shard_linear(mod.layers[2], "all-to-sharded"),
shard_linear(mod.layers[3], "sharded-to-all"),
)
grad1 = nn.value_and_grad(mod, dummy_loss)
grad2 = nn.value_and_grad(smod, dummy_loss)
x = mx.random.normal((4, 128))
y = mx.random.normal((4, 128))
l1, g1 = grad1(mod, x, y)
l2, g2 = grad2(smod, x, y)
mx.eval(l1, g1, l2, g2)
part = slice(
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
)
self.assertTrue(mx.allclose(l1, l2))
self.assertTrue(
mx.allclose(
g1["layers"][0]["weight"][part],
g2["layers"][0]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][2]["weight"][part],
g2["layers"][2]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][1]["weight"][:, part],
g2["layers"][1]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][3]["weight"][:, part],
g2["layers"][3]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][0]["bias"][part],
g2["layers"][0]["bias"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][2]["bias"][part],
g2["layers"][2]["bias"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
)
)
self.assertTrue(
mx.allclose(
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
)
)
def test_shard_predicate(self):
mx.random.seed(0xF0F0F0F0)
class MyConv(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.aggregate = kwargs.pop("aggregate", False)
self.conv = nn.Conv2d(*args, **kwargs)
def __call__(self, x):
x = self.conv(x)
if self.aggregate:
x = mx.distributed.all_sum(x)
return x
def sharding(path, weight):
parts = path.split(".")
even = int(parts[1]) % 2 == 0
if even:
return 0
else:
return -1 if parts[-1] != "bias" else None
mod = nn.Sequential(
MyConv(3, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3),
MyConv(128, 3, kernel_size=3),
)
smod = nn.Sequential(
MyConv(3, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3, aggregate=True),
MyConv(128, 128, kernel_size=3),
MyConv(128, 3, kernel_size=3, aggregate=True),
)
smod.update(mod.parameters())
shard_inplace(smod, sharding)
x = mx.random.normal((4, 16, 16, 3))
y1 = mod(x)
y2 = smod(x)
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()

View File

@@ -3068,13 +3068,6 @@ class TestOps(mlx_tests.MLXTestCase):
d = mx.where(c, a[1:], b) d = mx.where(c, a[1:], b)
self.assertTrue(mx.all(d == 1.0)) self.assertTrue(mx.all(d == 1.0))
def test_integer_power(self):
x = mx.power(2, mx.array([8, 8, 8, 8, 8, 8, 8, 8]))
self.assertTrue(mx.all(x == 256))
# Doesn't hang
x = mx.power(2, -1)
class TestBroadcast(mlx_tests.MLXTestCase): class TestBroadcast(mlx_tests.MLXTestCase):
def test_broadcast_shapes(self): def test_broadcast_shapes(self):

View File

@@ -297,7 +297,6 @@ if __name__ == "__main__":
"nvidia-cublas-cu12==12.9.*", "nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*", "nvidia-cuda-nvrtc-cu12==12.9.*",
"nvidia-cudnn-cu12==9.*", "nvidia-cudnn-cu12==9.*",
"nvidia-nccl-cu12",
] ]
else: else:
name = "mlx-cpu" name = "mlx-cpu"