Compare commits

..

12 Commits

Author SHA1 Message Date
Angelos Katharopoulos
4987e7615a Improve the cutlass gemm 2025-08-25 18:18:19 -07:00
Angelos Katharopoulos
e1303f6160 Reset cutlass gemm to working state again 2025-08-21 01:29:43 -07:00
Angelos Katharopoulos
cf5eef095d tmp 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
395d582719 Add a cutlass gemm 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
05583bcd10 More pipelining for the sm_80 gemm 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
6fce01593a Improve gemm 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
97afe40b7b Remove duplicate register tile 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
f70c62d69c Simple gemm example 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
0c5fc63a36 Fix docs omission (#2524) 2025-08-20 17:56:06 -07:00
Angelos Katharopoulos
e397177f6e Custom cuda kernel (#2517) 2025-08-20 17:20:22 -07:00
Cheng
f4c8888cbe [CUDA] Fix stride of singleton dims before passing to cuDNN (#2521) 2025-08-21 08:55:26 +09:00
Angelos Katharopoulos
25c1e03205 Fix overflow in large filter small channels (#2520) 2025-08-20 08:03:29 -07:00
47 changed files with 1869 additions and 1213 deletions

View File

@@ -1,54 +0,0 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -70,6 +70,7 @@ are the CPU and GPU.
python/fft python/fft
python/linalg python/linalg
python/metal python/metal
python/cuda
python/memory_management python/memory_management
python/nn python/nn
python/optimizers python/optimizers

View File

@@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
dpkg -i cuda-keyring_1.1-1_all.deb dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y apt-get update -y
apt-get -y install cuda-toolkit-12-9 apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag When building either the Python or C++ APIs make sure to pass the cmake flag

9
docs/src/python/cuda.rst Normal file
View File

@@ -0,0 +1,9 @@
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available

View File

@@ -13,3 +13,4 @@ Fast
rope rope
scaled_dot_product_attention scaled_dot_product_attention
metal_kernel metal_kernel
cuda_kernel

View File

@@ -20,12 +20,14 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cutlass_gemm.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simple_gemm.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -88,6 +90,9 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
target_compile_options(mlx target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>") PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Keep ptx around for inspection
target_compile_options(mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--keep>")
# Enable calling host constexpr functions from device. This is needed because # Enable calling host constexpr functions from device. This is needed because
# the constexpr version of isnan is host only. # the constexpr version of isnan is host only.
target_compile_options( target_compile_options(
@@ -173,3 +178,12 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
# Install CCCL headers for JIT. # Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
# Fetch and make available cutlass
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG v4.1.0)
FetchContent_Populate(cutlass)
target_include_directories(
mlx PRIVATE $<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/include>)

View File

@@ -267,7 +267,8 @@ void Compiled::eval_gpu(
} }
} }
return std::make_pair(std::move(builder.os), std::move(kernel_names)); return std::make_tuple(
false, std::move(builder.os), std::move(kernel_names));
}); });
// Collapse contiguous dims to route to a faster kernel if possible. Also // Collapse contiguous dims to route to a faster kernel if possible. Also

View File

@@ -23,6 +23,24 @@ inline cudnn_frontend::Tensor build_cudnn_tensor(
.build(); .build();
} }
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
// whether a tensor is contiguous is determined with:
// shape[dim] == shape[dim + 1] * strides[dim + 1]
// So a contiguous array with singleton dims in MLX may be mistakenly treated
// as strided in cuDNN, and we work around it by normalizing the strides.
Strides normalized_strides(const array& x) {
if (!x.flags().row_contiguous || x.ndim() < 2) {
return x.strides();
}
Strides strides = x.strides();
for (int i = x.ndim() - 2; i >= 0; --i) {
if (x.shape(i) == 1) {
strides[i] = x.shape(i + 1) * strides[i + 1];
}
}
return strides;
}
// Return the shape and strides after transposing from NHWC to NCHW. // Return the shape and strides after transposing from NHWC to NCHW.
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) { auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
assert(shape.size() >= 3); assert(shape.size() >= 3);
@@ -33,8 +51,9 @@ auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
return std::make_tuple(std::move(shape), std::move(strides)); return std::make_tuple(std::move(shape), std::move(strides));
} }
auto nhwc_to_nchw(const array& x) { inline auto nhwc_to_nchw(const array& x) {
return nhwc_to_nchw(convert_vector<int64_t>(x.shape()), x.strides()); return nhwc_to_nchw(
convert_vector<int64_t>(x.shape()), normalized_strides(x));
} }
// Return available engines for a |op_graph|. // Return available engines for a |op_graph|.
@@ -140,7 +159,7 @@ bool prepare_cudnn_plan(
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) { cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
auto shape = convert_vector<int64_t>(x.shape()); auto shape = convert_vector<int64_t>(x.shape());
return build_cudnn_tensor(id, x, shape, x.strides()); return build_cudnn_tensor(id, x, shape, normalized_strides(x));
} }
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) { cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
@@ -160,7 +179,8 @@ cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
return build_cudnn_tensor(id, x, shape, strides); return build_cudnn_tensor(id, x, shape, strides);
} }
if (x.ndim() == 2) { if (x.ndim() == 2) {
int64_t s = x.strides(0); int64_t s =
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1}; SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s}; SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
return build_cudnn_tensor(id, x, shape, strides); return build_cudnn_tensor(id, x, shape, strides);

View File

@@ -0,0 +1,379 @@
// Copyright © 2025 Apple Inc.
#include <iostream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::fast {
namespace {
constexpr const char* default_header = R"(
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
#define inf cuda::std::numeric_limits<float>::infinity()
)";
std::string template_arguments_hash(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
if (template_args.empty()) {
return "";
}
std::string hash;
hash.reserve(512);
for (const auto& [name, arg] : template_args) {
if (std::holds_alternative<int>(arg)) {
hash += fmt::format("_{}", std::get<int>(arg));
} else if (std::holds_alternative<bool>(arg)) {
hash += (std::get<bool>(arg)) ? "_t" : "_f";
} else if (std::holds_alternative<Dtype>(arg)) {
hash += "_";
hash += get_type_string(std::get<Dtype>(arg));
}
}
return hash;
}
std::string build_kernel(
const std::string& func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<CustomKernelShapeInfo>& shape_infos) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 8192);
kernel_source += default_header;
kernel_source += header;
kernel_source +=
"namespace mlx::core::cu {\n\n"
"namespace cg = cooperative_groups;\n\n";
kernel_source += "__global__ void ";
kernel_source += func_name;
kernel_source += "(\n";
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
kernel_source += " const ";
kernel_source += dtype_to_cuda_type(arr.dtype());
kernel_source += "* ";
kernel_source += name;
kernel_source += ",\n";
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source += " const __grid_constant__ Shape ";
kernel_source += name;
kernel_source += "_shape,\n";
}
if (shape_infos[i].strides) {
kernel_source += " const __grid_constant__ Strides ";
kernel_source += name;
kernel_source += "_strides,\n";
}
if (shape_infos[i].ndim) {
kernel_source += " const __grid_constant__ int ";
kernel_source += name;
kernel_source += "_ndim,\n";
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " ";
kernel_source += dtype_to_cuda_type(dtype);
kernel_source += "* ";
kernel_source += name;
if (i < output_names.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
}
// Set compile time constants
if (!template_args.empty()) {
for (const auto& [name, arg] : template_args) {
if (std::holds_alternative<int>(arg)) {
kernel_source +=
fmt::format(" constexpr int {} = {};\n", name, std::get<int>(arg));
} else if (std::holds_alternative<bool>(arg)) {
kernel_source += fmt::format(
" constexpr bool {} = {};\n", name, std::get<bool>(arg));
} else {
kernel_source += fmt::format(
" using {} = {};\n",
name,
dtype_to_cuda_type(std::get<Dtype>(arg)));
}
}
kernel_source += "\n";
}
kernel_source += source;
kernel_source += "\n}\n\n} // namespace mlx::core::cu\n";
return kernel_source;
}
} // namespace
CustomKernelFunction cuda_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header,
bool ensure_row_contiguous,
int shared_memory) {
if (output_names.empty()) {
throw std::invalid_argument(
"[custom_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
return [=, shape_infos = std::move(shape_infos)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[custom_kernel] Only supports the GPU.");
}
std::string kernel_name =
"custom_kernel_" + name + template_arguments_hash(template_args);
std::string kernel_source = build_kernel(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
shape_infos);
if (verbose) {
std::cout << "Generated source code for `" << kernel_name
<< "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
std::vector<ScalarArg>{},
false,
shared_memory),
std::move(inputs));
};
}
std::vector<array> precompiled_cuda_kernel(
const std::string& name,
const std::string& compiled_source,
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
const std::vector<ScalarArg>& scalars,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory,
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice s) {
std::vector<CustomKernelShapeInfo> shape_infos(
inputs.size(), CustomKernelShapeInfo{false, false, false});
return array::make_arrays(
output_shapes,
output_dtypes,
std::make_shared<CustomKernel>(
to_stream(s),
name,
compiled_source,
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
scalars,
true,
shared_memory),
inputs);
}
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("CustomKernel::eval_gpu");
auto& s = stream();
std::vector<array> copies;
// Allocate and initialize the output arrays
for (auto& out : outputs) {
if (init_value_) {
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
// Create the input arrays and copy if needed
auto check_input = [&copies, &s, this](const array& x) -> const array {
bool no_copy = x.flags().row_contiguous;
if (!ensure_row_contiguous_ || no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
std::vector<array> checked_inputs;
for (const array& in : inputs) {
checked_inputs.push_back(check_input(in));
}
// Compile the custom kernel
std::string kernel_name =
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
cu::JitModule& mod = cu::get_jit_module(
s.device,
name_,
[&]() {
return std::make_tuple(
is_precompiled_, source_, std::vector{kernel_name});
},
false);
// Make the arguments
cu::KernelArgs args;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i];
args.append(in);
if (shape_info.shape) {
args.append_ndim(in.shape());
}
if (shape_info.strides) {
args.append_ndim(in.strides());
}
if (shape_info.ndim) {
args.append<int32_t>(in.ndim());
}
}
for (auto& out : outputs) {
args.append(out);
}
for (auto& s : scalar_arguments_) {
if (std::holds_alternative<bool>(s)) {
args.append(std::get<bool>(s));
} else if (std::holds_alternative<int>(s)) {
args.append(std::get<int>(s));
} else if (std::holds_alternative<float>(s)) {
args.append(std::get<float>(s));
}
}
// Make the grid
const auto [tx, ty, tz] = threadgroup_;
const auto [gx, gy, gz] = grid_;
dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
// Call the kernel
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : checked_inputs) {
encoder.set_input_array(in);
}
for (const auto& out : outputs) {
encoder.set_output_array(out);
}
for (const auto& t : copies) {
encoder.add_temporary(t);
}
auto kernel =
mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
if (smem > 0 && smem > 48000) {
cuFuncSetAttribute(
kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
}
});
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
}
} // namespace mlx::core::fast

View File

@@ -1,51 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/distributed/primitives.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core {
namespace distributed {
void AllReduce::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& input = inputs[0];
auto& output = outputs[0];
auto& encoder = cu::get_command_encoder(stream());
if (input.is_donatable()) {
output.copy_shared_buffer(input);
} else {
output.set_data(allocator::malloc(output.nbytes()));
}
encoder.set_input_array(input);
encoder.set_output_array(output);
auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), input, output, s);
break;
case Max:
distributed::detail::all_max(group(), input, output, s);
break;
case Min:
distributed::detail::all_min(group(), input, output, s);
break;
default:
throw std::runtime_error(
"Only all reduce sum, max, and min are supported.");
}
}
} // namespace distributed
} // namespace mlx::core

