mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
8 Commits
30561229c7
...
simple-gem
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4987e7615a | ||
|
|
e1303f6160 | ||
|
|
cf5eef095d | ||
|
|
395d582719 | ||
|
|
05583bcd10 | ||
|
|
6fce01593a | ||
|
|
97afe40b7b | ||
|
|
f70c62d69c |
@@ -222,7 +222,6 @@ jobs:
|
|||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libcudnn9-dev-cuda-12
|
sudo apt-get install libcudnn9-dev-cuda-12
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
sudo apt-get install libnccl2 libnccl-dev
|
|
||||||
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||||
rm -rf ccache-4.11.3-linux-x86_64
|
rm -rf ccache-4.11.3-linux-x86_64
|
||||||
@@ -405,7 +404,6 @@ jobs:
|
|||||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||||
sudo apt-get install libnccl2 libnccl-dev
|
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
sudo apt-get install zip
|
sudo apt-get install zip
|
||||||
pip install auditwheel
|
pip install auditwheel
|
||||||
|
|||||||
@@ -25,11 +25,6 @@ MLX was developed with contributions from the following individuals:
|
|||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
# Organizations
|
|
||||||
|
|
||||||
MLX has received contributions from the following companies:
|
|
||||||
- NVIDIA Corporation & Affiliates
|
|
||||||
|
|
||||||
# Third-Party Software
|
# Third-Party Software
|
||||||
|
|
||||||
MLX leverages several third-party software, listed here together with
|
MLX leverages several third-party software, listed here together with
|
||||||
|
|||||||
@@ -1,54 +0,0 @@
|
|||||||
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
|
||||||
# directories.
|
|
||||||
|
|
||||||
set(NCCL_ROOT_DIR
|
|
||||||
$ENV{NCCL_ROOT_DIR}
|
|
||||||
CACHE PATH "Folder contains NVIDIA NCCL")
|
|
||||||
|
|
||||||
find_path(
|
|
||||||
NCCL_INCLUDE_DIRS
|
|
||||||
NAMES nccl.h
|
|
||||||
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
|
||||||
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
|
||||||
|
|
||||||
if($ENV{USE_STATIC_NCCL})
|
|
||||||
message(
|
|
||||||
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
|
||||||
set(NCCL_LIBNAME "libnccl_static.a")
|
|
||||||
else()
|
|
||||||
set(NCCL_LIBNAME "nccl")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
find_library(
|
|
||||||
NCCL_LIBRARIES
|
|
||||||
NAMES ${NCCL_LIBNAME}
|
|
||||||
HINTS ${NCCL_LIB_DIR}
|
|
||||||
${NCCL_ROOT_DIR}
|
|
||||||
${NCCL_ROOT_DIR}/lib
|
|
||||||
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
|
||||||
${NCCL_ROOT_DIR}/lib64
|
|
||||||
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
|
||||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
|
||||||
|
|
||||||
include(FindPackageHandleStandardArgs)
|
|
||||||
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
|
||||||
NCCL_LIBRARIES)
|
|
||||||
|
|
||||||
if(NCCL_FOUND)
|
|
||||||
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
|
||||||
message(
|
|
||||||
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
|
||||||
file(
|
|
||||||
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
|
||||||
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
|
||||||
LIMIT_COUNT 1)
|
|
||||||
if(NCCL_MAJOR_VERSION_DEFINED)
|
|
||||||
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
|
||||||
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
|
||||||
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
|
||||||
endif()
|
|
||||||
message(
|
|
||||||
STATUS
|
|
||||||
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
|
||||||
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
|
||||||
endif()
|
|
||||||
@@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
|
|||||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
apt-get update -y
|
apt-get update -y
|
||||||
apt-get -y install cuda-toolkit-12-9
|
apt-get -y install cuda-toolkit-12-9
|
||||||
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
|
||||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
|||||||
@@ -234,7 +234,6 @@ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
|
|||||||
|
|
||||||
template <typename MaskT, typename T1, typename T2, int N>
|
template <typename MaskT, typename T1, typename T2, int N>
|
||||||
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
|
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
|
||||||
static_assert(std::is_same_v<MaskT, bool>);
|
|
||||||
if constexpr (sizeof(T1) == 1) {
|
if constexpr (sizeof(T1) == 1) {
|
||||||
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
|
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
|
||||||
} else if constexpr (sizeof(T1) == 2) {
|
} else if constexpr (sizeof(T1) == 2) {
|
||||||
@@ -252,13 +251,9 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
|
|||||||
return asd::pow(base.value, exp.value);
|
return asd::pow(base.value, exp.value);
|
||||||
} else {
|
} else {
|
||||||
Simd<T, N> res = 1;
|
Simd<T, N> res = 1;
|
||||||
// Raising an integer to a negative power is undefined
|
while (any(exp)) {
|
||||||
if (any(exp < 0)) {
|
res = select(exp & 1, res * base, res);
|
||||||
return 0;
|
base = select(exp, base * base, base);
|
||||||
}
|
|
||||||
while (any(exp > 0)) {
|
|
||||||
res = select((exp & 1) != 0, res * base, res);
|
|
||||||
base = select(exp > 0, base * base, base);
|
|
||||||
exp = exp >> 1;
|
exp = exp >> 1;
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
|
|||||||
@@ -22,11 +22,12 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cutlass_gemm.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simple_gemm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
@@ -89,6 +90,9 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
|||||||
target_compile_options(mlx
|
target_compile_options(mlx
|
||||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||||
|
|
||||||
|
# Keep ptx around for inspection
|
||||||
|
target_compile_options(mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--keep>")
|
||||||
|
|
||||||
# Enable calling host constexpr functions from device. This is needed because
|
# Enable calling host constexpr functions from device. This is needed because
|
||||||
# the constexpr version of isnan is host only.
|
# the constexpr version of isnan is host only.
|
||||||
target_compile_options(
|
target_compile_options(
|
||||||
@@ -174,3 +178,12 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
|||||||
# Install CCCL headers for JIT.
|
# Install CCCL headers for JIT.
|
||||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||||
|
|
||||||
|
# Fetch and make available cutlass
|
||||||
|
FetchContent_Declare(
|
||||||
|
cutlass
|
||||||
|
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
|
||||||
|
GIT_TAG v4.1.0)
|
||||||
|
FetchContent_Populate(cutlass)
|
||||||
|
target_include_directories(
|
||||||
|
mlx PRIVATE $<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/include>)
|
||||||
|
|||||||
@@ -30,15 +30,8 @@ SmallSizePool::SmallSizePool() {
|
|||||||
next_free_ = buffer_;
|
next_free_ = buffer_;
|
||||||
|
|
||||||
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||||
#if CUDART_VERSION >= 13000
|
|
||||||
cudaMemLocation loc;
|
|
||||||
loc.type = cudaMemLocationTypeDevice;
|
|
||||||
loc.id = 0;
|
|
||||||
#else
|
|
||||||
int loc = 0;
|
|
||||||
#endif // CUDART_VERSION >= 13000
|
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, loc));
|
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
||||||
|
|
||||||
auto curr = next_free_;
|
auto curr = next_free_;
|
||||||
for (size_t i = 1; i < num_blocks; ++i) {
|
for (size_t i = 1; i < num_blocks; ++i) {
|
||||||
|
|||||||
@@ -269,13 +269,7 @@ void CommandEncoder::commit() {
|
|||||||
if (node_count_ > 0) {
|
if (node_count_ > 0) {
|
||||||
if (!from_nodes_.empty()) {
|
if (!from_nodes_.empty()) {
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
||||||
graph_,
|
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
|
||||||
from_nodes_.data(),
|
|
||||||
to_nodes_.data(),
|
|
||||||
#if CUDART_VERSION >= 13000
|
|
||||||
nullptr, // edgeData
|
|
||||||
#endif // CUDART_VERSION >= 13000
|
|
||||||
from_nodes_.size()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
graph_key_ += ".";
|
graph_key_ += ".";
|
||||||
|
|||||||
@@ -204,12 +204,6 @@ struct Power {
|
|||||||
__device__ T operator()(T base, T exp) {
|
__device__ T operator()(T base, T exp) {
|
||||||
if constexpr (cuda::std::is_integral_v<T>) {
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
T res = 1;
|
T res = 1;
|
||||||
// Raising an integer to a negative power is undefined
|
|
||||||
if constexpr (cuda::std::is_signed_v<T>) {
|
|
||||||
if (exp < 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
while (exp) {
|
while (exp) {
|
||||||
if (exp & 1) {
|
if (exp & 1) {
|
||||||
res *= base;
|
res *= base;
|
||||||
|
|||||||
@@ -1,56 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
|
||||||
#include "mlx/distributed/primitives.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
|
||||||
void AllReduce::eval_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
assert(outputs.size() == 1);
|
|
||||||
|
|
||||||
auto set_input_output =
|
|
||||||
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
|
|
||||||
if (!in.flags().row_contiguous) {
|
|
||||||
copy_gpu(in, out, CopyType::General, s);
|
|
||||||
return {out, out};
|
|
||||||
} else if (in.is_donatable()) {
|
|
||||||
out.copy_shared_buffer(in);
|
|
||||||
return {in, out};
|
|
||||||
} else {
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
return {in, out};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
auto [input, output] = set_input_output(inputs[0], outputs[0]);
|
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(stream());
|
|
||||||
encoder.set_input_array(input);
|
|
||||||
encoder.set_output_array(output);
|
|
||||||
|
|
||||||
auto capture = encoder.capture_context();
|
|
||||||
auto& s = stream();
|
|
||||||
|
|
||||||
switch (reduce_type_) {
|
|
||||||
case Sum:
|
|
||||||
distributed::detail::all_sum(group(), input, output, s);
|
|
||||||
break;
|
|
||||||
case Max:
|
|
||||||
distributed::detail::all_max(group(), input, output, s);
|
|
||||||
break;
|
|
||||||
case Min:
|
|
||||||
distributed::detail::all_min(group(), input, output, s);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(
|
|
||||||
"Only all reduce sum, max, and min are supported.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace mlx::core::distributed
|
|
||||||
396
mlx/backend/cuda/gemms/cutlass_gemm.cu
Normal file
396
mlx/backend/cuda/gemms/cutlass_gemm.cu
Normal file
@@ -0,0 +1,396 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
#include <cute/tensor.hpp>
|
||||||
|
#include <cutlass/arch/arch.h>
|
||||||
|
#include <cutlass/cutlass.h>
|
||||||
|
#include <cutlass/gemm/device/gemm.h>
|
||||||
|
#include <cutlass/layout/matrix.h>
|
||||||
|
#include <cutlass/numeric_types.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
using bf16 = cute::bfloat16_t;
|
||||||
|
|
||||||
|
template <typename Kernel>
|
||||||
|
void configure_matmul(Kernel kernel, int smem_size) {
|
||||||
|
static bool initialized = false;
|
||||||
|
if (!initialized) {
|
||||||
|
initialized = true;
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool transpose, typename Tiler>
|
||||||
|
constexpr int get_feature_size(Tiler smem) {
|
||||||
|
int feature_size = (transpose) ? size<0>(smem) : size<1>(smem);
|
||||||
|
return (feature_size >= 64) ? 64 : feature_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int constexpr_log2(int x) {
|
||||||
|
return (x > 0) ? 1 + constexpr_log2(x >> 1) : -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int feature_size, int itemsize, int copy_bits>
|
||||||
|
constexpr int get_swizzle_bits() {
|
||||||
|
constexpr int swizzle_bits =
|
||||||
|
constexpr_log2(feature_size * itemsize / copy_bits);
|
||||||
|
return (swizzle_bits > 3) ? 3 : swizzle_bits;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
|
||||||
|
constexpr auto make_smem_layout(Tiler smem) {
|
||||||
|
constexpr int feature_size = get_feature_size<transpose>(smem);
|
||||||
|
constexpr int swizzle_bits =
|
||||||
|
get_swizzle_bits<feature_size, itemsize, copy_bits>();
|
||||||
|
|
||||||
|
using F = Int<feature_size>;
|
||||||
|
using BaseLayout = std::conditional_t<
|
||||||
|
transpose,
|
||||||
|
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
|
||||||
|
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
|
||||||
|
|
||||||
|
auto swizzled =
|
||||||
|
make_composed_layout(Swizzle<swizzle_bits, 3, 3>{}, 0, BaseLayout{});
|
||||||
|
|
||||||
|
return tile_to_shape(swizzled, smem);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
|
||||||
|
constexpr auto make_result_smem_layout(Tiler smem) {
|
||||||
|
constexpr int feature_size = get_feature_size<transpose>(smem);
|
||||||
|
constexpr int swizzle_bits =
|
||||||
|
get_swizzle_bits<feature_size, itemsize, copy_bits>();
|
||||||
|
|
||||||
|
using F = Int<feature_size>;
|
||||||
|
using BaseLayout = std::conditional_t<
|
||||||
|
transpose,
|
||||||
|
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
|
||||||
|
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
|
||||||
|
|
||||||
|
auto swizzled = make_composed_layout(
|
||||||
|
Swizzle<transpose ? 0 : swizzle_bits, 3, 4>{}, 0, BaseLayout{});
|
||||||
|
|
||||||
|
return tile_to_shape(swizzled, smem);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
int num_threads,
|
||||||
|
int itemsize,
|
||||||
|
bool transpose,
|
||||||
|
int copy_bits,
|
||||||
|
typename Copier,
|
||||||
|
typename Tiler>
|
||||||
|
constexpr auto make_tiled_copy(Copier copy_op, Tiler smem) {
|
||||||
|
constexpr int num_elements = copy_bits / itemsize;
|
||||||
|
constexpr int feature_size = transpose ? size<0>(smem) : size<1>(smem);
|
||||||
|
constexpr int copies_per_feature = feature_size / num_elements;
|
||||||
|
|
||||||
|
using E = Int<num_elements>;
|
||||||
|
using C = Int<copies_per_feature>;
|
||||||
|
using R = Int<num_threads / copies_per_feature>;
|
||||||
|
|
||||||
|
using ThreadLayout = std::conditional_t<
|
||||||
|
transpose,
|
||||||
|
Layout<cute::Shape<C, R>, cute::Stride<_1, C>>,
|
||||||
|
Layout<cute::Shape<R, C>, cute::Stride<C, _1>>>;
|
||||||
|
using ValueLayout = std::conditional_t<
|
||||||
|
transpose,
|
||||||
|
Layout<cute::Shape<E, _1>>,
|
||||||
|
Layout<cute::Shape<_1, E>>>;
|
||||||
|
|
||||||
|
return make_tiled_copy(copy_op, ThreadLayout{}, ValueLayout{});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int rasterization_factor>
|
||||||
|
__device__ inline int2 raster_tile(int x, int y) {
|
||||||
|
return {
|
||||||
|
x / rasterization_factor,
|
||||||
|
(x % rasterization_factor) + y * rasterization_factor};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename SLayoutA,
|
||||||
|
typename SLayoutB,
|
||||||
|
typename SLayoutC,
|
||||||
|
typename CopyA,
|
||||||
|
typename CopyB,
|
||||||
|
typename CopyC,
|
||||||
|
typename MMA,
|
||||||
|
int rasterization_factor>
|
||||||
|
__global__ static __launch_bounds__(decltype(size(MMA{}))::value) void matmul_kernel(
|
||||||
|
const T* __restrict__ A,
|
||||||
|
const T* __restrict__ B,
|
||||||
|
T* __restrict__ C,
|
||||||
|
SLayoutA SA,
|
||||||
|
SLayoutB SB,
|
||||||
|
SLayoutC SC,
|
||||||
|
CopyA copy_a,
|
||||||
|
CopyB copy_b,
|
||||||
|
CopyC copy_c,
|
||||||
|
MMA mma,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K) {
|
||||||
|
constexpr auto BM = size<0>(SA);
|
||||||
|
constexpr auto BN = size<0>(SB);
|
||||||
|
constexpr auto BK = size<1>(SA);
|
||||||
|
constexpr auto PIPE = size<2>(SA);
|
||||||
|
|
||||||
|
const int2 tile = raster_tile<rasterization_factor>(blockIdx.x, blockIdx.y);
|
||||||
|
const int blocks_m = ceil_div(M, BM);
|
||||||
|
const int blocks_n = ceil_div(N, BN);
|
||||||
|
|
||||||
|
// Exit early if the tile is OOB
|
||||||
|
if (tile.x >= blocks_m || tile.y >= blocks_n) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the full tensors
|
||||||
|
Tensor full_A =
|
||||||
|
make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{}));
|
||||||
|
Tensor full_B =
|
||||||
|
make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{}));
|
||||||
|
Tensor full_C =
|
||||||
|
make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{}));
|
||||||
|
|
||||||
|
// Partition the tensors into tiles and select the ones for this threadblock
|
||||||
|
Tensor local_A =
|
||||||
|
local_tile(full_A, make_shape(BM, BK), make_coord(tile.x, _));
|
||||||
|
Tensor local_B =
|
||||||
|
local_tile(full_B, make_shape(BN, BK), make_coord(tile.y, _));
|
||||||
|
Tensor local_C =
|
||||||
|
local_tile(full_C, make_shape(BM, BN), make_coord(tile.x, tile.y));
|
||||||
|
|
||||||
|
// Make shared memory tensors
|
||||||
|
extern __shared__ char shared_memory[];
|
||||||
|
T* shared_A_ptr = reinterpret_cast<T*>(shared_memory);
|
||||||
|
T* shared_B_ptr =
|
||||||
|
reinterpret_cast<T*>(shared_memory + cosize(SA) * sizeof(T));
|
||||||
|
T* shared_C_ptr = reinterpret_cast<T*>(shared_memory);
|
||||||
|
Tensor shared_A = make_tensor(make_smem_ptr(shared_A_ptr), SA);
|
||||||
|
Tensor shared_B = make_tensor(make_smem_ptr(shared_B_ptr), SB);
|
||||||
|
Tensor shared_C = make_tensor(make_smem_ptr(shared_C_ptr), SC);
|
||||||
|
|
||||||
|
// Get the copies that correspond to this thread
|
||||||
|
auto thread_copy_a = copy_a.get_slice(threadIdx.x);
|
||||||
|
Tensor local_A_src = thread_copy_a.partition_S(local_A);
|
||||||
|
Tensor local_A_dst = thread_copy_a.partition_D(shared_A);
|
||||||
|
auto thread_copy_b = copy_b.get_slice(threadIdx.x);
|
||||||
|
Tensor local_B_src = thread_copy_a.partition_S(local_B);
|
||||||
|
Tensor local_B_dst = thread_copy_a.partition_D(shared_B);
|
||||||
|
auto thread_copy_c = copy_c.get_slice(threadIdx.x);
|
||||||
|
Tensor local_C_src = thread_copy_c.partition_S(shared_C);
|
||||||
|
Tensor local_C_dst = thread_copy_c.partition_D(local_C);
|
||||||
|
|
||||||
|
// Start fetches
|
||||||
|
int k_tile_count = size<2>(local_A);
|
||||||
|
int k_tile_next = 0;
|
||||||
|
CUTE_UNROLL
|
||||||
|
for (int k = 0; k < PIPE - 1; k++) {
|
||||||
|
copy(copy_a, local_A_src(_, _, _, k_tile_next), local_A_dst(_, _, _, k));
|
||||||
|
copy(copy_b, local_B_src(_, _, _, k_tile_next), local_B_dst(_, _, _, k));
|
||||||
|
cp_async_fence();
|
||||||
|
k_tile_count--;
|
||||||
|
k_tile_next += (k_tile_count > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the MMA that corresponds to this thread and allocate registers
|
||||||
|
auto thread_mma = mma.get_slice(threadIdx.x);
|
||||||
|
Tensor mma_shared_A = thread_mma.partition_A(shared_A);
|
||||||
|
Tensor mma_shared_B = thread_mma.partition_B(shared_B);
|
||||||
|
Tensor mma_shared_C = thread_mma.partition_C(shared_C);
|
||||||
|
Tensor mma_global_C = thread_mma.partition_C(local_C);
|
||||||
|
Tensor mma_frag_A = mma.make_fragment_A(mma_shared_A(_, _, _, 0));
|
||||||
|
Tensor mma_frag_B = mma.make_fragment_B(mma_shared_B(_, _, _, 0));
|
||||||
|
Tensor mma_frag_C = mma.make_fragment_C(mma_global_C);
|
||||||
|
clear(mma_frag_C);
|
||||||
|
|
||||||
|
// Make shared to register copies
|
||||||
|
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_a;
|
||||||
|
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_b;
|
||||||
|
auto s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
|
||||||
|
auto s2r_thread_copy_a = s2r_copy_a.get_slice(threadIdx.x);
|
||||||
|
auto s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
|
||||||
|
auto s2r_thread_copy_b = s2r_copy_b.get_slice(threadIdx.x);
|
||||||
|
Tensor mma_A_src = s2r_thread_copy_a.partition_S(shared_A);
|
||||||
|
Tensor mma_A_dst = s2r_thread_copy_a.retile_D(mma_frag_A);
|
||||||
|
Tensor mma_B_src = s2r_thread_copy_b.partition_S(shared_B);
|
||||||
|
Tensor mma_B_dst = s2r_thread_copy_b.retile_D(mma_frag_B);
|
||||||
|
|
||||||
|
constexpr auto RPIPE = size<2>(mma_shared_A);
|
||||||
|
int smem_read = 0;
|
||||||
|
int smem_write = PIPE - 1;
|
||||||
|
Tensor mma_A_src_p = mma_A_src(_, _, _, smem_read);
|
||||||
|
Tensor mma_B_src_p = mma_B_src(_, _, _, smem_read);
|
||||||
|
|
||||||
|
// Start the register pipeline
|
||||||
|
if constexpr (RPIPE > 1) {
|
||||||
|
cp_async_wait<PIPE - 2>();
|
||||||
|
__syncthreads();
|
||||||
|
copy(s2r_copy_a, mma_A_src_p(_, _, Int<0>{}), mma_A_dst(_, _, Int<0>{}));
|
||||||
|
copy(s2r_copy_b, mma_B_src_p(_, _, Int<0>{}), mma_B_dst(_, _, Int<0>{}));
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTE_NO_UNROLL
|
||||||
|
while (k_tile_count > -(PIPE - 1)) {
|
||||||
|
CUTE_UNROLL
|
||||||
|
for (int k_block = 0; k_block < RPIPE; k_block++) {
|
||||||
|
if (k_block == RPIPE - 1) {
|
||||||
|
mma_A_src_p = mma_A_src(_, _, _, smem_read);
|
||||||
|
mma_B_src_p = mma_B_src(_, _, _, smem_read);
|
||||||
|
cp_async_wait<PIPE - 2>();
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the next register tile
|
||||||
|
auto k_block_next = (k_block + 1) % RPIPE;
|
||||||
|
copy(
|
||||||
|
s2r_copy_a,
|
||||||
|
mma_A_src_p(_, _, k_block_next),
|
||||||
|
mma_A_dst(_, _, k_block_next));
|
||||||
|
copy(
|
||||||
|
s2r_copy_b,
|
||||||
|
mma_B_src_p(_, _, k_block_next),
|
||||||
|
mma_B_dst(_, _, k_block_next));
|
||||||
|
|
||||||
|
if (k_block == 0) {
|
||||||
|
copy(
|
||||||
|
copy_a,
|
||||||
|
local_A_src(_, _, _, k_tile_next),
|
||||||
|
local_A_dst(_, _, _, smem_write));
|
||||||
|
copy(
|
||||||
|
copy_b,
|
||||||
|
local_B_src(_, _, _, k_tile_next),
|
||||||
|
local_B_dst(_, _, _, smem_write));
|
||||||
|
cp_async_fence();
|
||||||
|
k_tile_count--;
|
||||||
|
k_tile_next += (k_tile_count > 0);
|
||||||
|
smem_write = smem_read;
|
||||||
|
smem_read = (smem_read == PIPE - 1) ? 0 : (smem_read + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
gemm(
|
||||||
|
mma,
|
||||||
|
mma_frag_A(_, _, k_block),
|
||||||
|
mma_frag_B(_, _, k_block),
|
||||||
|
mma_frag_C);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(mma_frag_C, mma_shared_C);
|
||||||
|
__syncthreads();
|
||||||
|
copy(copy_c, local_C_src, local_C_dst);
|
||||||
|
|
||||||
|
// if (threadIdx.x == 0) {
|
||||||
|
// print("fC: "); print(mma_frag_C); print("\n");
|
||||||
|
// print("sC: "); print(mma_shared_C); print("\n");
|
||||||
|
// print("dC: "); print(local_C_dst); print("\n");
|
||||||
|
//
|
||||||
|
// print(s2r_atom_a); print("\n");
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void cutlass_gemm(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
cu::CommandEncoder& enc) {
|
||||||
|
enc.set_input_array(a);
|
||||||
|
enc.set_input_array(b);
|
||||||
|
enc.set_output_array(out);
|
||||||
|
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
// Tile definitions
|
||||||
|
auto BM = Int<128>{};
|
||||||
|
auto BN = Int<128>{};
|
||||||
|
auto BK = Int<64>{};
|
||||||
|
auto BP = Int<3>{};
|
||||||
|
auto GM = Int<8>{};
|
||||||
|
|
||||||
|
// Thread definitions
|
||||||
|
using TM = Int<2>;
|
||||||
|
using TN = Int<2>;
|
||||||
|
using TK = Int<1>;
|
||||||
|
constexpr int num_threads = TM::value * TN::value * 32;
|
||||||
|
|
||||||
|
auto SA = make_smem_layout<16, false, 128>(make_shape(BM, BK, BP));
|
||||||
|
auto SB = make_smem_layout<16, false, 128>(make_shape(BN, BK, BP));
|
||||||
|
auto SC = make_result_smem_layout<16, false, 128>(make_shape(BM, BN));
|
||||||
|
|
||||||
|
constexpr auto smem_size = (cosize(SA) + cosize(SB)) * sizeof(bf16);
|
||||||
|
|
||||||
|
auto async_copy_op =
|
||||||
|
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, bf16>{};
|
||||||
|
auto tiled_copy_a = make_tiled_copy<num_threads, 16, false, 128>(
|
||||||
|
async_copy_op, make_shape(BM, BK));
|
||||||
|
auto tiled_copy_b = make_tiled_copy<num_threads, 16, false, 128>(
|
||||||
|
async_copy_op, make_shape(BN, BK));
|
||||||
|
|
||||||
|
auto sync_copy_op = Copy_Atom<UniversalCopy<uint128_t>, bf16>{};
|
||||||
|
auto tiled_copy_c = make_tiled_copy<num_threads, 16, false, 128>(
|
||||||
|
sync_copy_op, make_shape(BM, BN));
|
||||||
|
|
||||||
|
auto mma_op = SM80_16x8x16_F32BF16BF16F32_TN{};
|
||||||
|
auto tiled_mma = make_tiled_mma(
|
||||||
|
mma_op, Layout<cute::Shape<TM, TN, TK>>{}, Tile<_32, _32, _16>{});
|
||||||
|
|
||||||
|
auto kernel = matmul_kernel<
|
||||||
|
bf16,
|
||||||
|
decltype(SA),
|
||||||
|
decltype(SB),
|
||||||
|
decltype(SC),
|
||||||
|
decltype(tiled_copy_a),
|
||||||
|
decltype(tiled_copy_b),
|
||||||
|
decltype(tiled_copy_c),
|
||||||
|
decltype(tiled_mma),
|
||||||
|
GM.value>;
|
||||||
|
configure_matmul(kernel, smem_size);
|
||||||
|
|
||||||
|
dim3 block(size(tiled_mma));
|
||||||
|
dim3 grid(
|
||||||
|
size(ceil_div(M, BM) * GM), size(ceil_div(ceil_div(N, BN), GM)));
|
||||||
|
|
||||||
|
enc.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid,
|
||||||
|
block,
|
||||||
|
smem_size,
|
||||||
|
a.data<bf16>(),
|
||||||
|
b.data<bf16>(),
|
||||||
|
out.data<bf16>(),
|
||||||
|
SA,
|
||||||
|
SB,
|
||||||
|
SC,
|
||||||
|
tiled_copy_a,
|
||||||
|
tiled_copy_b,
|
||||||
|
tiled_copy_c,
|
||||||
|
tiled_mma,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Only bfloat16 supported");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
18
mlx/backend/cuda/gemms/cutlass_gemm.h
Normal file
18
mlx/backend/cuda/gemms/cutlass_gemm.h
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
void cutlass_gemm(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
cu::CommandEncoder& enc);
|
||||||
|
|
||||||
|
}
|
||||||
69
mlx/backend/cuda/gemms/simple_gemm.cu
Normal file
69
mlx/backend/cuda/gemms/simple_gemm.cu
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/steel/gemm.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename Kernel>
|
||||||
|
static void configure_smem(Kernel kernel, int SM) {
|
||||||
|
static bool done = false;
|
||||||
|
if (done) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::cout << "configuring" << std::endl;
|
||||||
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SM);
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
kernel,
|
||||||
|
cudaFuncAttributePreferredSharedMemoryCarveout,
|
||||||
|
cudaSharedmemCarveoutMaxShared);
|
||||||
|
done = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void simple_gemm(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
cu::CommandEncoder& enc) {
|
||||||
|
enc.set_input_array(a);
|
||||||
|
enc.set_input_array(b);
|
||||||
|
enc.set_output_array(out);
|
||||||
|
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
constexpr int BM = 128;
|
||||||
|
constexpr int BN = 128;
|
||||||
|
constexpr int BK = 32;
|
||||||
|
constexpr int PIPE = 3;
|
||||||
|
constexpr int SM = PIPE * sizeof(DataType) * (BM * BK + BN * BK);
|
||||||
|
constexpr int WM = 2;
|
||||||
|
constexpr int WN = 4;
|
||||||
|
|
||||||
|
auto kernel = ab_t_aligned<DataType, BM, BN, BK, WM, WN, PIPE>;
|
||||||
|
configure_smem(kernel, SM);
|
||||||
|
|
||||||
|
dim3 grid(N / BN, M / BM);
|
||||||
|
enc.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid,
|
||||||
|
WM * WN * WARP_SIZE,
|
||||||
|
SM,
|
||||||
|
a.data<DataType>(),
|
||||||
|
b.data<DataType>(),
|
||||||
|
out.data<DataType>(),
|
||||||
|
N,
|
||||||
|
K);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
18
mlx/backend/cuda/gemms/simple_gemm.h
Normal file
18
mlx/backend/cuda/gemms/simple_gemm.h
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
void simple_gemm(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
cu::CommandEncoder& enc);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -3,7 +3,9 @@
|
|||||||
#include "mlx/backend/common/matmul.h"
|
#include "mlx/backend/common/matmul.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
#include "mlx/backend/cuda/gemms/cutlass_gemm.h"
|
||||||
#include "mlx/backend/cuda/gemms/gemv.h"
|
#include "mlx/backend/cuda/gemms/gemv.h"
|
||||||
|
#include "mlx/backend/cuda/gemms/simple_gemm.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@@ -11,8 +13,14 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
int get_test_gemm() {
|
||||||
|
static int t = env::get_var("MLX_ENABLE_TEST_GEMM", 0);
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<bool, int64_t, array>
|
std::tuple<bool, int64_t, array>
|
||||||
check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
||||||
auto stx = arr.strides()[arr.ndim() - 2];
|
auto stx = arr.strides()[arr.ndim() - 2];
|
||||||
@@ -95,6 +103,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
|
||||||
|
b_transposed && batch_count == 1 && get_test_gemm() == 1) {
|
||||||
|
cu::simple_gemm(a, b, out, M, N, K, encoder);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
|
||||||
|
b_transposed && batch_count == 1 && get_test_gemm() == 2) {
|
||||||
|
cu::cutlass_gemm(a, b, out, M, N, K, encoder);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt
|
||||||
CublasGemm gemm(
|
CublasGemm gemm(
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ NO_GPU_MULTI(Eig)
|
|||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
|
NO_GPU_MULTI(AllReduce)
|
||||||
NO_GPU_MULTI(AllGather)
|
NO_GPU_MULTI(AllGather)
|
||||||
NO_GPU_MULTI(Send)
|
NO_GPU_MULTI(Send)
|
||||||
NO_GPU_MULTI(Recv)
|
NO_GPU_MULTI(Recv)
|
||||||
|
|||||||
@@ -4,95 +4,189 @@
|
|||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
template <typename T, int BM, int BN, int BK, int WM, int WN>
|
||||||
|
__device__ inline void gemm_ab_t(
|
||||||
|
RegisterTile<float, BM / WM, BN / WN>& C,
|
||||||
|
SharedTile<T, BM, BK>& As,
|
||||||
|
SharedTile<T, BN, BK>& Bs,
|
||||||
|
RegisterTileLoader<SharedTile<T, BM, BK>>& rloader_a,
|
||||||
|
RegisterTileLoader<SharedTile<T, BN, BK>>& rloader_b) {
|
||||||
|
RegisterTile<T, BM / WM, 16> A[2];
|
||||||
|
RegisterTile<T, BN / WN, 16> B[2];
|
||||||
|
|
||||||
|
rloader_a.load(A[0], As.base_addr(), 0);
|
||||||
|
rloader_b.load(B[0], Bs.base_addr(), 0);
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int k = 1; k < BK / 16; k++) {
|
||||||
|
rloader_a.load(A[k & 1], As.base_addr(), k);
|
||||||
|
rloader_b.load(B[k & 1], Bs.base_addr(), k);
|
||||||
|
|
||||||
|
mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]);
|
||||||
|
}
|
||||||
|
mma_t(C, A[(BK / 16 - 1) & 1], B[(BK / 16 - 1) & 1]);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An example gemm written with the utils.
|
* An example gemm written with the utils.
|
||||||
*
|
*
|
||||||
* Computes A @ B.T when A and B are all aligned with the block sizes.
|
* Computes A @ B.T when A and B are all aligned with the block sizes.
|
||||||
*/
|
*/
|
||||||
template <typename T, int BM, int BN, int BK>
|
// template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
|
||||||
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
//__global__ __launch_bounds__(WM * WN * WARP_SIZE, 1)
|
||||||
constexpr int WARPS_M = 2;
|
// void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||||
constexpr int WARPS_N = 2;
|
// constexpr int NUM_WARPS = WM * WN;
|
||||||
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
|
// constexpr int WARP_STEP_M = BM / WM;
|
||||||
constexpr int WARP_STEP_M = BM / WARPS_M;
|
// constexpr int WARP_STEP_N = BN / WN;
|
||||||
constexpr int WARP_STEP_N = BN / WARPS_N;
|
//
|
||||||
|
// // Precompute some offsets for each thread
|
||||||
|
// const int warpid = threadIdx.x / 32;
|
||||||
|
// const int laneid = threadIdx.x % 32;
|
||||||
|
// const int wm = warpid / WN;
|
||||||
|
// const int wn = warpid % WN;
|
||||||
|
// const int offset_m = wm * WARP_STEP_M;
|
||||||
|
// const int offset_n = wn * WARP_STEP_N;
|
||||||
|
//
|
||||||
|
// // Allocate shared memory
|
||||||
|
// extern __shared__ char shmem[];
|
||||||
|
// SharedTile<T, BM, BK>(&as)[PIPE] =
|
||||||
|
// *(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
|
||||||
|
// SharedTile<T, BN, BK>(&bs)[PIPE] =
|
||||||
|
// *(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
|
||||||
|
//
|
||||||
|
// // Move the global pointers to the tile
|
||||||
|
// a += blockIdx.y * BM * K;
|
||||||
|
// b += blockIdx.x * BN * K;
|
||||||
|
// y += blockIdx.y * BM * N + blockIdx.x * BN;
|
||||||
|
//
|
||||||
|
// // Make the loaders to/from SMEM
|
||||||
|
// SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>> sloader_a(a, K);
|
||||||
|
// SharedTileLoader<NUM_WARPS, SharedTile<T, BN, BK>> sloader_b(b, K);
|
||||||
|
// RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
|
||||||
|
// RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
|
||||||
|
//
|
||||||
|
// // Start the SM pipeline
|
||||||
|
// MLX_UNROLL
|
||||||
|
// for (int i = 0; i < PIPE - 1; i++) {
|
||||||
|
// sloader_a.load_async(as[i].base_addr());
|
||||||
|
// sloader_b.load_async(bs[i].base_addr());
|
||||||
|
// cp_async_commit();
|
||||||
|
// sloader_a.next();
|
||||||
|
// sloader_b.next();
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Allocate and zero the MMA accumulator
|
||||||
|
// RegisterTile<float, BM / WM, BN / WN> C;
|
||||||
|
// C.fill(0);
|
||||||
|
//
|
||||||
|
// // Matmul loop
|
||||||
|
// int num_blocks = K / BK;
|
||||||
|
// int sread = 0;
|
||||||
|
// int swrite = PIPE - 1;
|
||||||
|
// for (int i = 0; i < num_blocks; i++) {
|
||||||
|
// cp_async_wait<PIPE - 1>();
|
||||||
|
//
|
||||||
|
// gemm_ab_t<T, BM, BN, BK, WM, WN>(
|
||||||
|
// C, as[sread], bs[sread], rloader_a, rloader_b);
|
||||||
|
//
|
||||||
|
// sloader_a.load_async(as[swrite].base_addr());
|
||||||
|
// sloader_b.load_async(bs[swrite].base_addr());
|
||||||
|
// cp_async_commit();
|
||||||
|
// sloader_a.next(i + PIPE < num_blocks);
|
||||||
|
// sloader_b.next(i + PIPE < num_blocks);
|
||||||
|
//
|
||||||
|
// swrite = sread;
|
||||||
|
// sread = (sread + 1) % PIPE;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// C.store_global(y, N, offset_m, offset_n);
|
||||||
|
// }
|
||||||
|
|
||||||
|
template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
|
||||||
|
__global__ __launch_bounds__(
|
||||||
|
WM* WN* WARP_SIZE,
|
||||||
|
1) void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||||
|
constexpr int NUM_WARPS = WM * WN;
|
||||||
|
constexpr int WARP_STEP_M = BM / WM;
|
||||||
|
constexpr int WARP_STEP_N = BN / WN;
|
||||||
|
|
||||||
// Precompute some offsets for each thread
|
// Precompute some offsets for each thread
|
||||||
const int warpid = threadIdx.x / 32;
|
const int warpid = threadIdx.x / 32;
|
||||||
const int laneid = threadIdx.x % 32;
|
const int laneid = threadIdx.x % 32;
|
||||||
const int wm = warpid / WARPS_N;
|
const int wm = warpid / WN;
|
||||||
const int wn = warpid % WARPS_N;
|
const int wn = warpid % WN;
|
||||||
const int offset_m = wm * WARP_STEP_M;
|
const int offset_m = wm * WARP_STEP_M;
|
||||||
const int offset_n = wn * WARP_STEP_N;
|
const int offset_n = wn * WARP_STEP_N;
|
||||||
|
|
||||||
// Allocate shared memory
|
// Allocate shared memory
|
||||||
extern __shared__ char shmem[];
|
extern __shared__ char shmem[];
|
||||||
SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]);
|
SharedTile<T, BM, BK>(&as)[PIPE] =
|
||||||
SharedTile<T, BN, BK>(&bs)[2] =
|
*(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
|
||||||
*(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]);
|
SharedTile<T, BN, BK>(&bs)[PIPE] =
|
||||||
|
*(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
|
||||||
// Allocate registers for the MMA
|
|
||||||
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
|
|
||||||
RegisterTile<T, BM / WARPS_M, 16> A;
|
|
||||||
RegisterTile<T, BN / WARPS_N, 16> B;
|
|
||||||
|
|
||||||
// Move the global pointers to the tile
|
// Move the global pointers to the tile
|
||||||
a += blockIdx.y * BM * K;
|
a += blockIdx.y * BM * K;
|
||||||
b += blockIdx.x * BN * K;
|
b += blockIdx.x * BN * K;
|
||||||
y += blockIdx.y * BM * N + blockIdx.x * BN;
|
y += blockIdx.y * BM * N + blockIdx.x * BN;
|
||||||
|
|
||||||
// Zero the accumulators
|
// Make the loaders to/from SMEM
|
||||||
C.fill(0);
|
using sloader = SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>>;
|
||||||
|
constexpr int SSTEP = sloader::STEP_ROWS * sizeof(T) * BK;
|
||||||
|
const int srow = threadIdx.x / sloader::NUM_LOADS_PER_ROW;
|
||||||
|
const int scol =
|
||||||
|
(threadIdx.x % sloader::NUM_LOADS_PER_ROW) * sloader::ELEMENTS_PER_LOAD;
|
||||||
|
a += srow * K + scol;
|
||||||
|
b += srow * K + scol;
|
||||||
|
uint32_t sm_offsets[PIPE][2];
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int s = 0; s < PIPE; s++) {
|
||||||
|
sm_offsets[s][0] = as[s].loc(as[s].base_addr(), srow, scol);
|
||||||
|
sm_offsets[s][1] = bs[s].loc(bs[s].base_addr(), srow, scol);
|
||||||
|
}
|
||||||
|
RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
|
||||||
|
RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
|
||||||
|
|
||||||
// Start the SM pipeline
|
// Start the SM pipeline
|
||||||
load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K);
|
MLX_UNROLL
|
||||||
load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K);
|
for (int s = 0; s < PIPE - 1; s++) {
|
||||||
cp_async_commit();
|
|
||||||
|
|
||||||
int tic = 0;
|
|
||||||
for (int k_block = BK; k_block < K; k_block += BK) {
|
|
||||||
load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);
|
|
||||||
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
|
|
||||||
cp_async_commit();
|
|
||||||
cp_async_wait<1>();
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
MLX_UNROLL
|
MLX_UNROLL
|
||||||
for (int k = 0; k < BK / 16; k++) {
|
for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
|
||||||
A.load(
|
cp_async<16>(sm_offsets[s][0] + l * SSTEP, a);
|
||||||
as[tic],
|
cp_async<16>(sm_offsets[s][1] + l * SSTEP, b);
|
||||||
as[tic].base_addr(),
|
a += sloader::STEP_ROWS * K;
|
||||||
offset_m + laneid % 16,
|
b += sloader::STEP_ROWS * K;
|
||||||
k * 16 + laneid / 16 * 8);
|
|
||||||
B.load(
|
|
||||||
bs[tic],
|
|
||||||
bs[tic].base_addr(),
|
|
||||||
offset_n + laneid % 16,
|
|
||||||
k * 16 + laneid / 16 * 8);
|
|
||||||
|
|
||||||
mma_t(C, A, B);
|
|
||||||
}
|
}
|
||||||
|
cp_async_commit();
|
||||||
tic ^= 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Empty the pipeline
|
// Allocate and zero the MMA accumulator
|
||||||
cp_async_wait_all();
|
RegisterTile<float, BM / WM, BN / WN> C;
|
||||||
__syncthreads();
|
C.fill(0);
|
||||||
MLX_UNROLL
|
|
||||||
for (int k = 0; k < BK / 16; k++) {
|
|
||||||
A.load(
|
|
||||||
as[tic],
|
|
||||||
as[tic].base_addr(),
|
|
||||||
offset_m + laneid % 16,
|
|
||||||
k * 16 + laneid / 16 * 8);
|
|
||||||
B.load(
|
|
||||||
bs[tic],
|
|
||||||
bs[tic].base_addr(),
|
|
||||||
offset_n + laneid % 16,
|
|
||||||
k * 16 + laneid / 16 * 8);
|
|
||||||
|
|
||||||
mma_t(C, A, B);
|
// Matmul loop
|
||||||
|
int num_blocks = K / BK;
|
||||||
|
int sread = 0;
|
||||||
|
int swrite = PIPE - 1;
|
||||||
|
for (int i = 0; i < num_blocks; i++) {
|
||||||
|
cp_async_wait<PIPE - 1>();
|
||||||
|
|
||||||
|
gemm_ab_t<T, BM, BN, BK, WM, WN>(
|
||||||
|
C, as[sread], bs[sread], rloader_a, rloader_b);
|
||||||
|
|
||||||
|
if (false) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
|
||||||
|
cp_async<16>(sm_offsets[swrite][0] + l * SSTEP, a);
|
||||||
|
cp_async<16>(sm_offsets[swrite][1] + l * SSTEP, b);
|
||||||
|
a += sloader::STEP_ROWS * K;
|
||||||
|
b += sloader::STEP_ROWS * K;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cp_async_commit();
|
||||||
|
|
||||||
|
swrite = sread;
|
||||||
|
sread = (sread + 1) % PIPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
C.store_global(y, N, offset_m, offset_n);
|
C.store_global(y, N, offset_m, offset_n);
|
||||||
|
|||||||
@@ -223,59 +223,10 @@ struct RegisterTile {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
|
||||||
* A simple container of multiple Tile16x16.
|
|
||||||
*
|
|
||||||
* Provides utility functions for loading and manipulating collections of basic
|
|
||||||
* tiles.
|
|
||||||
*/
|
|
||||||
template <typename T, int ROWS_, int COLS_>
|
|
||||||
struct RegisterTile {
|
|
||||||
static constexpr int ROWS = ROWS_;
|
|
||||||
static constexpr int COLS = COLS_;
|
|
||||||
static constexpr int TILES_X = COLS / 16;
|
|
||||||
static constexpr int TILES_Y = ROWS / 16;
|
|
||||||
|
|
||||||
Tile16x16<T> data[TILES_X * TILES_Y];
|
|
||||||
|
|
||||||
__device__ inline void fill(T v) {
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
|
||||||
data[i * TILES_X + j].fill(v);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Tile>
|
|
||||||
__device__ inline void
|
|
||||||
load(Tile& tile, uint32_t base_address, int row, int col) {
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
|
||||||
data[i * TILES_X + j].load(
|
|
||||||
tile.loc(base_address, row + i * 16, col + j * 16));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename U>
|
|
||||||
__device__ inline void store_global(U* x, int N, int row, int col) {
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
|
||||||
data[i * TILES_X + j].store_global(
|
|
||||||
x + (row + i * 16) * N + col + j * 16, N);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, int ROWS_, int COLS_>
|
template <typename T, int ROWS_, int COLS_>
|
||||||
struct SharedTile {
|
struct SharedTile {
|
||||||
|
using value_type = T;
|
||||||
|
|
||||||
static constexpr int ROWS = ROWS_;
|
static constexpr int ROWS = ROWS_;
|
||||||
static constexpr int COLS = COLS_;
|
static constexpr int COLS = COLS_;
|
||||||
static constexpr int TILES_X = COLS / 16;
|
static constexpr int TILES_X = COLS / 16;
|
||||||
@@ -317,23 +268,26 @@ struct SharedTile {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the location of the element at (row, col) using the swizzle.
|
__device__ static inline uint32_t offset(int row, int col) {
|
||||||
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
|
|
||||||
if constexpr (swizzle_bytes > 0) {
|
if constexpr (swizzle_bytes > 0) {
|
||||||
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
||||||
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
||||||
const int outer_idx = col / subtile_cols;
|
const int outer_idx = col / subtile_cols;
|
||||||
const uint32_t addr = ptr +
|
const uint32_t addr = sizeof(T) *
|
||||||
sizeof(T) *
|
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
||||||
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
col % subtile_cols);
|
||||||
col % subtile_cols);
|
|
||||||
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
||||||
return (addr ^ swizzle);
|
return (addr ^ swizzle);
|
||||||
} else {
|
} else {
|
||||||
return ptr + sizeof(T) * (row * COLS + col);
|
return sizeof(T) * (row * COLS + col);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the location of the element at (row, col) using the swizzle.
|
||||||
|
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
|
||||||
|
return ptr + offset(row, col);
|
||||||
|
}
|
||||||
|
|
||||||
// Convenience functions to edit elements going through the swizzle.
|
// Convenience functions to edit elements going through the swizzle.
|
||||||
__device__ inline T& operator()(int row, int col) {
|
__device__ inline T& operator()(int row, int col) {
|
||||||
return *ptr(data, row, col);
|
return *ptr(data, row, col);
|
||||||
@@ -364,6 +318,76 @@ struct SharedTile {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <int NUM_WARPS, typename Tile>
|
||||||
|
struct SharedTileLoader {
|
||||||
|
using T = typename Tile::value_type;
|
||||||
|
|
||||||
|
static constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||||
|
static constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||||
|
static constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||||
|
static constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||||
|
static constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||||
|
static constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
const T* x_;
|
||||||
|
int N_;
|
||||||
|
uint32_t offset_;
|
||||||
|
|
||||||
|
__device__ SharedTileLoader(const T* x, int N) : x_(x), N_(N) {
|
||||||
|
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||||
|
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
x_ += row * N + col * ELEMENTS_PER_LOAD;
|
||||||
|
offset_ = Tile::offset(row, col * ELEMENTS_PER_LOAD);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline void load_async(uint32_t base_address) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||||
|
cp_async<16>(
|
||||||
|
base_address + offset_ + i * STEP_ROWS * sizeof(T) * Tile::COLS,
|
||||||
|
x_ + i * STEP_ROWS * N_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline void next() {
|
||||||
|
x_ += Tile::COLS;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Tile>
|
||||||
|
struct RegisterTileLoader {
|
||||||
|
using T = typename Tile::value_type;
|
||||||
|
|
||||||
|
uint32_t offset_[Tile::COLS / 16];
|
||||||
|
|
||||||
|
__device__ RegisterTileLoader(int offset_row, int laneid) {
|
||||||
|
const int row = offset_row + laneid & 15;
|
||||||
|
const int col = (laneid >> 4) << 3;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < Tile::COLS / 16; i++) {
|
||||||
|
offset_[i] = Tile::offset(row, col + i * 16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int ROWS, int COLS>
|
||||||
|
__device__ inline void
|
||||||
|
load(RegisterTile<T, ROWS, COLS>& x, uint32_t base_address, int col) {
|
||||||
|
constexpr int TILES_Y = RegisterTile<T, ROWS, COLS>::TILES_Y;
|
||||||
|
constexpr int TILES_X = RegisterTile<T, ROWS, COLS>::TILES_X;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
x.data[i * TILES_X + j].load(
|
||||||
|
base_address + offset_[j + col] + i * 16 * Tile::COLS * sizeof(T));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load the tile from global memory by loading 16 bytes at a time and storing
|
* Load the tile from global memory by loading 16 bytes at a time and storing
|
||||||
* them immediately.
|
* them immediately.
|
||||||
|
|||||||
@@ -21,15 +21,15 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
|
|||||||
#if defined(MLX_CUDA_SM_80_ENABLED)
|
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||||
if constexpr (N == 16) {
|
if constexpr (N == 16) {
|
||||||
asm volatile(
|
asm volatile(
|
||||||
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
"cp.async.cg.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
||||||
"l"(reinterpret_cast<const int4*>(x)));
|
"l"(reinterpret_cast<const int4*>(x)));
|
||||||
} else if constexpr (N == 8) {
|
} else if constexpr (N == 8) {
|
||||||
asm volatile(
|
asm volatile(
|
||||||
"cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
|
"cp.async.cg.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
|
||||||
"l"(reinterpret_cast<const int2*>(x)));
|
"l"(reinterpret_cast<const int2*>(x)));
|
||||||
} else if constexpr (N == 4) {
|
} else if constexpr (N == 4) {
|
||||||
asm volatile(
|
asm volatile(
|
||||||
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
|
"cp.async.cg.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
|
||||||
"l"(reinterpret_cast<const int*>(x)));
|
"l"(reinterpret_cast<const int*>(x)));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -223,11 +223,6 @@ struct Power {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||||
T res = 1;
|
T res = 1;
|
||||||
// Undefined to raise integer to negative power
|
|
||||||
if (exp < 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
while (exp) {
|
while (exp) {
|
||||||
if (exp & 1) {
|
if (exp & 1) {
|
||||||
res *= base;
|
res *= base;
|
||||||
|
|||||||
@@ -6,4 +6,3 @@ target_sources(
|
|||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
|
||||||
|
|||||||
@@ -2,21 +2,15 @@
|
|||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "mlx/backend/cuda/cuda.h"
|
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
#include "mlx/distributed/nccl/nccl.h"
|
|
||||||
#include "mlx/distributed/ring/ring.h"
|
#include "mlx/distributed/ring/ring.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
Stream communication_stream(Group group, StreamOrDevice s /* = {} */) {
|
|
||||||
return group.raw_group()->communication_stream(s);
|
|
||||||
}
|
|
||||||
|
|
||||||
void all_sum(Group group, const array& input, array& output, Stream stream) {
|
void all_sum(Group group, const array& input, array& output, Stream stream) {
|
||||||
group.raw_group()->all_sum(input, output, stream);
|
group.raw_group()->all_sum(input, output, stream);
|
||||||
}
|
}
|
||||||
@@ -43,10 +37,6 @@ void recv(Group group, array& out, int src, Stream stream) {
|
|||||||
|
|
||||||
class EmptyGroup : public GroupImpl {
|
class EmptyGroup : public GroupImpl {
|
||||||
public:
|
public:
|
||||||
Stream communication_stream(StreamOrDevice s) override {
|
|
||||||
return to_stream(s);
|
|
||||||
}
|
|
||||||
|
|
||||||
int rank() override {
|
int rank() override {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@@ -90,7 +80,7 @@ class EmptyGroup : public GroupImpl {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
return mpi::is_available() || ring::is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
int Group::rank() const {
|
int Group::rank() const {
|
||||||
@@ -115,23 +105,15 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create the requested communication group
|
// Create the requested communication group
|
||||||
std::shared_ptr<detail::GroupImpl> group{nullptr};
|
std::shared_ptr<detail::GroupImpl> group;
|
||||||
std::string bk_ = bk;
|
std::string bk_ = bk;
|
||||||
if (bk == "mpi") {
|
if (bk == "mpi") {
|
||||||
group = mpi::init(strict);
|
group = mpi::init(strict);
|
||||||
} else if (bk == "ring") {
|
} else if (bk == "ring") {
|
||||||
group = ring::init(strict);
|
group = ring::init(strict);
|
||||||
} else if (bk == "nccl") {
|
|
||||||
group = nccl::init(strict);
|
|
||||||
} else if (bk == "any") {
|
} else if (bk == "any") {
|
||||||
if (mlx::core::cu::is_available()) {
|
group = ring::init(false);
|
||||||
group = nccl::init(false);
|
bk_ = "ring";
|
||||||
bk_ = "nccl";
|
|
||||||
}
|
|
||||||
if (group == nullptr) {
|
|
||||||
group = ring::init(false);
|
|
||||||
bk_ = "ring";
|
|
||||||
}
|
|
||||||
if (group == nullptr) {
|
if (group == nullptr) {
|
||||||
group = mpi::init(false);
|
group = mpi::init(false);
|
||||||
bk_ = "mpi";
|
bk_ = "mpi";
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
|
|||||||
@@ -13,15 +13,10 @@ class GroupImpl {
|
|||||||
public:
|
public:
|
||||||
virtual ~GroupImpl() {}
|
virtual ~GroupImpl() {}
|
||||||
|
|
||||||
// Choose the stream this communication group can operate on
|
|
||||||
virtual Stream communication_stream(StreamOrDevice s = {}) = 0;
|
|
||||||
|
|
||||||
// Group operations
|
|
||||||
virtual int rank() = 0;
|
virtual int rank() = 0;
|
||||||
virtual int size() = 0;
|
virtual int size() = 0;
|
||||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
||||||
|
|
||||||
// Actual communication operations
|
|
||||||
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
|
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
|
||||||
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
||||||
virtual void send(const array& input, int dst, Stream stream) = 0;
|
virtual void send(const array& input, int dst, Stream stream) = 0;
|
||||||
@@ -30,9 +25,6 @@ class GroupImpl {
|
|||||||
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Define the MLX stream that the communication should happen in. */
|
|
||||||
Stream communication_stream(Group group, StreamOrDevice s = {});
|
|
||||||
|
|
||||||
/* Perform an all reduce sum operation */
|
/* Perform an all reduce sum operation */
|
||||||
void all_sum(Group group, const array& input, array& output, Stream stream);
|
void all_sum(Group group, const array& input, array& output, Stream stream);
|
||||||
|
|
||||||
|
|||||||
@@ -349,10 +349,6 @@ class MPIGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Stream communication_stream(StreamOrDevice s) override {
|
|
||||||
return to_stream(s, Device::cpu);
|
|
||||||
}
|
|
||||||
|
|
||||||
int rank() override {
|
int rank() override {
|
||||||
if (rank_ < 0) {
|
if (rank_ < 0) {
|
||||||
mpi().rank(comm_, &rank_);
|
mpi().rank(comm_, &rank_);
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
if(MLX_BUILD_CUDA)
|
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
|
|
||||||
find_package(NCCL REQUIRED)
|
|
||||||
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
|
|
||||||
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
|
|
||||||
else()
|
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
|
|
||||||
endif()
|
|
||||||
@@ -1,372 +0,0 @@
|
|||||||
#include <arpa/inet.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include <nccl.h>
|
|
||||||
#include <netdb.h>
|
|
||||||
#include <sys/socket.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
#include <cstdio>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <cstring>
|
|
||||||
#include <iostream>
|
|
||||||
#include <mutex>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <string>
|
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/distributed/distributed.h"
|
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
|
||||||
#include "mlx/dtype_utils.h"
|
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core::distributed::nccl {
|
|
||||||
|
|
||||||
#define CHECK_CUDA(cmd) \
|
|
||||||
do { \
|
|
||||||
cudaError_t e = cmd; \
|
|
||||||
if (e != cudaSuccess) { \
|
|
||||||
fprintf( \
|
|
||||||
stderr, \
|
|
||||||
"CUDA error %s:%d '%s'\n", \
|
|
||||||
__FILE__, \
|
|
||||||
__LINE__, \
|
|
||||||
cudaGetErrorString(e)); \
|
|
||||||
exit(1); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define CHECK_NCCL(cmd) \
|
|
||||||
do { \
|
|
||||||
ncclResult_t r = cmd; \
|
|
||||||
if (r != ncclSuccess) { \
|
|
||||||
fprintf( \
|
|
||||||
stderr, \
|
|
||||||
"NCCL error %s:%d '%s'\n", \
|
|
||||||
__FILE__, \
|
|
||||||
__LINE__, \
|
|
||||||
ncclGetErrorString(r)); \
|
|
||||||
exit(1); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define MLX_NCCL_TYPE_LIST(X) \
|
|
||||||
X(int8_t, ncclChar) \
|
|
||||||
X(uint8_t, ncclUint8) \
|
|
||||||
X(int32_t, ncclInt) \
|
|
||||||
X(uint32_t, ncclUint32) \
|
|
||||||
X(int64_t, ncclInt64) \
|
|
||||||
X(uint64_t, ncclUint64) \
|
|
||||||
X(float16_t, ncclHalf) \
|
|
||||||
X(bfloat16_t, ncclBfloat16) \
|
|
||||||
X(float, ncclFloat) \
|
|
||||||
X(double, ncclDouble)
|
|
||||||
|
|
||||||
template <class>
|
|
||||||
struct nccl_map {
|
|
||||||
static constexpr bool ok = false; // default: unsupported
|
|
||||||
};
|
|
||||||
|
|
||||||
#define MLX_DEF_NCCL_MAP(T, E) \
|
|
||||||
template <> \
|
|
||||||
struct nccl_map<T> { \
|
|
||||||
static constexpr bool ok = true; \
|
|
||||||
static constexpr ncclDataType_t value = E; \
|
|
||||||
};
|
|
||||||
|
|
||||||
MLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP)
|
|
||||||
#undef MLX_DEF_NCCL_MAP
|
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
void dispatch_dtype(const array& arr, F&& f) {
|
|
||||||
dispatch_all_types(arr.dtype(), [&](auto type_tag) {
|
|
||||||
using T = MLX_GET_TYPE(type_tag);
|
|
||||||
if constexpr (nccl_map<T>::ok) {
|
|
||||||
f(type_tag, nccl_map<T>::value);
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void sendAll(int sock, const void* buf, size_t len) {
|
|
||||||
const char* ptr = reinterpret_cast<const char*>(buf);
|
|
||||||
while (len > 0) {
|
|
||||||
ssize_t sent = send(sock, ptr, len, 0);
|
|
||||||
if (sent <= 0) {
|
|
||||||
perror("send");
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
ptr += sent;
|
|
||||||
len -= sent;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void recvAll(int sock, void* buf, size_t len) {
|
|
||||||
char* ptr = reinterpret_cast<char*>(buf);
|
|
||||||
while (len > 0) {
|
|
||||||
ssize_t rec = recv(sock, ptr, len, 0);
|
|
||||||
if (rec <= 0) {
|
|
||||||
perror("recv");
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
ptr += rec;
|
|
||||||
len -= rec;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void bootstrap_unique_id(
|
|
||||||
ncclUniqueId& id,
|
|
||||||
int rank,
|
|
||||||
int size,
|
|
||||||
const std::string& initMethod) {
|
|
||||||
// Parse the init method to extract the host and port
|
|
||||||
if (initMethod.rfind("tcp://", 0) != 0)
|
|
||||||
throw;
|
|
||||||
auto hostport = initMethod.substr(6);
|
|
||||||
auto colon = hostport.find(':');
|
|
||||||
std::string host = hostport.substr(0, colon);
|
|
||||||
int port = std::stoi(hostport.substr(colon + 1));
|
|
||||||
|
|
||||||
if (rank == 0) {
|
|
||||||
// create a unique id on the rank 0
|
|
||||||
CHECK_NCCL(ncclGetUniqueId(&id));
|
|
||||||
|
|
||||||
// create a socket to send the unique id to all other ranks
|
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
|
||||||
|
|
||||||
if (sock < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
sockaddr_in serv = {};
|
|
||||||
serv.sin_family = AF_INET;
|
|
||||||
serv.sin_addr.s_addr = htonl(INADDR_ANY);
|
|
||||||
serv.sin_port = htons(port);
|
|
||||||
|
|
||||||
int reuse = 1;
|
|
||||||
// Without this, if rank-0 crashes or restarts process quickly,
|
|
||||||
// the OS might refuse to let binding to the same port, so reuse
|
|
||||||
|
|
||||||
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] bind() failed: " << strerror(errno);
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
if (listen(sock, size - 1) < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] listen() failed: " << strerror(errno);
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int peer = 1; peer < size; ++peer) {
|
|
||||||
int conn = accept(sock, nullptr, nullptr);
|
|
||||||
if (conn < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] accept() failed: " << strerror(errno);
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
sendAll(conn, &id, sizeof(id));
|
|
||||||
close(conn);
|
|
||||||
}
|
|
||||||
close(sock);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// Here just wanted to make show that rank 0 has enough time to bind
|
|
||||||
// so we will retry to connect until max attempts
|
|
||||||
|
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
|
||||||
if (sock < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] socket() failed: " << strerror(errno);
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
hostent* he = gethostbyname(host.c_str());
|
|
||||||
if (!he) {
|
|
||||||
throw std::runtime_error("[nccl] lookup failed for host: " + host);
|
|
||||||
}
|
|
||||||
sockaddr_in serv = {};
|
|
||||||
serv.sin_family = AF_INET;
|
|
||||||
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
|
||||||
serv.sin_port = htons(port);
|
|
||||||
|
|
||||||
const int max_retries = 30;
|
|
||||||
int attempt = 0;
|
|
||||||
bool connected = false;
|
|
||||||
|
|
||||||
bool do_log = std::getenv("NCCL_DEBUG") == "INFO";
|
|
||||||
for (attempt = 0; attempt < max_retries; ++attempt) {
|
|
||||||
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
|
||||||
0) {
|
|
||||||
connected = true;
|
|
||||||
if (do_log) {
|
|
||||||
std::cout << "[Rank " << rank
|
|
||||||
<< "] Connected successfully on attempt " << attempt + 1
|
|
||||||
<< std::endl;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (errno != ECONNREFUSED) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!connected) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[Rank " << rank << "] connect() failed after " << attempt
|
|
||||||
<< " retries: " << strerror(errno);
|
|
||||||
close(sock);
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
recvAll(sock, &id, sizeof(id));
|
|
||||||
close(sock);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
|
||||||
class NCCLGroup : public GroupImpl {
|
|
||||||
public:
|
|
||||||
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
|
|
||||||
: rank_(worldRank),
|
|
||||||
size_(worldSize),
|
|
||||||
comm_(nullptr),
|
|
||||||
initMethod_(initMethod) {
|
|
||||||
if (initialized_)
|
|
||||||
return;
|
|
||||||
int ndev;
|
|
||||||
CHECK_CUDA(cudaGetDeviceCount(&ndev));
|
|
||||||
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
|
|
||||||
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
|
|
||||||
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
|
|
||||||
initialized_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
~NCCLGroup() {
|
|
||||||
ncclCommDestroy(comm_);
|
|
||||||
ncclGroupEnd();
|
|
||||||
initialized_ = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
Stream communication_stream(StreamOrDevice s) override {
|
|
||||||
return to_stream(s, Device::gpu);
|
|
||||||
}
|
|
||||||
|
|
||||||
int rank() override {
|
|
||||||
return rank_;
|
|
||||||
}
|
|
||||||
|
|
||||||
int size() override {
|
|
||||||
return size_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void all_sum(const array& input, array& output, Stream stream) override {
|
|
||||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
|
||||||
using T = typename decltype(type_tag)::type;
|
|
||||||
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
|
||||||
throw std::runtime_error("[nccl] Group split not supported.");
|
|
||||||
}
|
|
||||||
|
|
||||||
void all_gather(const array& input, array& output, Stream stream) override {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[nccl] All gather not supported in NCCL backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
void send(const array& input, int dst, Stream stream) override {
|
|
||||||
throw std::runtime_error("[nccl] Send not supported in NCCL backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
void recv(array& output, int src, Stream stream) override {
|
|
||||||
throw std::runtime_error("[nccl] Recv not supported in NCCL backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
void all_max(const array& input, array& output, Stream stream) override {
|
|
||||||
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
void all_min(const array& input, array& output, Stream stream) override {
|
|
||||||
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void all_reduce_impl(
|
|
||||||
const array& input,
|
|
||||||
array& output,
|
|
||||||
Stream stream,
|
|
||||||
ncclDataType_t dt,
|
|
||||||
ncclRedOp_t op) {
|
|
||||||
auto& encoder = cu::get_command_encoder(stream);
|
|
||||||
|
|
||||||
CHECK_NCCL(ncclAllReduce(
|
|
||||||
input.data<T>(),
|
|
||||||
output.data<T>(),
|
|
||||||
input.size(),
|
|
||||||
dt,
|
|
||||||
op,
|
|
||||||
comm_,
|
|
||||||
encoder.stream()));
|
|
||||||
}
|
|
||||||
|
|
||||||
int rank_, size_;
|
|
||||||
std::string initMethod_;
|
|
||||||
ncclUniqueId uniqueId_;
|
|
||||||
ncclComm_t comm_;
|
|
||||||
bool initialized_ = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
bool is_available() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
std::string get_env_var_or_throw(const char* env_var_name, bool strict) {
|
|
||||||
const char* value = std::getenv(env_var_name);
|
|
||||||
if (value == nullptr && strict) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] Required environment variable '" << env_var_name
|
|
||||||
<< "' is not set. "
|
|
||||||
<< "Please set it before initializing the distributed backend.";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
if (value == nullptr) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
return std::string(value);
|
|
||||||
}
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
|
||||||
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP", strict);
|
|
||||||
std::string port = detail::get_env_var_or_throw("NCCL_PORT", strict);
|
|
||||||
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK", strict);
|
|
||||||
std::string n_nodes_str =
|
|
||||||
detail::get_env_var_or_throw("MLX_WORLD_SIZE", strict);
|
|
||||||
if (!strict &&
|
|
||||||
(host.empty() || port.empty() || rank_str.empty() ||
|
|
||||||
n_nodes_str.empty())) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
int rank = std::stoi(rank_str);
|
|
||||||
int n_nodes = std::stoi(n_nodes_str);
|
|
||||||
std::string init_method = "tcp://" + host + ":" + port;
|
|
||||||
|
|
||||||
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
|
|
||||||
}
|
|
||||||
} // namespace mlx::core::distributed::nccl
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/distributed/distributed.h"
|
|
||||||
|
|
||||||
namespace mlx::core::distributed::nccl {
|
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
|
||||||
|
|
||||||
bool is_available();
|
|
||||||
std::shared_ptr<GroupImpl> init(bool strict = false);
|
|
||||||
|
|
||||||
} // namespace mlx::core::distributed::nccl
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/distributed/nccl/nccl.h"
|
|
||||||
|
|
||||||
namespace mlx::core::distributed::nccl {
|
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
|
||||||
|
|
||||||
bool is_available() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
|
||||||
if (strict) {
|
|
||||||
throw std::runtime_error("Cannot initialize nccl distributed backend.");
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::distributed::nccl
|
|
||||||
@@ -2,9 +2,6 @@
|
|||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/cuda/cuda.h"
|
|
||||||
#include "mlx/backend/metal/metal.h"
|
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
|
|
||||||
@@ -31,12 +28,11 @@ array all_sum(
|
|||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
auto stream = detail::communication_stream(group, s);
|
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(stream, group, AllReduce::Sum),
|
std::make_shared<AllReduce>(
|
||||||
|
to_stream(s, Device::cpu), group, AllReduce::Sum),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,12 +45,11 @@ array all_max(
|
|||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
auto stream = detail::communication_stream(group, s);
|
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(stream, group, AllReduce::Max),
|
std::make_shared<AllReduce>(
|
||||||
|
to_stream(s, Device::cpu), group, AllReduce::Max),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,12 +62,11 @@ array all_min(
|
|||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
auto stream = detail::communication_stream(group, s);
|
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(stream, group, AllReduce::Min),
|
std::make_shared<AllReduce>(
|
||||||
|
to_stream(s, Device::cpu), group, AllReduce::Min),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,7 +79,6 @@ array all_gather(
|
|||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
auto stream = detail::communication_stream(group, s);
|
|
||||||
|
|
||||||
auto result_shape = x.shape();
|
auto result_shape = x.shape();
|
||||||
if (result_shape.size() == 0) {
|
if (result_shape.size() == 0) {
|
||||||
@@ -96,7 +89,7 @@ array all_gather(
|
|||||||
return array(
|
return array(
|
||||||
std::move(result_shape),
|
std::move(result_shape),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllGather>(stream, group),
|
std::make_shared<AllGather>(to_stream(s, Device::cpu), group),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,7 +103,6 @@ array send(
|
|||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
throw std::invalid_argument("Cannot send to a singleton group");
|
throw std::invalid_argument("Cannot send to a singleton group");
|
||||||
}
|
}
|
||||||
auto stream = detail::communication_stream(group, s);
|
|
||||||
|
|
||||||
if (dst < 0 || dst >= group.size()) {
|
if (dst < 0 || dst >= group.size()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@@ -120,7 +112,10 @@ array send(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
x.shape(), x.dtype(), std::make_shared<Send>(stream, group, dst), {x});
|
x.shape(),
|
||||||
|
x.dtype(),
|
||||||
|
std::make_shared<Send>(to_stream(s, Device::cpu), group, dst),
|
||||||
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
array recv(
|
array recv(
|
||||||
@@ -134,7 +129,6 @@ array recv(
|
|||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
throw std::invalid_argument("Cannot recv from a singleton group");
|
throw std::invalid_argument("Cannot recv from a singleton group");
|
||||||
}
|
}
|
||||||
auto stream = detail::communication_stream(group, s);
|
|
||||||
|
|
||||||
if (src < 0 || src >= group.size()) {
|
if (src < 0 || src >= group.size()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@@ -145,7 +139,7 @@ array recv(
|
|||||||
return array(
|
return array(
|
||||||
std::move(shape),
|
std::move(shape),
|
||||||
std::move(dtype),
|
std::move(dtype),
|
||||||
std::make_shared<Recv>(stream, group, src),
|
std::make_shared<Recv>(to_stream(s, Device::cpu), group, src),
|
||||||
std::vector<array>{});
|
std::vector<array>{});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -619,10 +619,6 @@ class RingGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Stream communication_stream(StreamOrDevice s) override {
|
|
||||||
return to_stream(s, Device::cpu);
|
|
||||||
}
|
|
||||||
|
|
||||||
int rank() override {
|
int rank() override {
|
||||||
return rank_;
|
return rank_;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,8 +20,6 @@ from select import select
|
|||||||
from subprocess import PIPE, Popen, run
|
from subprocess import PIPE, Popen, run
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Host:
|
class Host:
|
||||||
@@ -55,11 +53,6 @@ def parse_hardware_ports(ports_string):
|
|||||||
return ports
|
return ports
|
||||||
|
|
||||||
|
|
||||||
def get_num_nvidia_gpus():
|
|
||||||
result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
|
|
||||||
return len(result.stdout.strip().split("\n"))
|
|
||||||
|
|
||||||
|
|
||||||
def extract_rings(hosts, index):
|
def extract_rings(hosts, index):
|
||||||
def usable_port(i, j, used_ports):
|
def usable_port(i, j, used_ports):
|
||||||
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
|
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
|
||||||
@@ -422,57 +415,6 @@ def launch_mpi(parser, hosts, args, command):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def launch_nccl(parser, hosts, args, command):
|
|
||||||
master_host = hosts[0].ips[0]
|
|
||||||
|
|
||||||
if master_host != "127.0.0.1":
|
|
||||||
raise ValueError("The NCCL backend only supports localhost for now.")
|
|
||||||
master_port = args.nccl_port
|
|
||||||
world_size = len(hosts)
|
|
||||||
|
|
||||||
base_env = os.environ.copy()
|
|
||||||
base_env.update(
|
|
||||||
{
|
|
||||||
"NCCL_DEBUG": base_env.get(
|
|
||||||
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
|
|
||||||
),
|
|
||||||
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
|
||||||
"NCCL_HOST_IP": master_host,
|
|
||||||
"NCCL_PORT": str(master_port),
|
|
||||||
"MLX_WORLD_SIZE": str(world_size),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
procs = []
|
|
||||||
num_gpus = get_num_nvidia_gpus()
|
|
||||||
if num_gpus == 0:
|
|
||||||
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
|
|
||||||
if args.repeat_hosts > num_gpus:
|
|
||||||
raise RuntimeError("NCCL requires a separate GPU per process.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
for rank in range(world_size):
|
|
||||||
env = base_env.copy()
|
|
||||||
mlx_rank = str(rank % args.repeat_hosts)
|
|
||||||
env["MLX_RANK"] = mlx_rank
|
|
||||||
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
|
|
||||||
p = Popen(command, env=env)
|
|
||||||
procs.append(p)
|
|
||||||
|
|
||||||
for p in procs:
|
|
||||||
ret = p.wait()
|
|
||||||
if ret != 0:
|
|
||||||
raise RuntimeError(f"Rank process exited with {ret}")
|
|
||||||
|
|
||||||
except (RuntimeError, KeyboardInterrupt) as err:
|
|
||||||
for p in procs:
|
|
||||||
if p.poll() is None:
|
|
||||||
try:
|
|
||||||
p.kill()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def check_ssh_connections(hosts):
|
def check_ssh_connections(hosts):
|
||||||
results = [False] * len(hosts)
|
results = [False] * len(hosts)
|
||||||
|
|
||||||
@@ -723,8 +665,8 @@ def distributed_config():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi", "nccl"],
|
choices=["ring", "mpi"],
|
||||||
default="nccl" if mx.cuda.is_available() else "ring",
|
default="ring",
|
||||||
help="Which distributed backend to configure",
|
help="Which distributed backend to configure",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -795,8 +737,8 @@ def main():
|
|||||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi", "nccl"],
|
choices=["ring", "mpi"],
|
||||||
default="nccl" if mx.cuda.is_available() else "ring",
|
default="ring",
|
||||||
help="Which distributed backend to launch",
|
help="Which distributed backend to launch",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -827,14 +769,9 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cwd", help="Set the working directory on each node to the provided one"
|
"--cwd", help="Set the working directory on each node to the provided one"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--nccl-port",
|
|
||||||
type=int,
|
|
||||||
default=12345,
|
|
||||||
help="The port to use for the NCCL communication (only for nccl backend)",
|
|
||||||
)
|
|
||||||
|
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
|
if rest[0] == "--":
|
||||||
|
rest.pop(0)
|
||||||
|
|
||||||
if args.print_python:
|
if args.print_python:
|
||||||
print(sys.executable)
|
print(sys.executable)
|
||||||
@@ -862,10 +799,8 @@ def main():
|
|||||||
# Launch
|
# Launch
|
||||||
if args.backend == "ring":
|
if args.backend == "ring":
|
||||||
launch_ring(parser, hosts, args, rest)
|
launch_ring(parser, hosts, args, rest)
|
||||||
if args.backend == "mpi":
|
elif args.backend == "mpi":
|
||||||
launch_mpi(parser, hosts, args, rest)
|
launch_mpi(parser, hosts, args, rest)
|
||||||
if args.backend == "nccl":
|
|
||||||
launch_nccl(parser, hosts, args, rest)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -76,7 +76,6 @@ def average_gradients(
|
|||||||
group: Optional[mx.distributed.Group] = None,
|
group: Optional[mx.distributed.Group] = None,
|
||||||
all_reduce_size: int = 32 * 1024**2,
|
all_reduce_size: int = 32 * 1024**2,
|
||||||
communication_type: Optional[mx.Dtype] = None,
|
communication_type: Optional[mx.Dtype] = None,
|
||||||
stream: mx.Stream = mx.cpu,
|
|
||||||
):
|
):
|
||||||
"""Average the gradients across the distributed processes in the passed group.
|
"""Average the gradients across the distributed processes in the passed group.
|
||||||
|
|
||||||
@@ -95,7 +94,6 @@ def average_gradients(
|
|||||||
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
||||||
type before performing the communication. Typically cast to a
|
type before performing the communication. Typically cast to a
|
||||||
smaller float to reduce the communication size. Default: ``None``.
|
smaller float to reduce the communication size. Default: ``None``.
|
||||||
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
|
|
||||||
"""
|
"""
|
||||||
group = group or mx.distributed.init()
|
group = group or mx.distributed.init()
|
||||||
N = group.size()
|
N = group.size()
|
||||||
@@ -106,7 +104,7 @@ def average_gradients(
|
|||||||
def _average(x):
|
def _average(x):
|
||||||
dt = x.dtype
|
dt = x.dtype
|
||||||
x = x.astype(communication_type) if communication_type is not None else x
|
x = x.astype(communication_type) if communication_type is not None else x
|
||||||
return mx.distributed.all_sum(x, stream=stream).astype(dt) / N
|
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
|
||||||
|
|
||||||
if all_reduce_size <= 0:
|
if all_reduce_size <= 0:
|
||||||
return tree_map(_average, gradients)
|
return tree_map(_average, gradients)
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ auditwheel repair dist/* \
|
|||||||
--exclude libnvrtc* \
|
--exclude libnvrtc* \
|
||||||
--exclude libcuda* \
|
--exclude libcuda* \
|
||||||
--exclude libcudnn* \
|
--exclude libcudnn* \
|
||||||
--exclude libnccl* \
|
|
||||||
-w wheel_tmp
|
-w wheel_tmp
|
||||||
|
|
||||||
|
|
||||||
@@ -18,7 +17,7 @@ rm "${repaired_wheel}"
|
|||||||
mlx_so="mlx/lib/libmlx.so"
|
mlx_so="mlx/lib/libmlx.so"
|
||||||
rpath=$(patchelf --print-rpath "${mlx_so}")
|
rpath=$(patchelf --print-rpath "${mlx_so}")
|
||||||
base="\$ORIGIN/../../nvidia"
|
base="\$ORIGIN/../../nvidia"
|
||||||
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib:${base}/nccl/lib
|
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib
|
||||||
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
||||||
python ../python/scripts/repair_record.py ${mlx_so}
|
python ../python/scripts/repair_record.py ${mlx_so}
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||||
it throws a runtime error. Default: ``False``
|
it throws a runtime error. Default: ``False``
|
||||||
backend (str, optional): Which distributed backend to initialize.
|
backend (str, optional): Which distributed backend to initialize.
|
||||||
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
|
Possible values ``mpi``, ``ring``, ``any``. If set to ``any`` all
|
||||||
available backends are tried and the first one that succeeds
|
available backends are tried and the first one that succeeds
|
||||||
becomes the global group which will be returned in subsequent
|
becomes the global group which will be returned in subsequent
|
||||||
calls. Default: ``any``
|
calls. Default: ``any``
|
||||||
|
|||||||
@@ -1,284 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
import mlx_tests
|
|
||||||
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
|
||||||
from mlx.nn.utils import average_gradients
|
|
||||||
|
|
||||||
|
|
||||||
class TestNCCLDistributed(mlx_tests.MLXTestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
world = mx.distributed.init(strict=True, backend="nccl")
|
|
||||||
rank = world.rank()
|
|
||||||
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
|
|
||||||
|
|
||||||
def test_all_reduce(self):
|
|
||||||
world = mx.distributed.init()
|
|
||||||
dtypes = [
|
|
||||||
(mx.int8, 0),
|
|
||||||
(mx.uint8, 0),
|
|
||||||
(mx.int32, 0),
|
|
||||||
(mx.uint32, 0),
|
|
||||||
(mx.float32, 1e-6),
|
|
||||||
(mx.float16, 5e-3),
|
|
||||||
(mx.bfloat16, 1e-1),
|
|
||||||
]
|
|
||||||
sizes = [
|
|
||||||
(7,),
|
|
||||||
(10,),
|
|
||||||
(1024,),
|
|
||||||
(1024, 1024),
|
|
||||||
]
|
|
||||||
key = mx.random.key(0)
|
|
||||||
|
|
||||||
for dt, rtol in dtypes:
|
|
||||||
for sh in sizes:
|
|
||||||
x = (
|
|
||||||
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
|
||||||
).astype(dt)
|
|
||||||
|
|
||||||
# All sum
|
|
||||||
y = mx.distributed.all_sum(x[world.rank()])
|
|
||||||
z = x.sum(0)
|
|
||||||
maxrelerror = (y - z).abs()
|
|
||||||
if rtol > 0:
|
|
||||||
maxrelerror /= z.abs()
|
|
||||||
maxrelerror = maxrelerror.max()
|
|
||||||
self.assertLessEqual(maxrelerror, rtol)
|
|
||||||
|
|
||||||
def test_average_gradients(self):
|
|
||||||
original_all_sum = mx.distributed.all_sum
|
|
||||||
n_calls = 0
|
|
||||||
xtype = None
|
|
||||||
|
|
||||||
def new_all_sum(x, **kwargs):
|
|
||||||
nonlocal n_calls
|
|
||||||
nonlocal xtype
|
|
||||||
|
|
||||||
n_calls += 1
|
|
||||||
if xtype is not None:
|
|
||||||
self.assertEqual(xtype, x.dtype)
|
|
||||||
|
|
||||||
return original_all_sum(x, **kwargs)
|
|
||||||
|
|
||||||
mx.distributed.all_sum = new_all_sum
|
|
||||||
try:
|
|
||||||
grads = [mx.ones(10) for i in range(10)]
|
|
||||||
new_grads = average_gradients(grads, stream=mx.gpu)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 1)
|
|
||||||
|
|
||||||
n_calls = 0
|
|
||||||
new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 2)
|
|
||||||
|
|
||||||
n_calls = 0
|
|
||||||
new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 10)
|
|
||||||
|
|
||||||
n_calls = 0
|
|
||||||
xtype = mx.float16
|
|
||||||
new_grads = average_gradients(
|
|
||||||
grads,
|
|
||||||
all_reduce_size=2 * 50,
|
|
||||||
communication_type=mx.float16,
|
|
||||||
stream=mx.gpu,
|
|
||||||
)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 2)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
mx.distributed.all_sum = original_all_sum
|
|
||||||
|
|
||||||
def test_donation(self):
|
|
||||||
x = mx.random.normal((1024,))
|
|
||||||
mx.eval(x)
|
|
||||||
mx.synchronize()
|
|
||||||
|
|
||||||
mx.reset_peak_memory()
|
|
||||||
scale = mx.array(2.0)
|
|
||||||
y = mx.distributed.all_sum(x)
|
|
||||||
mx.eval(y)
|
|
||||||
mx.synchronize()
|
|
||||||
all_sum_only = mx.get_peak_memory()
|
|
||||||
y = mx.distributed.all_sum(x) * scale
|
|
||||||
mx.eval(y)
|
|
||||||
mx.synchronize()
|
|
||||||
all_sum_with_binary = mx.get_peak_memory()
|
|
||||||
|
|
||||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
|
||||||
|
|
||||||
def test_shard_linear(self):
|
|
||||||
# Seed the prng to have the same inputs and weights generated everywhere
|
|
||||||
mx.random.seed(0xF0F0F0F0)
|
|
||||||
|
|
||||||
# Prepare inputs
|
|
||||||
world = mx.distributed.init()
|
|
||||||
part = (
|
|
||||||
slice(None),
|
|
||||||
slice(
|
|
||||||
world.rank() * 1024 // world.size(),
|
|
||||||
(world.rank() + 1) * 1024 // world.size(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
x = mx.random.normal((4, 1024))
|
|
||||||
|
|
||||||
# Create and shard some linear layers
|
|
||||||
lin = nn.Linear(1024, 1024, bias=True)
|
|
||||||
slin1 = shard_linear(lin, "all-to-sharded")
|
|
||||||
slin2 = shard_linear(lin, "sharded-to-all")
|
|
||||||
y = lin(x)
|
|
||||||
y1 = slin1(x)
|
|
||||||
y2 = slin2(x[part])
|
|
||||||
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
|
|
||||||
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
|
|
||||||
|
|
||||||
# Check the backward works as expected
|
|
||||||
def dummy_loss(model, x, y):
|
|
||||||
return (model(x) * y).sum()
|
|
||||||
|
|
||||||
mod = nn.Sequential(
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
)
|
|
||||||
smod = nn.Sequential(
|
|
||||||
shard_linear(mod.layers[0], "all-to-sharded"),
|
|
||||||
shard_linear(mod.layers[1], "sharded-to-all"),
|
|
||||||
shard_linear(mod.layers[2], "all-to-sharded"),
|
|
||||||
shard_linear(mod.layers[3], "sharded-to-all"),
|
|
||||||
)
|
|
||||||
|
|
||||||
grad1 = nn.value_and_grad(mod, dummy_loss)
|
|
||||||
grad2 = nn.value_and_grad(smod, dummy_loss)
|
|
||||||
|
|
||||||
x = mx.random.normal((4, 128))
|
|
||||||
y = mx.random.normal((4, 128))
|
|
||||||
|
|
||||||
l1, g1 = grad1(mod, x, y)
|
|
||||||
l2, g2 = grad2(smod, x, y)
|
|
||||||
mx.eval(l1, g1, l2, g2)
|
|
||||||
|
|
||||||
part = slice(
|
|
||||||
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(l1, l2))
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][0]["weight"][part],
|
|
||||||
g2["layers"][0]["weight"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][2]["weight"][part],
|
|
||||||
g2["layers"][2]["weight"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][1]["weight"][:, part],
|
|
||||||
g2["layers"][1]["weight"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][3]["weight"][:, part],
|
|
||||||
g2["layers"][3]["weight"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][0]["bias"][part],
|
|
||||||
g2["layers"][0]["bias"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][2]["bias"][part],
|
|
||||||
g2["layers"][2]["bias"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_shard_predicate(self):
|
|
||||||
mx.random.seed(0xF0F0F0F0)
|
|
||||||
|
|
||||||
class MyConv(nn.Module):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.aggregate = kwargs.pop("aggregate", False)
|
|
||||||
self.conv = nn.Conv2d(*args, **kwargs)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
if self.aggregate:
|
|
||||||
x = mx.distributed.all_sum(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def sharding(path, weight):
|
|
||||||
parts = path.split(".")
|
|
||||||
even = int(parts[1]) % 2 == 0
|
|
||||||
if even:
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
return -1 if parts[-1] != "bias" else None
|
|
||||||
|
|
||||||
mod = nn.Sequential(
|
|
||||||
MyConv(3, 128, kernel_size=3),
|
|
||||||
MyConv(128, 128, kernel_size=3),
|
|
||||||
MyConv(128, 128, kernel_size=3),
|
|
||||||
MyConv(128, 3, kernel_size=3),
|
|
||||||
)
|
|
||||||
smod = nn.Sequential(
|
|
||||||
MyConv(3, 128, kernel_size=3),
|
|
||||||
MyConv(128, 128, kernel_size=3, aggregate=True),
|
|
||||||
MyConv(128, 128, kernel_size=3),
|
|
||||||
MyConv(128, 3, kernel_size=3, aggregate=True),
|
|
||||||
)
|
|
||||||
smod.update(mod.parameters())
|
|
||||||
shard_inplace(smod, sharding)
|
|
||||||
|
|
||||||
x = mx.random.normal((4, 16, 16, 3))
|
|
||||||
y1 = mod(x)
|
|
||||||
y2 = smod(x)
|
|
||||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
mlx_tests.MLXTestRunner()
|
|
||||||
@@ -3068,13 +3068,6 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
d = mx.where(c, a[1:], b)
|
d = mx.where(c, a[1:], b)
|
||||||
self.assertTrue(mx.all(d == 1.0))
|
self.assertTrue(mx.all(d == 1.0))
|
||||||
|
|
||||||
def test_integer_power(self):
|
|
||||||
x = mx.power(2, mx.array([8, 8, 8, 8, 8, 8, 8, 8]))
|
|
||||||
self.assertTrue(mx.all(x == 256))
|
|
||||||
|
|
||||||
# Doesn't hang
|
|
||||||
x = mx.power(2, -1)
|
|
||||||
|
|
||||||
|
|
||||||
class TestBroadcast(mlx_tests.MLXTestCase):
|
class TestBroadcast(mlx_tests.MLXTestCase):
|
||||||
def test_broadcast_shapes(self):
|
def test_broadcast_shapes(self):
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -297,7 +297,6 @@ if __name__ == "__main__":
|
|||||||
"nvidia-cublas-cu12==12.9.*",
|
"nvidia-cublas-cu12==12.9.*",
|
||||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||||
"nvidia-cudnn-cu12==9.*",
|
"nvidia-cudnn-cu12==9.*",
|
||||||
"nvidia-nccl-cu12",
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
name = "mlx-cpu"
|
name = "mlx-cpu"
|
||||||
|
|||||||
Reference in New Issue
Block a user