View File

@@ -0,0 +1,396 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cute/tensor.hpp>
#include <cutlass/arch/arch.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/numeric_types.h>
#include <iostream>
namespace mlx::core::cu {
namespace {
using namespace cute;
using bf16 = cute::bfloat16_t;
template <typename Kernel>
void configure_matmul(Kernel kernel, int smem_size) {
static bool initialized = false;
if (!initialized) {
initialized = true;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
}
template <bool transpose, typename Tiler>
constexpr int get_feature_size(Tiler smem) {
int feature_size = (transpose) ? size<0>(smem) : size<1>(smem);
return (feature_size >= 64) ? 64 : feature_size;
}
constexpr int constexpr_log2(int x) {
return (x > 0) ? 1 + constexpr_log2(x >> 1) : -1;
}
template <int feature_size, int itemsize, int copy_bits>
constexpr int get_swizzle_bits() {
constexpr int swizzle_bits =
constexpr_log2(feature_size * itemsize / copy_bits);
return (swizzle_bits > 3) ? 3 : swizzle_bits;
}
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
constexpr auto make_smem_layout(Tiler smem) {
constexpr int feature_size = get_feature_size<transpose>(smem);
constexpr int swizzle_bits =
get_swizzle_bits<feature_size, itemsize, copy_bits>();
using F = Int<feature_size>;
using BaseLayout = std::conditional_t<
transpose,
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
auto swizzled =
make_composed_layout(Swizzle<swizzle_bits, 3, 3>{}, 0, BaseLayout{});
return tile_to_shape(swizzled, smem);
}
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
constexpr auto make_result_smem_layout(Tiler smem) {
constexpr int feature_size = get_feature_size<transpose>(smem);
constexpr int swizzle_bits =
get_swizzle_bits<feature_size, itemsize, copy_bits>();
using F = Int<feature_size>;
using BaseLayout = std::conditional_t<
transpose,
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
auto swizzled = make_composed_layout(
Swizzle<transpose ? 0 : swizzle_bits, 3, 4>{}, 0, BaseLayout{});
return tile_to_shape(swizzled, smem);
}
template <
int num_threads,
int itemsize,
bool transpose,
int copy_bits,
typename Copier,
typename Tiler>
constexpr auto make_tiled_copy(Copier copy_op, Tiler smem) {
constexpr int num_elements = copy_bits / itemsize;
constexpr int feature_size = transpose ? size<0>(smem) : size<1>(smem);
constexpr int copies_per_feature = feature_size / num_elements;
using E = Int<num_elements>;
using C = Int<copies_per_feature>;
using R = Int<num_threads / copies_per_feature>;
using ThreadLayout = std::conditional_t<
transpose,
Layout<cute::Shape<C, R>, cute::Stride<_1, C>>,
Layout<cute::Shape<R, C>, cute::Stride<C, _1>>>;
using ValueLayout = std::conditional_t<
transpose,
Layout<cute::Shape<E, _1>>,
Layout<cute::Shape<_1, E>>>;
return make_tiled_copy(copy_op, ThreadLayout{}, ValueLayout{});
}
template <int rasterization_factor>
__device__ inline int2 raster_tile(int x, int y) {
return {
x / rasterization_factor,
(x % rasterization_factor) + y * rasterization_factor};
}
template <
typename T,
typename SLayoutA,
typename SLayoutB,
typename SLayoutC,
typename CopyA,
typename CopyB,
typename CopyC,
typename MMA,
int rasterization_factor>
__global__ static __launch_bounds__(decltype(size(MMA{}))::value) void matmul_kernel(
const T* __restrict__ A,
const T* __restrict__ B,
T* __restrict__ C,
SLayoutA SA,
SLayoutB SB,
SLayoutC SC,
CopyA copy_a,
CopyB copy_b,
CopyC copy_c,
MMA mma,
int M,
int N,
int K) {
constexpr auto BM = size<0>(SA);
constexpr auto BN = size<0>(SB);
constexpr auto BK = size<1>(SA);
constexpr auto PIPE = size<2>(SA);
const int2 tile = raster_tile<rasterization_factor>(blockIdx.x, blockIdx.y);
const int blocks_m = ceil_div(M, BM);
const int blocks_n = ceil_div(N, BN);
// Exit early if the tile is OOB
if (tile.x >= blocks_m || tile.y >= blocks_n) {
return;
}
// Make the full tensors
Tensor full_A =
make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{}));
Tensor full_B =
make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{}));
Tensor full_C =
make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{}));
// Partition the tensors into tiles and select the ones for this threadblock
Tensor local_A =
local_tile(full_A, make_shape(BM, BK), make_coord(tile.x, _));
Tensor local_B =
local_tile(full_B, make_shape(BN, BK), make_coord(tile.y, _));
Tensor local_C =
local_tile(full_C, make_shape(BM, BN), make_coord(tile.x, tile.y));
// Make shared memory tensors
extern __shared__ char shared_memory[];
T* shared_A_ptr = reinterpret_cast<T*>(shared_memory);
T* shared_B_ptr =
reinterpret_cast<T*>(shared_memory + cosize(SA) * sizeof(T));
T* shared_C_ptr = reinterpret_cast<T*>(shared_memory);
Tensor shared_A = make_tensor(make_smem_ptr(shared_A_ptr), SA);
Tensor shared_B = make_tensor(make_smem_ptr(shared_B_ptr), SB);
Tensor shared_C = make_tensor(make_smem_ptr(shared_C_ptr), SC);
// Get the copies that correspond to this thread
auto thread_copy_a = copy_a.get_slice(threadIdx.x);
Tensor local_A_src = thread_copy_a.partition_S(local_A);
Tensor local_A_dst = thread_copy_a.partition_D(shared_A);
auto thread_copy_b = copy_b.get_slice(threadIdx.x);
Tensor local_B_src = thread_copy_a.partition_S(local_B);
Tensor local_B_dst = thread_copy_a.partition_D(shared_B);
auto thread_copy_c = copy_c.get_slice(threadIdx.x);
Tensor local_C_src = thread_copy_c.partition_S(shared_C);
Tensor local_C_dst = thread_copy_c.partition_D(local_C);
// Start fetches
int k_tile_count = size<2>(local_A);
int k_tile_next = 0;
CUTE_UNROLL
for (int k = 0; k < PIPE - 1; k++) {
copy(copy_a, local_A_src(_, _, _, k_tile_next), local_A_dst(_, _, _, k));
copy(copy_b, local_B_src(_, _, _, k_tile_next), local_B_dst(_, _, _, k));
cp_async_fence();
k_tile_count--;
k_tile_next += (k_tile_count > 0);
}
// Get the MMA that corresponds to this thread and allocate registers
auto thread_mma = mma.get_slice(threadIdx.x);
Tensor mma_shared_A = thread_mma.partition_A(shared_A);
Tensor mma_shared_B = thread_mma.partition_B(shared_B);
Tensor mma_shared_C = thread_mma.partition_C(shared_C);
Tensor mma_global_C = thread_mma.partition_C(local_C);
Tensor mma_frag_A = mma.make_fragment_A(mma_shared_A(_, _, _, 0));
Tensor mma_frag_B = mma.make_fragment_B(mma_shared_B(_, _, _, 0));
Tensor mma_frag_C = mma.make_fragment_C(mma_global_C);
clear(mma_frag_C);
// Make shared to register copies
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_a;
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_b;
auto s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
auto s2r_thread_copy_a = s2r_copy_a.get_slice(threadIdx.x);
auto s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
auto s2r_thread_copy_b = s2r_copy_b.get_slice(threadIdx.x);
Tensor mma_A_src = s2r_thread_copy_a.partition_S(shared_A);
Tensor mma_A_dst = s2r_thread_copy_a.retile_D(mma_frag_A);
Tensor mma_B_src = s2r_thread_copy_b.partition_S(shared_B);
Tensor mma_B_dst = s2r_thread_copy_b.retile_D(mma_frag_B);
constexpr auto RPIPE = size<2>(mma_shared_A);
int smem_read = 0;
int smem_write = PIPE - 1;
Tensor mma_A_src_p = mma_A_src(_, _, _, smem_read);
Tensor mma_B_src_p = mma_B_src(_, _, _, smem_read);
// Start the register pipeline
if constexpr (RPIPE > 1) {
cp_async_wait<PIPE - 2>();
__syncthreads();
copy(s2r_copy_a, mma_A_src_p(_, _, Int<0>{}), mma_A_dst(_, _, Int<0>{}));
copy(s2r_copy_b, mma_B_src_p(_, _, Int<0>{}), mma_B_dst(_, _, Int<0>{}));
}
CUTE_NO_UNROLL
while (k_tile_count > -(PIPE - 1)) {
CUTE_UNROLL
for (int k_block = 0; k_block < RPIPE; k_block++) {
if (k_block == RPIPE - 1) {
mma_A_src_p = mma_A_src(_, _, _, smem_read);
mma_B_src_p = mma_B_src(_, _, _, smem_read);
cp_async_wait<PIPE - 2>();
__syncthreads();
}
// Load the next register tile
auto k_block_next = (k_block + 1) % RPIPE;
copy(
s2r_copy_a,
mma_A_src_p(_, _, k_block_next),
mma_A_dst(_, _, k_block_next));
copy(
s2r_copy_b,
mma_B_src_p(_, _, k_block_next),
mma_B_dst(_, _, k_block_next));
if (k_block == 0) {
copy(
copy_a,
local_A_src(_, _, _, k_tile_next),
local_A_dst(_, _, _, smem_write));
copy(
copy_b,
local_B_src(_, _, _, k_tile_next),
local_B_dst(_, _, _, smem_write));
cp_async_fence();
k_tile_count--;
k_tile_next += (k_tile_count > 0);
smem_write = smem_read;
smem_read = (smem_read == PIPE - 1) ? 0 : (smem_read + 1);
}
gemm(
mma,
mma_frag_A(_, _, k_block),
mma_frag_B(_, _, k_block),
mma_frag_C);
}
}
copy(mma_frag_C, mma_shared_C);
__syncthreads();
copy(copy_c, local_C_src, local_C_dst);
// if (threadIdx.x == 0) {
// print("fC: "); print(mma_frag_C); print("\n");
// print("sC: "); print(mma_shared_C); print("\n");
// print("dC: "); print(local_C_dst); print("\n");
//
// print(s2r_atom_a); print("\n");
// }
}
} // namespace
void cutlass_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc) {
enc.set_input_array(a);
enc.set_input_array(b);
enc.set_output_array(out);
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
using namespace cute;
// Tile definitions
auto BM = Int<128>{};
auto BN = Int<128>{};
auto BK = Int<64>{};
auto BP = Int<3>{};
auto GM = Int<8>{};
// Thread definitions
using TM = Int<2>;
using TN = Int<2>;
using TK = Int<1>;
constexpr int num_threads = TM::value * TN::value * 32;
auto SA = make_smem_layout<16, false, 128>(make_shape(BM, BK, BP));
auto SB = make_smem_layout<16, false, 128>(make_shape(BN, BK, BP));
auto SC = make_result_smem_layout<16, false, 128>(make_shape(BM, BN));
constexpr auto smem_size = (cosize(SA) + cosize(SB)) * sizeof(bf16);
auto async_copy_op =
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, bf16>{};
auto tiled_copy_a = make_tiled_copy<num_threads, 16, false, 128>(
async_copy_op, make_shape(BM, BK));
auto tiled_copy_b = make_tiled_copy<num_threads, 16, false, 128>(
async_copy_op, make_shape(BN, BK));
auto sync_copy_op = Copy_Atom<UniversalCopy<uint128_t>, bf16>{};
auto tiled_copy_c = make_tiled_copy<num_threads, 16, false, 128>(
sync_copy_op, make_shape(BM, BN));
auto mma_op = SM80_16x8x16_F32BF16BF16F32_TN{};
auto tiled_mma = make_tiled_mma(
mma_op, Layout<cute::Shape<TM, TN, TK>>{}, Tile<_32, _32, _16>{});
auto kernel = matmul_kernel<
bf16,
decltype(SA),
decltype(SB),
decltype(SC),
decltype(tiled_copy_a),
decltype(tiled_copy_b),
decltype(tiled_copy_c),
decltype(tiled_mma),
GM.value>;
configure_matmul(kernel, smem_size);
dim3 block(size(tiled_mma));
dim3 grid(
size(ceil_div(M, BM) * GM), size(ceil_div(ceil_div(N, BN), GM)));
enc.add_kernel_node(
kernel,
grid,
block,
smem_size,
a.data<bf16>(),
b.data<bf16>(),
out.data<bf16>(),
SA,
SB,
SC,
tiled_copy_a,
tiled_copy_b,
tiled_copy_c,
tiled_mma,
M,
N,
K);
} else {
throw std::runtime_error("Only bfloat16 supported");
}
});
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,18 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
void cutlass_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc);
}

View File

@@ -0,0 +1,69 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/steel/gemm.cuh"
#include "mlx/dtype_utils.h"
#include <iostream>
namespace mlx::core::cu {
namespace {
template <typename Kernel>
static void configure_smem(Kernel kernel, int SM) {
static bool done = false;
if (done) {
return;
}
std::cout << "configuring" << std::endl;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SM);
cudaFuncSetAttribute(
kernel,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared);
done = true;
}
} // namespace
void simple_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc) {
enc.set_input_array(a);
enc.set_input_array(b);
enc.set_output_array(out);
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 32;
constexpr int PIPE = 3;
constexpr int SM = PIPE * sizeof(DataType) * (BM * BK + BN * BK);
constexpr int WM = 2;
constexpr int WN = 4;
auto kernel = ab_t_aligned<DataType, BM, BN, BK, WM, WN, PIPE>;
configure_smem(kernel, SM);
dim3 grid(N / BN, M / BM);
enc.add_kernel_node(
kernel,
grid,
WM * WN * WARP_SIZE,
SM,
a.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
N,
K);
});
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,18 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
void simple_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc);
}

View File

@@ -94,7 +94,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
large ? "int64_t" : "int32_t")); large ? "int64_t" : "int32_t"));
} }
} }
return std::make_pair(jit_source_gather, std::move(kernel_names)); return std::make_tuple(false, jit_source_gather, std::move(kernel_names));
}); });
cu::KernelArgs args; cu::KernelArgs args;
@@ -189,7 +189,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
large ? "int64_t" : "int32_t")); large ? "int64_t" : "int32_t"));
} }
} }
return std::make_pair(jit_source_scatter, std::move(kernel_names)); return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));
}); });
cu::KernelArgs args; cu::KernelArgs args;
@@ -268,7 +268,8 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
} }
return std::make_pair(jit_source_gather_axis, std::move(kernel_names)); return std::make_tuple(
false, jit_source_gather_axis, std::move(kernel_names));
}); });
size_t idx_size_pre = 1; size_t idx_size_pre = 1;
@@ -371,7 +372,8 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
} }
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names)); return std::make_tuple(
false, jit_source_scatter_axis, std::move(kernel_names));
}); });
size_t idx_size_pre = 1; size_t idx_size_pre = 1;

View File

@@ -101,8 +101,8 @@ const std::filesystem::path& ptx_cache_dir() {
bool read_cached_ptx( bool read_cached_ptx(
const std::filesystem::path& cache_dir, const std::filesystem::path& cache_dir,
const std::string& module_name, const std::string& module_name,
std::vector<char>* ptx, std::string& ptx,
std::vector<std::pair<std::string, std::string>>* ptx_kernels) { std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (cache_dir.empty()) { if (cache_dir.empty()) {
return false; return false;
} }
@@ -117,15 +117,15 @@ bool read_cached_ptx(
if (!ptx_file.good()) { if (!ptx_file.good()) {
return false; return false;
} }
ptx->resize(ptx_size); ptx.resize(ptx_size);
ptx_file.read(ptx->data(), ptx_size); ptx_file.read(ptx.data(), ptx_size);
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
std::string line; std::string line;
while (std::getline(txt_file, line)) { while (std::getline(txt_file, line)) {
auto tab = line.find('\t'); auto tab = line.find('\t');
if (tab != std::string::npos) { if (tab != std::string::npos) {
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1)); ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1));
} }
} }
return true; return true;
@@ -135,7 +135,7 @@ bool read_cached_ptx(
void write_cached_ptx( void write_cached_ptx(
const std::filesystem::path& cache_dir, const std::filesystem::path& cache_dir,
const std::string& module_name, const std::string& module_name,
const std::vector<char>& ptx, const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels, const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) { const std::string& source_code) {
if (cache_dir.empty()) { if (cache_dir.empty()) {
@@ -217,22 +217,18 @@ constexpr const char* g_headers[] = {
jit_source_utils, jit_source_utils,
}; };
} // namespace void compile(
JitModule::JitModule(
Device& device, Device& device,
const std::string& module_name, const std::string& module_name,
const KernelBuilder& builder) { const std::string& source,
// Check cache. const std::vector<std::string>& kernel_names,
std::vector<char> ptx; std::string& ptx,
std::vector<std::pair<std::string, std::string>> ptx_kernels; std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) { // Create the program
// Create program.
auto [source_code, kernel_names] = builder();
nvrtcProgram prog; nvrtcProgram prog;
CHECK_NVRTC_ERROR(nvrtcCreateProgram( CHECK_NVRTC_ERROR(nvrtcCreateProgram(
&prog, &prog,
source_code.c_str(), source.c_str(),
(module_name + ".cu").c_str(), (module_name + ".cu").c_str(),
std::size(g_headers), std::size(g_headers),
g_headers, g_headers,
@@ -286,16 +282,20 @@ JitModule::JitModule(
} else { } else {
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size)); CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
} }
ptx.resize(ptx_size, 0); ptx.resize(ptx_size);
if (use_sass) { if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data())); CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
} else { } else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
} }
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
} }
void load_module(
const std::string& module_name,
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_,
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
// Load module. // Load module.
char jit_log[4089] = {}; char jit_log[4089] = {};
CUjit_option options[] = { CUjit_option options[] = {
@@ -312,21 +312,69 @@ JitModule::JitModule(
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel; CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels_[name] = kernel; kernels[name] = std::make_pair(kernel, false);
} }
} }
} // namespace
JitModule::JitModule(
Device& device,
const std::string& module_name,
const KernelBuilder& builder,
bool use_disk_cache) {
// Will hold the actual device executable source code and kernel names
std::string ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels;
// Try to load them from the file cache
if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {
auto [precompiled, source_code, kernel_names] = builder();
// Get the PTX or cubin
if (precompiled) {
ptx = std::move(source_code);
for (auto& name : kernel_names) {
ptx_kernels.emplace_back(name, name);
}
} else {
compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels);
}
// If requested save them in the file cache for the next launch
if (use_disk_cache) {
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
}
}
// Load the module
load_module(module_name, ptx, ptx_kernels, module_, kernels_);
}
JitModule::~JitModule() { JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_)); CHECK_CUDA_ERROR(cuModuleUnload(module_));
} }
CUfunction JitModule::get_kernel(const std::string& kernel_name) { CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name); auto it = kernels_.find(kernel_name);
if (it == kernels_.end()) { if (it == kernels_.end()) {
throw std::runtime_error( throw std::runtime_error(
fmt::format("There is no kernel named {}.", kernel_name)); fmt::format("There is no kernel named {}.", kernel_name));
} }
return it->second;
// If it is the first time we run this kernel then configure it. Do it only
// once!
if (!it->second.second) {
if (configure_kernel) {
configure_kernel(it->second.first);
}
it->second.second = true;
}
return it->second.first;
} }
std::unordered_map<std::string, JitModule>& get_jit_module_cache() { std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
@@ -337,11 +385,12 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,
const KernelBuilder& builder) { const KernelBuilder& builder,
bool cache) {
auto& map = get_jit_module_cache(); auto& map = get_jit_module_cache();
auto it = map.find(name); auto it = map.find(name);
if (it == map.end()) { if (it == map.end()) {
it = map.try_emplace(name, cu::device(device), name, builder).first; it = map.try_emplace(name, cu::device(device), name, builder, cache).first;
} }
return it->second; return it->second;
} }

View File

@@ -19,7 +19,8 @@ namespace mlx::core::cu {
class Device; class Device;
using KernelBuilderResult = std::pair< using KernelBuilderResult = std::tuple<
/* precompiled */ bool,
/* source code */ std::string, /* source code */ std::string,
/* kernel names */ std::vector<std::string>>; /* kernel names */ std::vector<std::string>>;
using KernelBuilder = std::function<KernelBuilderResult()>; using KernelBuilder = std::function<KernelBuilderResult()>;
@@ -63,14 +64,16 @@ struct KernelArgs {
private: private:
std::vector<void*> args_; std::vector<void*> args_;
// The cuLaunchKernel API requires passing pointers to arguments so store // The cuGraphAddKernelNode API requires passing pointers to arguments so
// temporary values untill kernel is launched. // store temporary values until the node is created.
using Arg = std::variant< using Arg = std::variant<
std::monostate, std::monostate,
CUdeviceptr, CUdeviceptr,
bool,
int32_t, int32_t,
uint32_t, uint32_t,
int64_t, int64_t,
float,
SmallVector<const void*>, SmallVector<const void*>,
SmallVector<int32_t>, SmallVector<int32_t>,
SmallVector<int64_t>>; SmallVector<int64_t>>;
@@ -82,16 +85,19 @@ class JitModule {
JitModule( JitModule(
Device& device, Device& device,
const std::string& module_name, const std::string& module_name,
const KernelBuilder& builder); const KernelBuilder& builder,
bool cache);
~JitModule(); ~JitModule();
JitModule(const JitModule&) = delete; JitModule(const JitModule&) = delete;
JitModule& operator=(const JitModule&) = delete; JitModule& operator=(const JitModule&) = delete;
CUfunction get_kernel(const std::string& kernel_name); CUfunction get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
private: private:
CUmodule module_{nullptr}; CUmodule module_{nullptr};
std::unordered_map<std::string, CUfunction> kernels_; std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
}; };
std::unordered_map<std::string, JitModule>& get_jit_module_cache(); std::unordered_map<std::string, JitModule>& get_jit_module_cache();
@@ -99,6 +105,7 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache();
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,
const KernelBuilder& builder); const KernelBuilder& builder,
bool use_disk_cache = true);
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -3,7 +3,9 @@
#include "mlx/backend/common/matmul.h" #include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/gemms/cutlass_gemm.h"
#include "mlx/backend/cuda/gemms/gemv.h" #include "mlx/backend/cuda/gemms/gemv.h"
#include "mlx/backend/cuda/gemms/simple_gemm.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -11,8 +13,14 @@
#include <numeric> #include <numeric>
namespace mlx::core { namespace mlx::core {
namespace { namespace {
int get_test_gemm() {
static int t = env::get_var("MLX_ENABLE_TEST_GEMM", 0);
return t;
}
std::tuple<bool, int64_t, array> std::tuple<bool, int64_t, array>
check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2]; auto stx = arr.strides()[arr.ndim() - 2];
@@ -95,6 +103,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
b_transposed && batch_count == 1 && get_test_gemm() == 1) {
cu::simple_gemm(a, b, out, M, N, K, encoder);
return;
}
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
b_transposed && batch_count == 1 && get_test_gemm() == 2) {
cu::cutlass_gemm(a, b, out, M, N, K, encoder);
return;
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt
CublasGemm gemm( CublasGemm gemm(

View File

@@ -1,11 +1,47 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cuda.h" #include "mlx/backend/cuda/cuda.h"
#include "mlx/fast.h"
namespace mlx::core::cu { namespace mlx::core {
namespace cu {
bool is_available() { bool is_available() {
return false; return false;
} }
} // namespace mlx::core::cu } // namespace cu
namespace fast {
CustomKernelFunction cuda_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool,
int) {
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
}
std::vector<array> precompiled_cuda_kernel(
const std::string&,
const std::string&,
const std::vector<array>&,
const std::vector<Shape>&,
const std::vector<Dtype>&,
const std::vector<ScalarArg>&,
std::tuple<int, int, int>,
std::tuple<int, int, int>,
int shared_memory,
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice) {
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
}
} // namespace fast
} // namespace mlx::core

View File

@@ -41,11 +41,8 @@ NO_GPU(Cholesky)
NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed { namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather) NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send) NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv) NO_GPU_MULTI(Recv)

View File

@@ -4,95 +4,189 @@
namespace mlx::core::cu { namespace mlx::core::cu {
template <typename T, int BM, int BN, int BK, int WM, int WN>
__device__ inline void gemm_ab_t(
RegisterTile<float, BM / WM, BN / WN>& C,
SharedTile<T, BM, BK>& As,
SharedTile<T, BN, BK>& Bs,
RegisterTileLoader<SharedTile<T, BM, BK>>& rloader_a,
RegisterTileLoader<SharedTile<T, BN, BK>>& rloader_b) {
RegisterTile<T, BM / WM, 16> A[2];
RegisterTile<T, BN / WN, 16> B[2];
rloader_a.load(A[0], As.base_addr(), 0);
rloader_b.load(B[0], Bs.base_addr(), 0);
MLX_UNROLL
for (int k = 1; k < BK / 16; k++) {
rloader_a.load(A[k & 1], As.base_addr(), k);
rloader_b.load(B[k & 1], Bs.base_addr(), k);
mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]);
}
mma_t(C, A[(BK / 16 - 1) & 1], B[(BK / 16 - 1) & 1]);
}
/** /**
* An example gemm written with the utils. * An example gemm written with the utils.
* *
* Computes A @ B.T when A and B are all aligned with the block sizes. * Computes A @ B.T when A and B are all aligned with the block sizes.
*/ */
template <typename T, int BM, int BN, int BK> // template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { //__global__ __launch_bounds__(WM * WN * WARP_SIZE, 1)
constexpr int WARPS_M = 2; // void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
constexpr int WARPS_N = 2; // constexpr int NUM_WARPS = WM * WN;
constexpr int NUM_WARPS = WARPS_M * WARPS_N; // constexpr int WARP_STEP_M = BM / WM;
constexpr int WARP_STEP_M = BM / WARPS_M; // constexpr int WARP_STEP_N = BN / WN;
constexpr int WARP_STEP_N = BN / WARPS_N; //
// // Precompute some offsets for each thread
// const int warpid = threadIdx.x / 32;
// const int laneid = threadIdx.x % 32;
// const int wm = warpid / WN;
// const int wn = warpid % WN;
// const int offset_m = wm * WARP_STEP_M;
// const int offset_n = wn * WARP_STEP_N;
//
// // Allocate shared memory
// extern __shared__ char shmem[];
// SharedTile<T, BM, BK>(&as)[PIPE] =
// *(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
// SharedTile<T, BN, BK>(&bs)[PIPE] =
// *(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
//
// // Move the global pointers to the tile
// a += blockIdx.y * BM * K;
// b += blockIdx.x * BN * K;
// y += blockIdx.y * BM * N + blockIdx.x * BN;
//
// // Make the loaders to/from SMEM
// SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>> sloader_a(a, K);
// SharedTileLoader<NUM_WARPS, SharedTile<T, BN, BK>> sloader_b(b, K);
// RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
// RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
//
// // Start the SM pipeline
// MLX_UNROLL
// for (int i = 0; i < PIPE - 1; i++) {
// sloader_a.load_async(as[i].base_addr());
// sloader_b.load_async(bs[i].base_addr());
// cp_async_commit();
// sloader_a.next();
// sloader_b.next();
// }
//
// // Allocate and zero the MMA accumulator
// RegisterTile<float, BM / WM, BN / WN> C;
// C.fill(0);
//
// // Matmul loop
// int num_blocks = K / BK;
// int sread = 0;
// int swrite = PIPE - 1;
// for (int i = 0; i < num_blocks; i++) {
// cp_async_wait<PIPE - 1>();
//
// gemm_ab_t<T, BM, BN, BK, WM, WN>(
// C, as[sread], bs[sread], rloader_a, rloader_b);
//
// sloader_a.load_async(as[swrite].base_addr());
// sloader_b.load_async(bs[swrite].base_addr());
// cp_async_commit();
// sloader_a.next(i + PIPE < num_blocks);
// sloader_b.next(i + PIPE < num_blocks);
//
// swrite = sread;
// sread = (sread + 1) % PIPE;
// }
//
// C.store_global(y, N, offset_m, offset_n);
// }
template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
__global__ __launch_bounds__(
WM* WN* WARP_SIZE,
1) void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
constexpr int NUM_WARPS = WM * WN;
constexpr int WARP_STEP_M = BM / WM;
constexpr int WARP_STEP_N = BN / WN;
// Precompute some offsets for each thread // Precompute some offsets for each thread
const int warpid = threadIdx.x / 32; const int warpid = threadIdx.x / 32;
const int laneid = threadIdx.x % 32; const int laneid = threadIdx.x % 32;
const int wm = warpid / WARPS_N; const int wm = warpid / WN;
const int wn = warpid % WARPS_N; const int wn = warpid % WN;
const int offset_m = wm * WARP_STEP_M; const int offset_m = wm * WARP_STEP_M;
const int offset_n = wn * WARP_STEP_N; const int offset_n = wn * WARP_STEP_N;
// Allocate shared memory // Allocate shared memory
extern __shared__ char shmem[]; extern __shared__ char shmem[];
SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]); SharedTile<T, BM, BK>(&as)[PIPE] =
SharedTile<T, BN, BK>(&bs)[2] = *(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
*(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]); SharedTile<T, BN, BK>(&bs)[PIPE] =
*(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
// Allocate registers for the MMA
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
RegisterTile<T, BM / WARPS_M, 16> A;
RegisterTile<T, BN / WARPS_N, 16> B;
// Move the global pointers to the tile // Move the global pointers to the tile
a += blockIdx.y * BM * K; a += blockIdx.y * BM * K;
b += blockIdx.x * BN * K; b += blockIdx.x * BN * K;
y += blockIdx.y * BM * N + blockIdx.x * BN; y += blockIdx.y * BM * N + blockIdx.x * BN;
// Zero the accumulators // Make the loaders to/from SMEM
C.fill(0); using sloader = SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>>;
constexpr int SSTEP = sloader::STEP_ROWS * sizeof(T) * BK;
const int srow = threadIdx.x / sloader::NUM_LOADS_PER_ROW;
const int scol =
(threadIdx.x % sloader::NUM_LOADS_PER_ROW) * sloader::ELEMENTS_PER_LOAD;
a += srow * K + scol;
b += srow * K + scol;
uint32_t sm_offsets[PIPE][2];
MLX_UNROLL
for (int s = 0; s < PIPE; s++) {
sm_offsets[s][0] = as[s].loc(as[s].base_addr(), srow, scol);
sm_offsets[s][1] = bs[s].loc(bs[s].base_addr(), srow, scol);
}
RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
// Start the SM pipeline // Start the SM pipeline
load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K);
load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K);
cp_async_commit();
int tic = 0;
for (int k_block = BK; k_block < K; k_block += BK) {
load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
cp_async_commit();
cp_async_wait<1>();
__syncthreads();
MLX_UNROLL MLX_UNROLL
for (int k = 0; k < BK / 16; k++) { for (int s = 0; s < PIPE - 1; s++) {
A.load( MLX_UNROLL
as[tic], for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
as[tic].base_addr(), cp_async<16>(sm_offsets[s][0] + l * SSTEP, a);
offset_m + laneid % 16, cp_async<16>(sm_offsets[s][1] + l * SSTEP, b);
k * 16 + laneid / 16 * 8); a += sloader::STEP_ROWS * K;
B.load( b += sloader::STEP_ROWS * K;
bs[tic], }
bs[tic].base_addr(), cp_async_commit();
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
} }
tic ^= 1; // Allocate and zero the MMA accumulator
} RegisterTile<float, BM / WM, BN / WN> C;
C.fill(0);
// Empty the pipeline // Matmul loop
cp_async_wait_all(); int num_blocks = K / BK;
__syncthreads(); int sread = 0;
int swrite = PIPE - 1;
for (int i = 0; i < num_blocks; i++) {
cp_async_wait<PIPE - 1>();
gemm_ab_t<T, BM, BN, BK, WM, WN>(
C, as[sread], bs[sread], rloader_a, rloader_b);
if (false) {
MLX_UNROLL MLX_UNROLL
for (int k = 0; k < BK / 16; k++) { for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
A.load( cp_async<16>(sm_offsets[swrite][0] + l * SSTEP, a);
as[tic], cp_async<16>(sm_offsets[swrite][1] + l * SSTEP, b);
as[tic].base_addr(), a += sloader::STEP_ROWS * K;
offset_m + laneid % 16, b += sloader::STEP_ROWS * K;
k * 16 + laneid / 16 * 8); }
B.load( }
bs[tic], cp_async_commit();
bs[tic].base_addr(),
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B); swrite = sread;
sread = (sread + 1) % PIPE;
} }
C.store_global(y, N, offset_m, offset_n); C.store_global(y, N, offset_m, offset_n);

View File

@@ -223,59 +223,10 @@ struct RegisterTile {
} }
}; };
/**
* A simple container of multiple Tile16x16.
*
* Provides utility functions for loading and manipulating collections of basic
* tiles.
*/
template <typename T, int ROWS_, int COLS_>
struct RegisterTile {
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
static constexpr int TILES_Y = ROWS / 16;
Tile16x16<T> data[TILES_X * TILES_Y];
__device__ inline void fill(T v) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].fill(v);
}
}
}
template <typename Tile>
__device__ inline void
load(Tile& tile, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].load(
tile.loc(base_address, row + i * 16, col + j * 16));
}
}
}
template <typename U>
__device__ inline void store_global(U* x, int N, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global(
x + (row + i * 16) * N + col + j * 16, N);
}
}
}
};
template <typename T, int ROWS_, int COLS_> template <typename T, int ROWS_, int COLS_>
struct SharedTile { struct SharedTile {
using value_type = T;
static constexpr int ROWS = ROWS_; static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_; static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16; static constexpr int TILES_X = COLS / 16;
@@ -317,23 +268,26 @@ struct SharedTile {
} }
} }
// Return the location of the element at (row, col) using the swizzle. __device__ static inline uint32_t offset(int row, int col) {
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
if constexpr (swizzle_bytes > 0) { if constexpr (swizzle_bytes > 0) {
static constexpr int swizzle_repeat = swizzle_bytes * 8; static constexpr int swizzle_repeat = swizzle_bytes * 8;
static constexpr int subtile_cols = swizzle_bytes / sizeof(T); static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
const int outer_idx = col / subtile_cols; const int outer_idx = col / subtile_cols;
const uint32_t addr = ptr + const uint32_t addr = sizeof(T) *
sizeof(T) *
(outer_idx * ROWS * subtile_cols + row * subtile_cols + (outer_idx * ROWS * subtile_cols + row * subtile_cols +
col % subtile_cols); col % subtile_cols);
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
return (addr ^ swizzle); return (addr ^ swizzle);
} else { } else {
return ptr + sizeof(T) * (row * COLS + col); return sizeof(T) * (row * COLS + col);
} }
} }
// Return the location of the element at (row, col) using the swizzle.
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
return ptr + offset(row, col);
}
// Convenience functions to edit elements going through the swizzle. // Convenience functions to edit elements going through the swizzle.
__device__ inline T& operator()(int row, int col) { __device__ inline T& operator()(int row, int col) {
return *ptr(data, row, col); return *ptr(data, row, col);
@@ -364,6 +318,76 @@ struct SharedTile {
} }
}; };
template <int NUM_WARPS, typename Tile>
struct SharedTileLoader {
using T = typename Tile::value_type;
static constexpr int NUM_THREADS = NUM_WARPS * 32;
static constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
static constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
static constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
static constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
static constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const T* x_;
int N_;
uint32_t offset_;
__device__ SharedTileLoader(const T* x, int N) : x_(x), N_(N) {
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x_ += row * N + col * ELEMENTS_PER_LOAD;
offset_ = Tile::offset(row, col * ELEMENTS_PER_LOAD);
}
__device__ inline void load_async(uint32_t base_address) {
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
cp_async<16>(
base_address + offset_ + i * STEP_ROWS * sizeof(T) * Tile::COLS,
x_ + i * STEP_ROWS * N_);
}
}
__device__ inline void next() {
x_ += Tile::COLS;
}
};
template <typename Tile>
struct RegisterTileLoader {
using T = typename Tile::value_type;
uint32_t offset_[Tile::COLS / 16];
__device__ RegisterTileLoader(int offset_row, int laneid) {
const int row = offset_row + laneid & 15;
const int col = (laneid >> 4) << 3;
MLX_UNROLL
for (int i = 0; i < Tile::COLS / 16; i++) {
offset_[i] = Tile::offset(row, col + i * 16);
}
}
template <typename T, int ROWS, int COLS>
__device__ inline void
load(RegisterTile<T, ROWS, COLS>& x, uint32_t base_address, int col) {
constexpr int TILES_Y = RegisterTile<T, ROWS, COLS>::TILES_Y;
constexpr int TILES_X = RegisterTile<T, ROWS, COLS>::TILES_X;
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
x.data[i * TILES_X + j].load(
base_address + offset_[j + col] + i * 16 * Tile::COLS * sizeof(T));
}
}
}
};
/** /**
* Load the tile from global memory by loading 16 bytes at a time and storing * Load the tile from global memory by loading 16 bytes at a time and storing
* them immediately. * them immediately.

View File

@@ -21,15 +21,15 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
#if defined(MLX_CUDA_SM_80_ENABLED) #if defined(MLX_CUDA_SM_80_ENABLED)
if constexpr (N == 16) { if constexpr (N == 16) {
asm volatile( asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address), "cp.async.cg.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int4*>(x))); "l"(reinterpret_cast<const int4*>(x)));
} else if constexpr (N == 8) { } else if constexpr (N == 8) {
asm volatile( asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address), "cp.async.cg.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int2*>(x))); "l"(reinterpret_cast<const int2*>(x)));
} else if constexpr (N == 4) { } else if constexpr (N == 4) {
asm volatile( asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address), "cp.async.cg.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int*>(x))); "l"(reinterpret_cast<const int*>(x)));
} }
#endif #endif

View File

@@ -172,7 +172,7 @@ std::string write_template(
return template_def.str(); return template_def.str();
} }
MetalKernelFunction metal_kernel( CustomKernelFunction metal_kernel(
const std::string& name, const std::string& name,
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
@@ -316,7 +316,10 @@ MetalKernelFunction metal_kernel(
threadgroup, threadgroup,
shape_infos, shape_infos,
ensure_row_contiguous, ensure_row_contiguous,
init_value), init_value,
std::vector<ScalarArg>{},
false,
0),
std::move(inputs)); std::move(inputs));
}; };
} }

View File

@@ -83,7 +83,7 @@ struct Conv2DInputBlockLoaderSmallChannels {
const constant MLXConvParams<2>* params; const constant MLXConvParams<2>* params;
const constant ImplicitGemmConv2DParams* gemm_params; const constant ImplicitGemmConv2DParams* gemm_params;
short weight_hw; int weight_hw;
const device T* src[n_rows]; const device T* src[n_rows];

View File

@@ -26,15 +26,15 @@ device_info() {
namespace fast { namespace fast {
MetalKernelFunction metal_kernel( CustomKernelFunction metal_kernel(
const std::string&, const std::string&,
const std::vector<std::string>&, const std::vector<std::string>&,
const std::vector<std::string>&, const std::vector<std::string>&,
const std::string&, const std::string&,
const std::string&, const std::string&,
bool ensure_row_contiguous, bool,
bool atomic_outputs) { bool) {
throw std::runtime_error("[metal_kernel] No GPU back-end."); throw std::runtime_error("[metal_kernel] No Metal back-end.");
} }
} // namespace fast } // namespace fast

View File

@@ -6,4 +6,3 @@ target_sources(
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)

View File

@@ -2,11 +2,9 @@
#include <unordered_map> #include <unordered_map>
#include <iostream>
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h" #include "mlx/distributed/ring/ring.h"
namespace mlx::core::distributed { namespace mlx::core::distributed {
@@ -82,7 +80,7 @@ class EmptyGroup : public GroupImpl {
} // namespace detail } // namespace detail
bool is_available() { bool is_available() {
return mpi::is_available() || ring::is_available() || nccl::is_available(); return mpi::is_available() || ring::is_available();
} }
int Group::rank() const { int Group::rank() const {
@@ -113,8 +111,6 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = mpi::init(strict); group = mpi::init(strict);
} else if (bk == "ring") { } else if (bk == "ring") {
group = ring::init(strict); group = ring::init(strict);
} else if (bk == "nccl") {
group = nccl::init(strict);
} else if (bk == "any") { } else if (bk == "any") {
group = ring::init(false); group = ring::init(false);
bk_ = "ring"; bk_ = "ring";

View File

@@ -3,6 +3,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include "mlx/array.h" #include "mlx/array.h"
namespace mlx::core::distributed { namespace mlx::core::distributed {

View File

@@ -1,8 +0,0 @@
if(MLX_BUILD_CUDA)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
find_package(NCCL REQUIRED)
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
endif()

View File

@@ -1,354 +0,0 @@
#include <arpa/inet.h>
#include <cuda_runtime.h>
#include <nccl.h>
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <mutex>
#include <stdexcept>
#include <string>
#include <type_traits>
#include "mlx/backend/cuda/device.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/dtype_utils.h"
namespace mlx::core::distributed::nccl {
#define CHECK_CUDA(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
fprintf( \
stderr, \
"CUDA error %s:%d '%s'\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString(e)); \
exit(1); \
} \
} while (0)
#define CHECK_NCCL(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
fprintf( \
stderr, \
"NCCL error %s:%d '%s'\n", \
__FILE__, \
__LINE__, \
ncclGetErrorString(r)); \
exit(1); \
} \
} while (0)
#define MLX_NCCL_TYPE_LIST(X) \
X(int8_t, ncclChar) \
X(uint8_t, ncclUint8) \
X(int32_t, ncclInt) \
X(uint32_t, ncclUint32) \
X(int64_t, ncclInt64) \
X(uint64_t, ncclUint64) \
X(float16_t, ncclHalf) \
X(bfloat16_t, ncclBfloat16) \
X(float, ncclFloat) \
X(double, ncclDouble)
template <class>
struct nccl_map {
static constexpr bool ok = false; // default: unsupported
};
#define MLX_DEF_NCCL_MAP(T, E) \
template <> \
struct nccl_map<T> { \
static constexpr bool ok = true; \
static constexpr ncclDataType_t value = E; \
};
MLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP)
#undef MLX_DEF_NCCL_MAP
namespace detail {
template <typename F>
void dispatch_dtype(const array& arr, F&& f) {
dispatch_all_types(arr.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
if constexpr (nccl_map<T>::ok) {
f(type_tag, nccl_map<T>::value);
} else {
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
}
});
}
inline void sendAll(int sock, const void* buf, size_t len) {
const char* ptr = reinterpret_cast<const char*>(buf);
while (len > 0) {
ssize_t sent = send(sock, ptr, len, 0);
if (sent <= 0) {
perror("send");
exit(1);
}
ptr += sent;
len -= sent;
}
}
inline void recvAll(int sock, void* buf, size_t len) {
char* ptr = reinterpret_cast<char*>(buf);
while (len > 0) {
ssize_t rec = recv(sock, ptr, len, 0);
if (rec <= 0) {
perror("recv");
exit(1);
}
ptr += rec;
len -= rec;
}
}
inline void bootstrap_unique_id(
ncclUniqueId& id,
int rank,
int size,
const std::string& initMethod) {
// Parse the init method to extract the host and port
if (initMethod.rfind("tcp://", 0) != 0)
throw;
auto hostport = initMethod.substr(6);
auto colon = hostport.find(':');
std::string host = hostport.substr(0, colon);
int port = std::stoi(hostport.substr(colon + 1));
if (rank == 0) {
// create a unique id on the rank 0
CHECK_NCCL(ncclGetUniqueId(&id));
// create a socket to send the unique id to all other ranks
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockaddr_in serv = {};
serv.sin_family = AF_INET;
serv.sin_addr.s_addr = htonl(INADDR_ANY);
serv.sin_port = htons(port);
int reuse = 1;
// Without this, if rank-0 crashes or restarts process quickly,
// the OS might refuse to let binding to the same port, so reuse
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
std::ostringstream msg;
msg << "[nccl] setsockopt() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
std::ostringstream msg;
msg << "[nccl] bind() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
if (listen(sock, size - 1) < 0) {
std::ostringstream msg;
msg << "[nccl] listen() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
for (int peer = 1; peer < size; ++peer) {
int conn = accept(sock, nullptr, nullptr);
if (conn < 0) {
std::ostringstream msg;
msg << "[nccl] accept() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
sendAll(conn, &id, sizeof(id));
close(conn);
}
close(sock);
} else {
// Here just wanted to make show that rank 0 has enough time to bind
// so we will retry to connect until max attempts
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[nccl] socket() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
hostent* he = gethostbyname(host.c_str());
if (!he) {
throw std::runtime_error("[nccl] lookup failed for host: " + host);
}
sockaddr_in serv = {};
serv.sin_family = AF_INET;
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
serv.sin_port = htons(port);
const int max_retries = 30;
int attempt = 0;
bool connected = false;
for (attempt = 0; attempt < max_retries; ++attempt) {
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
0) {
connected = true;
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
<< attempt + 1 << std::endl;
break;
}
if (errno != ECONNREFUSED) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
if (!connected) {
std::ostringstream msg;
msg << "[Rank " << rank << "] connect() failed after " << attempt
<< " retries: " << strerror(errno);
close(sock);
throw std::runtime_error(msg.str());
}
recvAll(sock, &id, sizeof(id));
close(sock);
}
}
} // namespace detail
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
class NCCLGroup : public GroupImpl {
public:
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
: rank_(worldRank),
size_(worldSize),
comm_(nullptr),
initMethod_(initMethod) {
if (initialized_)
return;
int ndev;
CHECK_CUDA(cudaGetDeviceCount(&ndev));
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
initialized_ = true;
}
~NCCLGroup() {
ncclCommDestroy(comm_);
ncclGroupEnd();
initialized_ = false;
}
int rank() override {
return rank_;
}
int size() override {
return size_;
}
void all_sum(const array& input, array& output, Stream stream) override {
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
});
}
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[nccl] Group split not supported.");
}
void all_gather(const array& input, array& output, Stream stream) override {
throw std::runtime_error(
"[nccl] All gather not supported in NCCL backend.");
}
void send(const array& input, int dst, Stream stream) override {
throw std::runtime_error("[nccl] Send not supported in NCCL backend.");
}
void recv(array& output, int src, Stream stream) override {
throw std::runtime_error("[nccl] Recv not supported in NCCL backend.");
}
void all_max(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
}
void all_min(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
}
template <typename T>
void all_reduce_impl(
const array& input,
array& output,
Stream stream,
ncclDataType_t dt,
ncclRedOp_t op) {
auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclAllReduce(
input.data<T>(),
output.data<T>(),
input.size(),
dt,
op,
comm_,
encoder.stream()));
}
int rank_, size_;
std::string initMethod_;
ncclUniqueId uniqueId_;
ncclComm_t comm_;
bool initialized_ = false;
};
bool is_available() {
return true;
}
namespace detail {
static std::string get_env_var_or_throw(const char* env_var_name) {
const char* value = std::getenv(env_var_name);
if (value == nullptr) {
std::ostringstream msg;
msg << "[nccl] Required environment variable '" << env_var_name
<< "' is not set. "
<< "Please set it before initializing the distributed backend.";
throw std::runtime_error(msg.str());
}
return std::string(value);
}
} // namespace detail
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
int rank = std::stoi(rank_str);
int n_nodes = std::stoi(n_nodes_str);
std::string init_method = "tcp://" + host + ":" + port;
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
}
} // namespace mlx::core::distributed::nccl

View File

@@ -1,12 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::nccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::nccl

View File

@@ -1,20 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/nccl/nccl.h"
namespace mlx::core::distributed::nccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available() {
return false;
}
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) {
throw std::runtime_error("Cannot initialize nccl distributed backend.");
}
return nullptr;
}
} // namespace mlx::core::distributed::nccl

View File

@@ -2,20 +2,9 @@
#include <sstream> #include <sstream>
#include "mlx/backend/cuda/cuda.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/distributed/ops.h" #include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h" #include "mlx/distributed/primitives.h"
inline mlx::core::Device get_device() {
if (mlx::core::metal::is_available()) {
return mlx::core::Device::cpu;
} else if (mlx::core::cu::is_available()) {
return mlx::core::Device::gpu;
}
throw std::runtime_error("No available device for distributed operations.");
}
namespace mlx::core::distributed { namespace mlx::core::distributed {
namespace { namespace {
@@ -35,7 +24,6 @@ array all_sum(
std::optional<Group> group_ /* = std::nullopt */, std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
auto group = to_group(group_); auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) { if (group.size() == 1) {
return x; return x;
@@ -43,7 +31,8 @@ array all_sum(
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Sum), std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Sum),
{x}); {x});
} }
@@ -52,7 +41,6 @@ array all_max(
std::optional<Group> group_ /* = std::nullopt */, std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
auto group = to_group(group_); auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) { if (group.size() == 1) {
return x; return x;
@@ -60,7 +48,8 @@ array all_max(
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Max), std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Max),
{x}); {x});
} }
@@ -69,7 +58,6 @@ array all_min(
std::optional<Group> group_ /* = std::nullopt */, std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
auto group = to_group(group_); auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) { if (group.size() == 1) {
return x; return x;
@@ -77,7 +65,8 @@ array all_min(
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Min), std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Min),
{x}); {x});
} }
@@ -86,7 +75,6 @@ array all_gather(
std::optional<Group> group_ /* = std::nullopt */, std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
auto group = to_group(group_); auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) { if (group.size() == 1) {
return x; return x;
@@ -101,7 +89,7 @@ array all_gather(
return array( return array(
std::move(result_shape), std::move(result_shape),
x.dtype(), x.dtype(),
std::make_shared<AllGather>(to_stream(s, dev), group), std::make_shared<AllGather>(to_stream(s, Device::cpu), group),
{x}); {x});
} }
@@ -111,7 +99,6 @@ array send(
std::optional<Group> group_ /* = std::nullopt */, std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
auto group = to_group(group_); auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) { if (group.size() == 1) {
throw std::invalid_argument("Cannot send to a singleton group"); throw std::invalid_argument("Cannot send to a singleton group");
@@ -127,7 +114,7 @@ array send(
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
std::make_shared<Send>(to_stream(s, dev), group, dst), std::make_shared<Send>(to_stream(s, Device::cpu), group, dst),
{x}); {x});
} }
@@ -138,7 +125,6 @@ array recv(
std::optional<Group> group_ /* = std::nullopt */, std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
auto group = to_group(group_); auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) { if (group.size() == 1) {
throw std::invalid_argument("Cannot recv from a singleton group"); throw std::invalid_argument("Cannot recv from a singleton group");
@@ -153,7 +139,7 @@ array recv(
return array( return array(
std::move(shape), std::move(shape),
std::move(dtype), std::move(dtype),
std::make_shared<Recv>(to_stream(s, dev), group, src), std::make_shared<Recv>(to_stream(s, Device::cpu), group, src),
std::vector<array>{}); std::vector<array>{});
} }

View File

@@ -66,9 +66,10 @@ array affine_dequantize(
int bits = 4, int bits = 4,
StreamOrDevice s = {}); StreamOrDevice s = {});
typedef std::variant<int, bool, Dtype> TemplateArg; using TemplateArg = std::variant<int, bool, Dtype>;
using ScalarArg = std::variant<bool, int, float>;
typedef std::function<std::vector<array>( using CustomKernelFunction = std::function<std::vector<array>(
const std::vector<array>&, const std::vector<array>&,
const std::vector<Shape>&, const std::vector<Shape>&,
const std::vector<Dtype>&, const std::vector<Dtype>&,
@@ -77,10 +78,9 @@ typedef std::function<std::vector<array>(
std::vector<std::pair<std::string, TemplateArg>>, std::vector<std::pair<std::string, TemplateArg>>,
std::optional<float>, std::optional<float>,
bool, bool,
StreamOrDevice)> StreamOrDevice)>;
MetalKernelFunction;
MetalKernelFunction metal_kernel( CustomKernelFunction metal_kernel(
const std::string& name, const std::string& name,
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
@@ -89,4 +89,27 @@ MetalKernelFunction metal_kernel(
bool ensure_row_contiguous = true, bool ensure_row_contiguous = true,
bool atomic_outputs = false); bool atomic_outputs = false);
CustomKernelFunction cuda_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header = "",
bool ensure_row_contiguous = true,
int shared_memory = 0);
std::vector<array> precompiled_cuda_kernel(
const std::string& name,
const std::string& compiled_source,
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
const std::vector<ScalarArg>& scalars,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory = 0,
std::optional<float> init_value = std::nullopt,
bool ensure_row_contiguous = false,
StreamOrDevice s = {});
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <optional> #include <optional>
#include <variant>
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -283,6 +284,8 @@ struct CustomKernelShapeInfo {
bool ndim = false; bool ndim = false;
}; };
using ScalarArg = std::variant<bool, int, float>;
class CustomKernel : public Primitive { class CustomKernel : public Primitive {
public: public:
CustomKernel( CustomKernel(
@@ -293,7 +296,10 @@ class CustomKernel : public Primitive {
std::tuple<int, int, int> threadgroup, std::tuple<int, int, int> threadgroup,
std::vector<CustomKernelShapeInfo> shape_infos, std::vector<CustomKernelShapeInfo> shape_infos,
bool ensure_row_contiguous, bool ensure_row_contiguous,
std::optional<float> init_value) std::optional<float> init_value,
std::vector<ScalarArg> scalar_arguments,
bool is_precompiled,
int shared_memory)
: Primitive(stream), : Primitive(stream),
source_(std::move(source)), source_(std::move(source)),
name_(std::move(name)), name_(std::move(name)),
@@ -301,11 +307,14 @@ class CustomKernel : public Primitive {
threadgroup_(threadgroup), threadgroup_(threadgroup),
shape_infos_(std::move(shape_infos)), shape_infos_(std::move(shape_infos)),
ensure_row_contiguous_(ensure_row_contiguous), ensure_row_contiguous_(ensure_row_contiguous),
init_value_(init_value) {} init_value_(init_value),
scalar_arguments_(std::move(scalar_arguments)),
is_precompiled_(is_precompiled),
shared_memory_(shared_memory) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("Custom Metal kernels only run on GPU."); throw std::runtime_error("Custom kernels only run on GPU.");
} }
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
@@ -321,6 +330,9 @@ class CustomKernel : public Primitive {
std::vector<CustomKernelShapeInfo> shape_infos_; std::vector<CustomKernelShapeInfo> shape_infos_;
bool ensure_row_contiguous_; bool ensure_row_contiguous_;
std::optional<float> init_value_; std::optional<float> init_value_;
std::vector<ScalarArg> scalar_arguments_;
bool is_precompiled_;
int shared_memory_;
}; };
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@@ -415,48 +415,6 @@ def launch_mpi(parser, hosts, args, command):
pass pass
def launch_nccl(parser, hosts, args, command):
master_host = hosts[0].ips[0]
if master_host != "127.0.0.1":
raise ValueError("The NCCL backend only supports localhost for now. ")
master_port = args.nccl_port
world_size = len(hosts)
base_env = os.environ.copy()
base_env.update(
{
"NCCL_DEBUG": "INFO",
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
"NCCL_HOST_IP": master_host,
"NCCL_PORT": str(master_port),
"MLX_WORLD_SIZE": str(world_size),
}
)
procs = []
try:
for rank in range(world_size):
env = base_env.copy()
env["MLX_RANK"] = str(rank)
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
p = Popen(command, env=env)
procs.append(p)
for p in procs:
ret = p.wait()
if ret != 0:
raise RuntimeError(f"Rank process exited with {ret}")
except (RuntimeError, KeyboardInterrupt) as err:
for p in procs:
if p.poll() is None:
try:
p.kill()
except Exception:
pass
raise
def check_ssh_connections(hosts): def check_ssh_connections(hosts):
results = [False] * len(hosts) results = [False] * len(hosts)
@@ -707,7 +665,7 @@ def distributed_config():
) )
parser.add_argument( parser.add_argument(
"--backend", "--backend",
choices=["ring", "mpi", "nccl"], choices=["ring", "mpi"],
default="ring", default="ring",
help="Which distributed backend to configure", help="Which distributed backend to configure",
) )
@@ -779,7 +737,7 @@ def main():
parser.add_argument("--hostfile", help="The file containing the hosts") parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument( parser.add_argument(
"--backend", "--backend",
choices=["ring", "mpi", "nccl"], choices=["ring", "mpi"],
default="ring", default="ring",
help="Which distributed backend to launch", help="Which distributed backend to launch",
) )
@@ -811,13 +769,6 @@ def main():
parser.add_argument( parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one" "--cwd", help="Set the working directory on each node to the provided one"
) )
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
args, rest = parser.parse_known_args() args, rest = parser.parse_known_args()
if rest[0] == "--": if rest[0] == "--":
rest.pop(0) rest.pop(0)
@@ -848,10 +799,8 @@ def main():
# Launch # Launch
if args.backend == "ring": if args.backend == "ring":
launch_ring(parser, hosts, args, rest) launch_ring(parser, hosts, args, rest)
if args.backend == "mpi": elif args.backend == "mpi":
launch_mpi(parser, hosts, args, rest) launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -76,7 +76,6 @@ def average_gradients(
group: Optional[mx.distributed.Group] = None, group: Optional[mx.distributed.Group] = None,
all_reduce_size: int = 32 * 1024**2, all_reduce_size: int = 32 * 1024**2,
communication_type: Optional[mx.Dtype] = None, communication_type: Optional[mx.Dtype] = None,
stream: mx.Stream = mx.cpu,
): ):
"""Average the gradients across the distributed processes in the passed group. """Average the gradients across the distributed processes in the passed group.
@@ -95,7 +94,6 @@ def average_gradients(
communication_type (Optional[mlx.core.Dtype]): If provided cast to this communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``. smaller float to reduce the communication size. Default: ``None``.
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
""" """
group = group or mx.distributed.init() group = group or mx.distributed.init()
N = group.size() N = group.size()
@@ -106,7 +104,7 @@ def average_gradients(
def _average(x): def _average(x):
dt = x.dtype dt = x.dtype
x = x.astype(communication_type) if communication_type is not None else x x = x.astype(communication_type) if communication_type is not None else x
return mx.distributed.all_sum(x, stream=stream).astype(dt) / N return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
if all_reduce_size <= 0: if all_reduce_size <= 0:
return tree_map(_average, gradients) return tree_map(_average, gradients)

View File

@@ -17,6 +17,7 @@ nanobind_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp

19
python/src/cuda.cpp Normal file
View File

@@ -0,0 +1,19 @@
// Copyright © 2023-2025 Apple Inc.
#include <nanobind/nanobind.h>
#include "mlx/backend/cuda/cuda.h"
namespace mx = mlx::core;
namespace nb = nanobind;
void init_cuda(nb::module_& m) {
nb::module_ cuda = m.def_submodule("cuda", "mlx.cuda");
cuda.def(
"is_available",
&mx::cu::is_available,
R"pbdoc(
Check if the CUDA back-end is available.
)pbdoc");
}

View File

@@ -79,7 +79,7 @@ void init_distributed(nb::module_& parent_module) {
in case ``mx.distributed.is_available()`` returns False otherwise in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False`` it throws a runtime error. Default: ``False``
backend (str, optional): Which distributed backend to initialize. backend (str, optional): Which distributed backend to initialize.
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all Possible values ``mpi``, ``ring``, ``any``. If set to ``any`` all
available backends are tried and the first one that succeeds available backends are tried and the first one that succeeds
becomes the global group which will be returned in subsequent becomes the global group which will be returned in subsequent
calls. Default: ``any`` calls. Default: ``any``

View File

@@ -17,6 +17,66 @@ namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
namespace {
struct PyCustomKernelFunction {
PyCustomKernelFunction(mx::fast::CustomKernelFunction kernel, const char* tag)
: kernel_(std::move(kernel)), tag_(tag) {}
std::vector<mx::array> operator()(
const std::vector<ScalarOrArray>& inputs_,
const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::optional<std::vector<std::pair<std::string, nb::object>>>&
template_args_ = std::nullopt,
std::optional<float> init_value = std::nullopt,
bool verbose = false,
mx::StreamOrDevice s = {}) const {
std::vector<mx::array> inputs;
for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt));
}
std::vector<std::pair<std::string, mx::fast::TemplateArg>> template_args;
if (template_args_) {
for (const auto& [name, value] : template_args_.value()) {
// Handle bool, int and dtype template args
if (nb::isinstance<bool>(value)) {
bool bool_val = nb::cast<bool>(value);
template_args.emplace_back(name, bool_val);
} else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value);
template_args.emplace_back(name, int_val);
} else if (nb::isinstance<mx::Dtype>(value)) {
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
template_args.emplace_back(name, dtype);
} else {
std::ostringstream msg;
msg << tag_
<< " Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.";
throw std::invalid_argument(msg.str());
}
}
}
return kernel_(
inputs,
output_shapes,
output_dtypes,
grid,
threadgroup,
template_args,
init_value,
verbose,
s);
}
mx::fast::CustomKernelFunction kernel_;
const char* tag_;
};
} // namespace
void init_fast(nb::module_& parent_module) { void init_fast(nb::module_& parent_module) {
auto m = auto m =
parent_module.def_submodule("fast", "mlx.core.fast: fast operations"); parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
@@ -240,53 +300,7 @@ void init_fast(nb::module_& parent_module) {
ensure_row_contiguous, ensure_row_contiguous,
atomic_outputs); atomic_outputs);
return nb::cpp_function( return nb::cpp_function(
[kernel = std::move(kernel)]( PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"),
const std::vector<ScalarOrArray>& inputs_,
const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::optional<
std::vector<std::pair<std::string, nb::object>>>&
template_args_ = std::nullopt,
std::optional<float> init_value = std::nullopt,
bool verbose = false,
mx::StreamOrDevice s = {}) {
std::vector<mx::array> inputs;
for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt));
}
std::vector<std::pair<std::string, mx::fast::TemplateArg>>
template_args;
if (template_args_) {
for (const auto& [name, value] : template_args_.value()) {
// Handle bool, int and dtype template args
if (nb::isinstance<bool>(value)) {
bool bool_val = nb::cast<bool>(value);
template_args.emplace_back(name, bool_val);
} else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value);
template_args.emplace_back(name, int_val);
} else if (nb::isinstance<mx::Dtype>(value)) {
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
template_args.emplace_back(name, dtype);
} else {
throw std::invalid_argument(
"[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
}
}
}
return kernel(
inputs,
output_shapes,
output_dtypes,
grid,
threadgroup,
template_args,
init_value,
verbose,
s);
},
nb::kw_only(), nb::kw_only(),
"inputs"_a, "inputs"_a,
"output_shapes"_a, "output_shapes"_a,
@@ -384,4 +398,216 @@ void init_fast(nb::module_& parent_module) {
b = exp_elementwise(a) b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a)) assert mx.allclose(b, mx.exp(a))
)pbdoc"); )pbdoc");
m.def(
"cuda_kernel",
[](const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header,
bool ensure_row_contiguous,
int shared_mem) {
auto kernel = mx::fast::cuda_kernel(
name,
input_names,
output_names,
source,
header,
ensure_row_contiguous,
shared_mem);
return nb::cpp_function(
PyCustomKernelFunction(std::move(kernel), "[cuda_kernel]"),
nb::kw_only(),
"inputs"_a,
"output_shapes"_a,
"output_dtypes"_a,
"grid"_a,
"threadgroup"_a,
"template"_a = nb::none(),
"init_value"_a = nb::none(),
"verbose"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
R"pbdoc(
Run the kernel.
Args:
inputs (List[array]): The inputs passed to the CUDA kernel.
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadgroups.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Returns:
List[array]: The list of output arrays.)pbdoc");
},
"name"_a,
"input_names"_a,
"output_names"_a,
"source"_a,
"header"_a = "",
"ensure_row_contiguous"_a = true,
"shared_memory"_a = 0,
R"pbdoc(
A jit-compiled custom CUDA kernel defined from a source string.
This is the CUDA equivalent of :ref:`custom_metal_kernels`.
Args:
name (str): Name for the kernel.
input_names (List[str]): The parameter names of the inputs in the
function signature.
output_names (List[str]): The parameter names of the outputs in the
function signature.
source (str): Source code. This is the body of a function in CUDA,
the function signature will be automatically generated.
header (str): Header source code to include before the main function.
Useful for helper functions or includes that should live outside of
the main function body.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
shared_memory (int): The dynamic shared memory to request for the
kernel. A value of 0 means no dynamic shared memory. Default: ``0``.
Returns:
Callable ``cuda_kernel``.
Example:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = '''
auto elem = cooperative_groups::this_grid().thread_rank();
T tmp = inp[elem];
out[elem] = exp(tmp);
'''
kernel = mx.fast.cuda_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
verbose=True,
)
return outputs[0]
a = mx.random.normal(shape=(16, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
)pbdoc");
m.def(
"precompiled_cuda_kernel",
[](const std::string& name,
const nb::bytes compiled_source,
const std::vector<ScalarOrArray>& inputs_,
const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes,
const std::vector<nb::object>& scalars_,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory,
std::optional<float> init_value = std::nullopt,
bool ensure_row_contiguous = false,
mx::StreamOrDevice s = {}) {
// Collect the inputs and cast them to array
std::vector<mx::array> inputs;
for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt));
}
// Collect the scalar inputs
std::vector<mx::fast::ScalarArg> scalars;
scalars.reserve(scalars_.size());
for (const auto& v : scalars_) {
if (nb::isinstance<bool>(v)) {
scalars.push_back(nb::cast<bool>(v));
} else if (nb::isinstance<int>(v)) {
scalars.push_back(nb::cast<int>(v));
} else if (nb::isinstance<float>(v)) {
scalars.push_back(nb::cast<float>(v));
} else {
nb::object vtype = v.attr("__class__");
std::string vtype_name =
nb::cast<std::string>(vtype.attr("__name__"));
std::ostringstream msg;
msg << "[precompiled_cuda_kernel] Invalid scalar argument type. "
<< "Received " << vtype_name
<< " but must be one of bool, int or float";
throw std::invalid_argument(msg.str());
}
}
return mx::fast::precompiled_cuda_kernel(
name,
std::string(
static_cast<const char*>(compiled_source.data()),
compiled_source.size()),
inputs,
output_shapes,
output_dtypes,
scalars,
grid,
threadgroup,
shared_memory,
init_value,
ensure_row_contiguous,
s);
},
nb::kw_only(),
"name"_a,
"compiled_source"_a,
"inputs"_a,
"output_shapes"_a,
"output_dtypes"_a,
"scalars"_a,
"grid"_a,
"threadgroup"_a,
"shared_memory"_a = 0,
"init_value"_a = nb::none(),
"ensure_row_contiguous"_a = false,
"stream"_a = nb::none(),
R"pbdoc(
Run a precompiled CUDA kernel defined from PTX or cubin.
This op is still experimental and various parts of the API may change.
Args:
name (str): Name for the kernel
compiled_source (bytes): The precompiled kernel in raw bytes.
inputs (List[array]): The inputs passed to the CUDA kernel.
output_shapes (List[Sequence[int]]): The list of shapes for each output.
output_dtypes (List[Dtype]): The list of data types for each output.
scalars (List[Union[bool, int, float]]): A list of scalar arguments to
pass to the kernel.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
shared_memory (int): The dynamic shared memory to request for the
kernel. A value of 0 means no dynamic shared memory. Default: ``0``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
)pbdoc");
} }

View File

@@ -12,6 +12,7 @@ void init_array(nb::module_&);
void init_device(nb::module_&); void init_device(nb::module_&);
void init_stream(nb::module_&); void init_stream(nb::module_&);
void init_metal(nb::module_&); void init_metal(nb::module_&);
void init_cuda(nb::module_&);
void init_memory(nb::module_&); void init_memory(nb::module_&);
void init_ops(nb::module_&); void init_ops(nb::module_&);
void init_transforms(nb::module_&); void init_transforms(nb::module_&);
@@ -35,6 +36,7 @@ NB_MODULE(core, m) {
init_stream(m); init_stream(m);
init_array(m); init_array(m);
init_metal(m); init_metal(m);
init_cuda(m);
init_memory(m); init_memory(m);
init_ops(m); init_ops(m);
init_transforms(m); init_transforms(m);

View File

@@ -13,8 +13,6 @@ cuda_skip = {
# Hadamard NYI # Hadamard NYI
"TestOps.test_hadamard", "TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap", "TestOps.test_hadamard_grad_vmap",
# Convolutions NYI
"TestConv.test_1d_conv_with_2d",
# FFTs NYI # FFTs NYI
"TestFFT.test_fft", "TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_big_powers_of_two",

View File

@@ -1,284 +0,0 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx_tests
from mlx.nn.layers.distributed import shard_inplace, shard_linear
from mlx.nn.utils import average_gradients
class TestNCCLDistributed(mlx_tests.MLXTestCase):
@classmethod
def setUpClass(cls):
world = mx.distributed.init(strict=True, backend="nccl")
rank = world.rank()
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
def test_all_reduce(self):
world = mx.distributed.init()
dtypes = [
(mx.int8, 0),
(mx.uint8, 0),
(mx.int32, 0),
(mx.uint32, 0),
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
]
sizes = [
(7,),
(10,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
for dt, rtol in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
# All sum
y = mx.distributed.all_sum(x[world.rank()])
z = x.sum(0)
maxrelerror = (y - z).abs()
if rtol > 0:
maxrelerror /= z.abs()
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
def test_average_gradients(self):
original_all_sum = mx.distributed.all_sum
n_calls = 0
xtype = None
def new_all_sum(x, **kwargs):
nonlocal n_calls
nonlocal xtype
n_calls += 1
if xtype is not None:
self.assertEqual(xtype, x.dtype)
return original_all_sum(x, **kwargs)
mx.distributed.all_sum = new_all_sum
try:
grads = [mx.ones(10) for i in range(10)]
new_grads = average_gradients(grads, stream=mx.gpu)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 1)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 10)
n_calls = 0
xtype = mx.float16
new_grads = average_gradients(
grads,
all_reduce_size=2 * 50,
communication_type=mx.float16,
stream=mx.gpu,
)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)
finally:
mx.distributed.all_sum = original_all_sum
def test_donation(self):
x = mx.random.normal((1024,))
mx.eval(x)
mx.synchronize()
mx.reset_peak_memory()
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
mx.synchronize()
all_sum_only = mx.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
mx.synchronize()
all_sum_with_binary = mx.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)
def test_shard_linear(self):
# Seed the prng to have the same inputs and weights generated everywhere
mx.random.seed(0xF0F0F0F0)
# Prepare inputs
world = mx.distributed.init()
part = (
slice(None),
slice(
world.rank() * 1024 // world.size(),
(world.rank() + 1) * 1024 // world.size(),
),
)
x = mx.random.normal((4, 1024))
# Create and shard some linear layers
lin = nn.Linear(1024, 1024, bias=True)
slin1 = shard_linear(lin, "all-to-sharded")
slin2 = shard_linear(lin, "sharded-to-all")
y = lin(x)
y1 = slin1(x)
y2 = slin2(x[part])
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
# Check the backward works as expected
def dummy_loss(model, x, y):
return (model(x) * y).sum()
mod = nn.Sequential(
nn.Linear(128, 128),
nn.Linear(128, 128),
nn.Linear(128, 128),
nn.Linear(128, 128),
)
smod = nn.Sequential(
shard_linear(mod.layers[0], "all-to-sharded"),
shard_linear(mod.layers[1], "sharded-to-all"),
shard_linear(mod.layers[2], "all-to-sharded"),
shard_linear(mod.layers[3], "sharded-to-all"),
)
grad1 = nn.value_and_grad(mod, dummy_loss)
grad2 = nn.value_and_grad(smod, dummy_loss)
x = mx.random.normal((4, 128))
y = mx.random.normal((4, 128))
l1, g1 = grad1(mod, x, y)
l2, g2 = grad2(smod, x, y)
mx.eval(l1, g1, l2, g2)
part = slice(
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
)
self.assertTrue(mx.allclose(l1, l2))
self.assertTrue(
mx.allclose(
g1["layers"][0]["weight"][part],
g2["layers"][0]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][2]["weight"][part],
g2["layers"][2]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][1]["weight"][:, part],
g2["layers"][1]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][3]["weight"][:, part],
g2["layers"][3]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][0]["bias"][part],
g2["layers"][0]["bias"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][2]["bias"][part],
g2["layers"][2]["bias"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
)
)
self.assertTrue(
mx.allclose(
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
)
)
def test_shard_predicate(self):
mx.random.seed(0xF0F0F0F0)
class MyConv(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.aggregate = kwargs.pop("aggregate", False)
self.conv = nn.Conv2d(*args, **kwargs)
def __call__(self, x):
x = self.conv(x)
if self.aggregate:
x = mx.distributed.all_sum(x)
return x
def sharding(path, weight):
parts = path.split(".")
even = int(parts[1]) % 2 == 0
if even:
return 0
else:
return -1 if parts[-1] != "bias" else None
mod = nn.Sequential(
MyConv(3, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3),
MyConv(128, 3, kernel_size=3),
)
smod = nn.Sequential(
MyConv(3, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3, aggregate=True),
MyConv(128, 128, kernel_size=3),
MyConv(128, 3, kernel_size=3, aggregate=True),
)
smod.update(mod.parameters())
shard_inplace(smod, sharding)
x = mx.random.normal((4, 16, 16, 3))
y1 = mod(x)
y2 = smod(x)
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()

View File

@@ -1186,6 +1186,13 @@ class TestConv(mlx_tests.MLXTestCase):
y_hat = mx.conv2d(x, w) y_hat = mx.conv2d(x, w)
self.assertTrue(mx.allclose(y, y_hat)) self.assertTrue(mx.allclose(y, y_hat))
def test_conv2d_large_filter_small_channels(self):
x = mx.random.normal(shape=(1, 181, 181, 1))
w = mx.random.normal(shape=(1, 182, 182, 1))
y = mx.conv2d(x, w, (1, 1), (1, 1), stream=mx.cpu)
y_hat = mx.conv2d(x, w, (1, 1), (1, 1))
self.assertTrue(mx.allclose(y, y_hat, rtol=1e-3, atol=1e-3))
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()

View File

@@ -581,18 +581,28 @@ class TestFast(mlx_tests.MLXTestCase):
)(x) )(x)
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out)) self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_basic(self): def test_custom_kernel_basic(self):
mx.random.seed(7) if mx.metal.is_available():
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
name="basic",
input_names=["a"],
output_names=["out1"],
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
out1[elem] = a[elem]; out1[elem] = a[elem];
""", """
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
out1[elem] = a[elem];
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(2, 2))
kernel = custom_kernel(
name="basic",
input_names=["a"],
output_names=["out1"],
source=source,
) )
out = kernel( out = kernel(
inputs=[a], inputs=[a],
@@ -604,16 +614,9 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(out[0], a)) self.assertTrue(mx.allclose(out[0], a))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_args(self): def test_custom_kernel_args(self):
mx.random.seed(7) if mx.metal.is_available():
a = mx.random.normal(shape=(3, 6))
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
kernel = mx.fast.metal_kernel(
name="arg_test",
input_names=["a", "b", "c", "d"],
output_names=["out1", "out2"],
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
T tmp = a[0]; T tmp = a[0];
@@ -623,7 +626,30 @@ class TestFast(mlx_tests.MLXTestCase):
out1[elem] = 1; out1[elem] = 1;
} }
out2[elem] = a[1] + b[2] + c[1] - d; out2[elem] = a[1] + b[2] + c[1] - d;
""", """
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
T tmp = a[0];
if (e) {
out1[elem] = a[1] + b[2] + static_cast<float>(c[3]) + d[0] + f;
} else {
out1[elem] = 1;
}
out2[elem] = a[1] + b[2] + static_cast<float>(c[1]) - d[0];
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(3, 6))
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
kernel = custom_kernel(
name="arg_test",
input_names=["a", "b", "c", "d"],
output_names=["out1", "out2"],
source=source,
) )
out = kernel( out = kernel(
inputs=[ inputs=[
@@ -647,10 +673,9 @@ class TestFast(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484))) self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484)))
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32))) self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_strides(self): def test_custom_kernel_strides(self):
mx.random.seed(7) if mx.metal.is_available():
a = mx.random.normal(shape=(3, 6))
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
@@ -662,12 +687,29 @@ class TestFast(mlx_tests.MLXTestCase):
T tmp = inp[elem]; T tmp = inp[elem];
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
""" """
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
auto loc = elem_to_loc(elem, inp_shape.data(), inp_strides.data(), inp_ndim);
T tmp = inp[loc];
out[elem] = exp(tmp) * WARP_SIZE;
"""
source_contig = """
auto elem = cooperative_groups::this_grid().thread_rank();
T tmp = inp[elem];
out[elem] = exp(tmp) * WARP_SIZE;
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(3, 6))
# non contiguous # non contiguous
a = mx.tile(a[::2], [4, 1]) a = mx.tile(a[::2], [4, 1])
for contig in [True, False]: for contig in [True, False]:
kernel = mx.fast.metal_kernel( kernel = custom_kernel(
name="myexp" + str(contig), name="myexp" + str(contig),
input_names=["inp"], input_names=["inp"],
output_names=["out"], output_names=["out"],
@@ -685,24 +727,41 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0])) self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_helper(self): def test_custom_kernel_helper(self):
mx.random.seed(7) if mx.metal.is_available():
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
name="helper",
input_names=["a"],
output_names=["out1"],
header = """ header = """
template <typename T> template <typename T>
T do_exp(T x) { T do_exp(T x) {
return metal::precise::exp(x); return metal::precise::exp(x);
} }
""", """
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
out1[elem] = do_exp(a[elem]); out1[elem] = do_exp(a[elem]);
""", """
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
header = """
template <typename T>
__device__ T do_exp(T x) {
return exp(x);
}
"""
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
out1[elem] = do_exp(a[elem]);
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(2, 2))
kernel = custom_kernel(
name="helper",
input_names=["a"],
output_names=["out1"],
header=header,
source=source,
) )
out = kernel( out = kernel(
inputs=[a], inputs=[a],
@@ -714,16 +773,21 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(out[0], mx.exp(a))) self.assertTrue(mx.allclose(out[0], mx.exp(a)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_attributes(self): def test_custom_kernel_attributes(self):
if mx.metal.is_available():
source = "out[0] = threads_per_threadgroup.x;"
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = "out[0] = blockDim.x;"
custom_kernel = mx.fast.cuda_kernel
a = mx.zeros(shape=(1, 1)) a = mx.zeros(shape=(1, 1))
kernel = mx.fast.metal_kernel( kernel = custom_kernel(
name="test_fun", name="test_fun",
input_names=["a"], input_names=["a"],
output_names=["out"], output_names=["out"],
source=""" source=source,
out[0] = threads_per_threadgroup.x;
""",
) )
out = kernel( out = kernel(
inputs=[a], inputs=[a],