mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
10 Commits
simple-gem
...
3bb6b1d44a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bb6b1d44a | ||
|
|
4ee0d0bb55 | ||
|
|
cd53eb1ae3 | ||
|
|
f7c11b965e | ||
|
|
984cefb14d | ||
|
|
dadf8d9c93 | ||
|
|
389276e2b8 | ||
|
|
2e255c8eb4 | ||
|
|
062aa80b84 | ||
|
|
f540b1d612 |
54
cmake/FindNCCL.cmake
Normal file
54
cmake/FindNCCL.cmake
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
||||||
|
# directories.
|
||||||
|
|
||||||
|
set(NCCL_ROOT_DIR
|
||||||
|
$ENV{NCCL_ROOT_DIR}
|
||||||
|
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||||
|
|
||||||
|
find_path(
|
||||||
|
NCCL_INCLUDE_DIRS
|
||||||
|
NAMES nccl.h
|
||||||
|
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||||
|
|
||||||
|
if($ENV{USE_STATIC_NCCL})
|
||||||
|
message(
|
||||||
|
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||||
|
set(NCCL_LIBNAME "libnccl_static.a")
|
||||||
|
else()
|
||||||
|
set(NCCL_LIBNAME "nccl")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_library(
|
||||||
|
NCCL_LIBRARIES
|
||||||
|
NAMES ${NCCL_LIBNAME}
|
||||||
|
HINTS ${NCCL_LIB_DIR}
|
||||||
|
${NCCL_ROOT_DIR}
|
||||||
|
${NCCL_ROOT_DIR}/lib
|
||||||
|
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||||
|
${NCCL_ROOT_DIR}/lib64
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||||
|
NCCL_LIBRARIES)
|
||||||
|
|
||||||
|
if(NCCL_FOUND)
|
||||||
|
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||||
|
message(
|
||||||
|
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||||
|
file(
|
||||||
|
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||||
|
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||||
|
LIMIT_COUNT 1)
|
||||||
|
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||||
|
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||||
|
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||||
|
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||||
|
endif()
|
||||||
|
message(
|
||||||
|
STATUS
|
||||||
|
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||||
|
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||||
|
endif()
|
||||||
@@ -70,7 +70,6 @@ are the CPU and GPU.
|
|||||||
python/fft
|
python/fft
|
||||||
python/linalg
|
python/linalg
|
||||||
python/metal
|
python/metal
|
||||||
python/cuda
|
|
||||||
python/memory_management
|
python/memory_management
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
|
|||||||
@@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
|
|||||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
apt-get update -y
|
apt-get update -y
|
||||||
apt-get -y install cuda-toolkit-12-9
|
apt-get -y install cuda-toolkit-12-9
|
||||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
||||||
|
|
||||||
|
|
||||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
CUDA
|
|
||||||
=====
|
|
||||||
|
|
||||||
.. currentmodule:: mlx.core.cuda
|
|
||||||
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: _autosummary
|
|
||||||
|
|
||||||
is_available
|
|
||||||
@@ -13,4 +13,3 @@ Fast
|
|||||||
rope
|
rope
|
||||||
scaled_dot_product_attention
|
scaled_dot_product_attention
|
||||||
metal_kernel
|
metal_kernel
|
||||||
cuda_kernel
|
|
||||||
|
|||||||
@@ -20,14 +20,12 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cutlass_gemm.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simple_gemm.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
@@ -90,9 +88,6 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
|||||||
target_compile_options(mlx
|
target_compile_options(mlx
|
||||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||||
|
|
||||||
# Keep ptx around for inspection
|
|
||||||
target_compile_options(mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--keep>")
|
|
||||||
|
|
||||||
# Enable calling host constexpr functions from device. This is needed because
|
# Enable calling host constexpr functions from device. This is needed because
|
||||||
# the constexpr version of isnan is host only.
|
# the constexpr version of isnan is host only.
|
||||||
target_compile_options(
|
target_compile_options(
|
||||||
@@ -178,12 +173,3 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
|||||||
# Install CCCL headers for JIT.
|
# Install CCCL headers for JIT.
|
||||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||||
|
|
||||||
# Fetch and make available cutlass
|
|
||||||
FetchContent_Declare(
|
|
||||||
cutlass
|
|
||||||
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
|
|
||||||
GIT_TAG v4.1.0)
|
|
||||||
FetchContent_Populate(cutlass)
|
|
||||||
target_include_directories(
|
|
||||||
mlx PRIVATE $<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/include>)
|
|
||||||
|
|||||||
@@ -267,8 +267,7 @@ void Compiled::eval_gpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_tuple(
|
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||||
false, std::move(builder.os), std::move(kernel_names));
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||||
|
|||||||
@@ -23,24 +23,6 @@ inline cudnn_frontend::Tensor build_cudnn_tensor(
|
|||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
|
|
||||||
// whether a tensor is contiguous is determined with:
|
|
||||||
// shape[dim] == shape[dim + 1] * strides[dim + 1]
|
|
||||||
// So a contiguous array with singleton dims in MLX may be mistakenly treated
|
|
||||||
// as strided in cuDNN, and we work around it by normalizing the strides.
|
|
||||||
Strides normalized_strides(const array& x) {
|
|
||||||
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
|
||||||
return x.strides();
|
|
||||||
}
|
|
||||||
Strides strides = x.strides();
|
|
||||||
for (int i = x.ndim() - 2; i >= 0; --i) {
|
|
||||||
if (x.shape(i) == 1) {
|
|
||||||
strides[i] = x.shape(i + 1) * strides[i + 1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the shape and strides after transposing from NHWC to NCHW.
|
// Return the shape and strides after transposing from NHWC to NCHW.
|
||||||
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
||||||
assert(shape.size() >= 3);
|
assert(shape.size() >= 3);
|
||||||
@@ -51,9 +33,8 @@ auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
|||||||
return std::make_tuple(std::move(shape), std::move(strides));
|
return std::make_tuple(std::move(shape), std::move(strides));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline auto nhwc_to_nchw(const array& x) {
|
auto nhwc_to_nchw(const array& x) {
|
||||||
return nhwc_to_nchw(
|
return nhwc_to_nchw(convert_vector<int64_t>(x.shape()), x.strides());
|
||||||
convert_vector<int64_t>(x.shape()), normalized_strides(x));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return available engines for a |op_graph|.
|
// Return available engines for a |op_graph|.
|
||||||
@@ -159,7 +140,7 @@ bool prepare_cudnn_plan(
|
|||||||
|
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
|
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
|
||||||
auto shape = convert_vector<int64_t>(x.shape());
|
auto shape = convert_vector<int64_t>(x.shape());
|
||||||
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
|
return build_cudnn_tensor(id, x, shape, x.strides());
|
||||||
}
|
}
|
||||||
|
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
|
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
|
||||||
@@ -179,8 +160,7 @@ cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
|
|||||||
return build_cudnn_tensor(id, x, shape, strides);
|
return build_cudnn_tensor(id, x, shape, strides);
|
||||||
}
|
}
|
||||||
if (x.ndim() == 2) {
|
if (x.ndim() == 2) {
|
||||||
int64_t s =
|
int64_t s = x.strides(0);
|
||||||
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
|
|
||||||
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
|
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
|
||||||
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
|
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
|
||||||
return build_cudnn_tensor(id, x, shape, strides);
|
return build_cudnn_tensor(id, x, shape, strides);
|
||||||
|
|||||||
@@ -1,379 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
|
||||||
#include "mlx/backend/cuda/jit_module.h"
|
|
||||||
#include "mlx/backend/cuda/utils.h"
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
|
||||||
#include "mlx/fast.h"
|
|
||||||
#include "mlx/fast_primitives.h"
|
|
||||||
|
|
||||||
#include <fmt/format.h>
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
|
|
||||||
namespace mlx::core::fast {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
constexpr const char* default_header = R"(
|
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
|
|
||||||
#define inf cuda::std::numeric_limits<float>::infinity()
|
|
||||||
|
|
||||||
)";
|
|
||||||
|
|
||||||
std::string template_arguments_hash(
|
|
||||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
|
||||||
if (template_args.empty()) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string hash;
|
|
||||||
hash.reserve(512);
|
|
||||||
|
|
||||||
for (const auto& [name, arg] : template_args) {
|
|
||||||
if (std::holds_alternative<int>(arg)) {
|
|
||||||
hash += fmt::format("_{}", std::get<int>(arg));
|
|
||||||
} else if (std::holds_alternative<bool>(arg)) {
|
|
||||||
hash += (std::get<bool>(arg)) ? "_t" : "_f";
|
|
||||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
|
||||||
hash += "_";
|
|
||||||
hash += get_type_string(std::get<Dtype>(arg));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return hash;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string build_kernel(
|
|
||||||
const std::string& func_name,
|
|
||||||
const std::string& header,
|
|
||||||
const std::string& source,
|
|
||||||
const std::vector<std::string>& input_names,
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<std::string>& output_names,
|
|
||||||
const std::vector<Dtype>& output_dtypes,
|
|
||||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
|
||||||
const std::vector<CustomKernelShapeInfo>& shape_infos) {
|
|
||||||
std::string kernel_source;
|
|
||||||
kernel_source.reserve(header.size() + source.size() + 8192);
|
|
||||||
kernel_source += default_header;
|
|
||||||
kernel_source += header;
|
|
||||||
kernel_source +=
|
|
||||||
"namespace mlx::core::cu {\n\n"
|
|
||||||
"namespace cg = cooperative_groups;\n\n";
|
|
||||||
|
|
||||||
kernel_source += "__global__ void ";
|
|
||||||
kernel_source += func_name;
|
|
||||||
kernel_source += "(\n";
|
|
||||||
|
|
||||||
// Add inputs
|
|
||||||
for (int i = 0; i < inputs.size(); ++i) {
|
|
||||||
const auto& name = input_names[i];
|
|
||||||
const auto& arr = inputs[i];
|
|
||||||
kernel_source += " const ";
|
|
||||||
kernel_source += dtype_to_cuda_type(arr.dtype());
|
|
||||||
kernel_source += "* ";
|
|
||||||
kernel_source += name;
|
|
||||||
kernel_source += ",\n";
|
|
||||||
// Add input shape, strides and ndim if present in the source
|
|
||||||
if (arr.ndim() > 0) {
|
|
||||||
if (shape_infos[i].shape) {
|
|
||||||
kernel_source += " const __grid_constant__ Shape ";
|
|
||||||
kernel_source += name;
|
|
||||||
kernel_source += "_shape,\n";
|
|
||||||
}
|
|
||||||
if (shape_infos[i].strides) {
|
|
||||||
kernel_source += " const __grid_constant__ Strides ";
|
|
||||||
kernel_source += name;
|
|
||||||
kernel_source += "_strides,\n";
|
|
||||||
}
|
|
||||||
if (shape_infos[i].ndim) {
|
|
||||||
kernel_source += " const __grid_constant__ int ";
|
|
||||||
kernel_source += name;
|
|
||||||
kernel_source += "_ndim,\n";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add outputs
|
|
||||||
for (int i = 0; i < output_names.size(); ++i) {
|
|
||||||
const auto& name = output_names[i];
|
|
||||||
const auto& dtype = output_dtypes[i];
|
|
||||||
kernel_source += " ";
|
|
||||||
kernel_source += dtype_to_cuda_type(dtype);
|
|
||||||
kernel_source += "* ";
|
|
||||||
kernel_source += name;
|
|
||||||
if (i < output_names.size() - 1) {
|
|
||||||
kernel_source += ",\n";
|
|
||||||
} else {
|
|
||||||
kernel_source += ") {\n";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set compile time constants
|
|
||||||
if (!template_args.empty()) {
|
|
||||||
for (const auto& [name, arg] : template_args) {
|
|
||||||
if (std::holds_alternative<int>(arg)) {
|
|
||||||
kernel_source +=
|
|
||||||
fmt::format(" constexpr int {} = {};\n", name, std::get<int>(arg));
|
|
||||||
} else if (std::holds_alternative<bool>(arg)) {
|
|
||||||
kernel_source += fmt::format(
|
|
||||||
" constexpr bool {} = {};\n", name, std::get<bool>(arg));
|
|
||||||
} else {
|
|
||||||
kernel_source += fmt::format(
|
|
||||||
" using {} = {};\n",
|
|
||||||
name,
|
|
||||||
dtype_to_cuda_type(std::get<Dtype>(arg)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
kernel_source += "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel_source += source;
|
|
||||||
kernel_source += "\n}\n\n} // namespace mlx::core::cu\n";
|
|
||||||
|
|
||||||
return kernel_source;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
CustomKernelFunction cuda_kernel(
|
|
||||||
const std::string& name,
|
|
||||||
const std::vector<std::string>& input_names,
|
|
||||||
const std::vector<std::string>& output_names,
|
|
||||||
const std::string& source,
|
|
||||||
const std::string& header,
|
|
||||||
bool ensure_row_contiguous,
|
|
||||||
int shared_memory) {
|
|
||||||
if (output_names.empty()) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[custom_kernel] Must specify at least one output.");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
|
||||||
for (auto& n : input_names) {
|
|
||||||
CustomKernelShapeInfo shape_info;
|
|
||||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
|
||||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
|
||||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
|
||||||
shape_infos.push_back(shape_info);
|
|
||||||
}
|
|
||||||
|
|
||||||
return [=, shape_infos = std::move(shape_infos)](
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<Shape>& output_shapes,
|
|
||||||
const std::vector<Dtype>& output_dtypes,
|
|
||||||
std::tuple<int, int, int> grid,
|
|
||||||
std::tuple<int, int, int> threadgroup,
|
|
||||||
const std::vector<std::pair<std::string, TemplateArg>>&
|
|
||||||
template_args = {},
|
|
||||||
std::optional<float> init_value = std::nullopt,
|
|
||||||
bool verbose = false,
|
|
||||||
StreamOrDevice s_ = {}) {
|
|
||||||
if (inputs.size() != input_names.size()) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[custom_kernel] Expected `inputs` to have size "
|
|
||||||
<< input_names.size() << " but got size " << inputs.size() << "."
|
|
||||||
<< std::endl;
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
if (output_shapes.size() != output_names.size()) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[custom_kernel] Expected `output_shapes` to have size "
|
|
||||||
<< output_names.size() << " but got size " << output_shapes.size()
|
|
||||||
<< "." << std::endl;
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
if (output_dtypes.size() != output_names.size()) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[custom_kernel] Expected `output_dtypes` to have size "
|
|
||||||
<< output_names.size() << " but got size " << output_dtypes.size()
|
|
||||||
<< "." << std::endl;
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto s = to_stream(s_);
|
|
||||||
if (s.device != Device::gpu) {
|
|
||||||
throw std::invalid_argument("[custom_kernel] Only supports the GPU.");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string kernel_name =
|
|
||||||
"custom_kernel_" + name + template_arguments_hash(template_args);
|
|
||||||
std::string kernel_source = build_kernel(
|
|
||||||
kernel_name,
|
|
||||||
header,
|
|
||||||
source,
|
|
||||||
input_names,
|
|
||||||
inputs,
|
|
||||||
output_names,
|
|
||||||
output_dtypes,
|
|
||||||
template_args,
|
|
||||||
shape_infos);
|
|
||||||
|
|
||||||
if (verbose) {
|
|
||||||
std::cout << "Generated source code for `" << kernel_name
|
|
||||||
<< "`:" << std::endl
|
|
||||||
<< "```" << std::endl
|
|
||||||
<< kernel_source << std::endl
|
|
||||||
<< "```" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
return array::make_arrays(
|
|
||||||
std::move(output_shapes),
|
|
||||||
std::move(output_dtypes),
|
|
||||||
std::make_shared<CustomKernel>(
|
|
||||||
s,
|
|
||||||
std::move(kernel_name),
|
|
||||||
std::move(kernel_source),
|
|
||||||
grid,
|
|
||||||
threadgroup,
|
|
||||||
shape_infos,
|
|
||||||
ensure_row_contiguous,
|
|
||||||
init_value,
|
|
||||||
std::vector<ScalarArg>{},
|
|
||||||
false,
|
|
||||||
shared_memory),
|
|
||||||
std::move(inputs));
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<array> precompiled_cuda_kernel(
|
|
||||||
const std::string& name,
|
|
||||||
const std::string& compiled_source,
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<Shape>& output_shapes,
|
|
||||||
const std::vector<Dtype>& output_dtypes,
|
|
||||||
const std::vector<ScalarArg>& scalars,
|
|
||||||
std::tuple<int, int, int> grid,
|
|
||||||
std::tuple<int, int, int> threadgroup,
|
|
||||||
int shared_memory,
|
|
||||||
std::optional<float> init_value,
|
|
||||||
bool ensure_row_contiguous,
|
|
||||||
StreamOrDevice s) {
|
|
||||||
std::vector<CustomKernelShapeInfo> shape_infos(
|
|
||||||
inputs.size(), CustomKernelShapeInfo{false, false, false});
|
|
||||||
return array::make_arrays(
|
|
||||||
output_shapes,
|
|
||||||
output_dtypes,
|
|
||||||
std::make_shared<CustomKernel>(
|
|
||||||
to_stream(s),
|
|
||||||
name,
|
|
||||||
compiled_source,
|
|
||||||
grid,
|
|
||||||
threadgroup,
|
|
||||||
shape_infos,
|
|
||||||
ensure_row_contiguous,
|
|
||||||
init_value,
|
|
||||||
scalars,
|
|
||||||
true,
|
|
||||||
shared_memory),
|
|
||||||
inputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CustomKernel::eval_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
nvtx3::scoped_range r("CustomKernel::eval_gpu");
|
|
||||||
auto& s = stream();
|
|
||||||
|
|
||||||
std::vector<array> copies;
|
|
||||||
|
|
||||||
// Allocate and initialize the output arrays
|
|
||||||
for (auto& out : outputs) {
|
|
||||||
if (init_value_) {
|
|
||||||
copies.emplace_back(init_value_.value(), out.dtype());
|
|
||||||
fill_gpu(copies.back(), out, s);
|
|
||||||
} else {
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the input arrays and copy if needed
|
|
||||||
auto check_input = [&copies, &s, this](const array& x) -> const array {
|
|
||||||
bool no_copy = x.flags().row_contiguous;
|
|
||||||
if (!ensure_row_contiguous_ || no_copy) {
|
|
||||||
return x;
|
|
||||||
} else {
|
|
||||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
|
||||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
|
||||||
return copies.back();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
std::vector<array> checked_inputs;
|
|
||||||
for (const array& in : inputs) {
|
|
||||||
checked_inputs.push_back(check_input(in));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compile the custom kernel
|
|
||||||
std::string kernel_name =
|
|
||||||
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
|
|
||||||
cu::JitModule& mod = cu::get_jit_module(
|
|
||||||
s.device,
|
|
||||||
name_,
|
|
||||||
[&]() {
|
|
||||||
return std::make_tuple(
|
|
||||||
is_precompiled_, source_, std::vector{kernel_name});
|
|
||||||
},
|
|
||||||
false);
|
|
||||||
|
|
||||||
// Make the arguments
|
|
||||||
cu::KernelArgs args;
|
|
||||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
|
||||||
const array& in = checked_inputs[i];
|
|
||||||
auto& shape_info = shape_infos_[i];
|
|
||||||
args.append(in);
|
|
||||||
if (shape_info.shape) {
|
|
||||||
args.append_ndim(in.shape());
|
|
||||||
}
|
|
||||||
if (shape_info.strides) {
|
|
||||||
args.append_ndim(in.strides());
|
|
||||||
}
|
|
||||||
if (shape_info.ndim) {
|
|
||||||
args.append<int32_t>(in.ndim());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto& out : outputs) {
|
|
||||||
args.append(out);
|
|
||||||
}
|
|
||||||
for (auto& s : scalar_arguments_) {
|
|
||||||
if (std::holds_alternative<bool>(s)) {
|
|
||||||
args.append(std::get<bool>(s));
|
|
||||||
} else if (std::holds_alternative<int>(s)) {
|
|
||||||
args.append(std::get<int>(s));
|
|
||||||
} else if (std::holds_alternative<float>(s)) {
|
|
||||||
args.append(std::get<float>(s));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make the grid
|
|
||||||
const auto [tx, ty, tz] = threadgroup_;
|
|
||||||
const auto [gx, gy, gz] = grid_;
|
|
||||||
dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
|
||||||
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
|
|
||||||
|
|
||||||
// Call the kernel
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
for (const auto& in : checked_inputs) {
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
}
|
|
||||||
for (const auto& out : outputs) {
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
}
|
|
||||||
for (const auto& t : copies) {
|
|
||||||
encoder.add_temporary(t);
|
|
||||||
}
|
|
||||||
auto kernel =
|
|
||||||
mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
|
|
||||||
if (smem > 0 && smem > 48000) {
|
|
||||||
cuFuncSetAttribute(
|
|
||||||
kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
|
||||||
51
mlx/backend/cuda/distributed.cu
Normal file
51
mlx/backend/cuda/distributed.cu
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/distributed/primitives.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
namespace distributed {
|
||||||
|
void AllReduce::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto& input = inputs[0];
|
||||||
|
auto& output = outputs[0];
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
|
|
||||||
|
if (input.is_donatable()) {
|
||||||
|
output.copy_shared_buffer(input);
|
||||||
|
} else {
|
||||||
|
output.set_data(allocator::malloc(output.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(input);
|
||||||
|
encoder.set_output_array(output);
|
||||||
|
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Sum:
|
||||||
|
distributed::detail::all_sum(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
case Max:
|
||||||
|
distributed::detail::all_max(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
case Min:
|
||||||
|
distributed::detail::all_min(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Only all reduce sum, max, and min are supported.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace distributed
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,396 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
#include "mlx/dtype_utils.h"
|
|
||||||
|
|
||||||
#include <cute/tensor.hpp>
|
|
||||||
#include <cutlass/arch/arch.h>
|
|
||||||
#include <cutlass/cutlass.h>
|
|
||||||
#include <cutlass/gemm/device/gemm.h>
|
|
||||||
#include <cutlass/layout/matrix.h>
|
|
||||||
#include <cutlass/numeric_types.h>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using namespace cute;
|
|
||||||
using bf16 = cute::bfloat16_t;
|
|
||||||
|
|
||||||
template <typename Kernel>
|
|
||||||
void configure_matmul(Kernel kernel, int smem_size) {
|
|
||||||
static bool initialized = false;
|
|
||||||
if (!initialized) {
|
|
||||||
initialized = true;
|
|
||||||
cudaFuncSetAttribute(
|
|
||||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool transpose, typename Tiler>
|
|
||||||
constexpr int get_feature_size(Tiler smem) {
|
|
||||||
int feature_size = (transpose) ? size<0>(smem) : size<1>(smem);
|
|
||||||
return (feature_size >= 64) ? 64 : feature_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int constexpr_log2(int x) {
|
|
||||||
return (x > 0) ? 1 + constexpr_log2(x >> 1) : -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int feature_size, int itemsize, int copy_bits>
|
|
||||||
constexpr int get_swizzle_bits() {
|
|
||||||
constexpr int swizzle_bits =
|
|
||||||
constexpr_log2(feature_size * itemsize / copy_bits);
|
|
||||||
return (swizzle_bits > 3) ? 3 : swizzle_bits;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
|
|
||||||
constexpr auto make_smem_layout(Tiler smem) {
|
|
||||||
constexpr int feature_size = get_feature_size<transpose>(smem);
|
|
||||||
constexpr int swizzle_bits =
|
|
||||||
get_swizzle_bits<feature_size, itemsize, copy_bits>();
|
|
||||||
|
|
||||||
using F = Int<feature_size>;
|
|
||||||
using BaseLayout = std::conditional_t<
|
|
||||||
transpose,
|
|
||||||
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
|
|
||||||
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
|
|
||||||
|
|
||||||
auto swizzled =
|
|
||||||
make_composed_layout(Swizzle<swizzle_bits, 3, 3>{}, 0, BaseLayout{});
|
|
||||||
|
|
||||||
return tile_to_shape(swizzled, smem);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
|
|
||||||
constexpr auto make_result_smem_layout(Tiler smem) {
|
|
||||||
constexpr int feature_size = get_feature_size<transpose>(smem);
|
|
||||||
constexpr int swizzle_bits =
|
|
||||||
get_swizzle_bits<feature_size, itemsize, copy_bits>();
|
|
||||||
|
|
||||||
using F = Int<feature_size>;
|
|
||||||
using BaseLayout = std::conditional_t<
|
|
||||||
transpose,
|
|
||||||
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
|
|
||||||
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
|
|
||||||
|
|
||||||
auto swizzled = make_composed_layout(
|
|
||||||
Swizzle<transpose ? 0 : swizzle_bits, 3, 4>{}, 0, BaseLayout{});
|
|
||||||
|
|
||||||
return tile_to_shape(swizzled, smem);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
int num_threads,
|
|
||||||
int itemsize,
|
|
||||||
bool transpose,
|
|
||||||
int copy_bits,
|
|
||||||
typename Copier,
|
|
||||||
typename Tiler>
|
|
||||||
constexpr auto make_tiled_copy(Copier copy_op, Tiler smem) {
|
|
||||||
constexpr int num_elements = copy_bits / itemsize;
|
|
||||||
constexpr int feature_size = transpose ? size<0>(smem) : size<1>(smem);
|
|
||||||
constexpr int copies_per_feature = feature_size / num_elements;
|
|
||||||
|
|
||||||
using E = Int<num_elements>;
|
|
||||||
using C = Int<copies_per_feature>;
|
|
||||||
using R = Int<num_threads / copies_per_feature>;
|
|
||||||
|
|
||||||
using ThreadLayout = std::conditional_t<
|
|
||||||
transpose,
|
|
||||||
Layout<cute::Shape<C, R>, cute::Stride<_1, C>>,
|
|
||||||
Layout<cute::Shape<R, C>, cute::Stride<C, _1>>>;
|
|
||||||
using ValueLayout = std::conditional_t<
|
|
||||||
transpose,
|
|
||||||
Layout<cute::Shape<E, _1>>,
|
|
||||||
Layout<cute::Shape<_1, E>>>;
|
|
||||||
|
|
||||||
return make_tiled_copy(copy_op, ThreadLayout{}, ValueLayout{});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int rasterization_factor>
|
|
||||||
__device__ inline int2 raster_tile(int x, int y) {
|
|
||||||
return {
|
|
||||||
x / rasterization_factor,
|
|
||||||
(x % rasterization_factor) + y * rasterization_factor};
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename SLayoutA,
|
|
||||||
typename SLayoutB,
|
|
||||||
typename SLayoutC,
|
|
||||||
typename CopyA,
|
|
||||||
typename CopyB,
|
|
||||||
typename CopyC,
|
|
||||||
typename MMA,
|
|
||||||
int rasterization_factor>
|
|
||||||
__global__ static __launch_bounds__(decltype(size(MMA{}))::value) void matmul_kernel(
|
|
||||||
const T* __restrict__ A,
|
|
||||||
const T* __restrict__ B,
|
|
||||||
T* __restrict__ C,
|
|
||||||
SLayoutA SA,
|
|
||||||
SLayoutB SB,
|
|
||||||
SLayoutC SC,
|
|
||||||
CopyA copy_a,
|
|
||||||
CopyB copy_b,
|
|
||||||
CopyC copy_c,
|
|
||||||
MMA mma,
|
|
||||||
int M,
|
|
||||||
int N,
|
|
||||||
int K) {
|
|
||||||
constexpr auto BM = size<0>(SA);
|
|
||||||
constexpr auto BN = size<0>(SB);
|
|
||||||
constexpr auto BK = size<1>(SA);
|
|
||||||
constexpr auto PIPE = size<2>(SA);
|
|
||||||
|
|
||||||
const int2 tile = raster_tile<rasterization_factor>(blockIdx.x, blockIdx.y);
|
|
||||||
const int blocks_m = ceil_div(M, BM);
|
|
||||||
const int blocks_n = ceil_div(N, BN);
|
|
||||||
|
|
||||||
// Exit early if the tile is OOB
|
|
||||||
if (tile.x >= blocks_m || tile.y >= blocks_n) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make the full tensors
|
|
||||||
Tensor full_A =
|
|
||||||
make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{}));
|
|
||||||
Tensor full_B =
|
|
||||||
make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{}));
|
|
||||||
Tensor full_C =
|
|
||||||
make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{}));
|
|
||||||
|
|
||||||
// Partition the tensors into tiles and select the ones for this threadblock
|
|
||||||
Tensor local_A =
|
|
||||||
local_tile(full_A, make_shape(BM, BK), make_coord(tile.x, _));
|
|
||||||
Tensor local_B =
|
|
||||||
local_tile(full_B, make_shape(BN, BK), make_coord(tile.y, _));
|
|
||||||
Tensor local_C =
|
|
||||||
local_tile(full_C, make_shape(BM, BN), make_coord(tile.x, tile.y));
|
|
||||||
|
|
||||||
// Make shared memory tensors
|
|
||||||
extern __shared__ char shared_memory[];
|
|
||||||
T* shared_A_ptr = reinterpret_cast<T*>(shared_memory);
|
|
||||||
T* shared_B_ptr =
|
|
||||||
reinterpret_cast<T*>(shared_memory + cosize(SA) * sizeof(T));
|
|
||||||
T* shared_C_ptr = reinterpret_cast<T*>(shared_memory);
|
|
||||||
Tensor shared_A = make_tensor(make_smem_ptr(shared_A_ptr), SA);
|
|
||||||
Tensor shared_B = make_tensor(make_smem_ptr(shared_B_ptr), SB);
|
|
||||||
Tensor shared_C = make_tensor(make_smem_ptr(shared_C_ptr), SC);
|
|
||||||
|
|
||||||
// Get the copies that correspond to this thread
|
|
||||||
auto thread_copy_a = copy_a.get_slice(threadIdx.x);
|
|
||||||
Tensor local_A_src = thread_copy_a.partition_S(local_A);
|
|
||||||
Tensor local_A_dst = thread_copy_a.partition_D(shared_A);
|
|
||||||
auto thread_copy_b = copy_b.get_slice(threadIdx.x);
|
|
||||||
Tensor local_B_src = thread_copy_a.partition_S(local_B);
|
|
||||||
Tensor local_B_dst = thread_copy_a.partition_D(shared_B);
|
|
||||||
auto thread_copy_c = copy_c.get_slice(threadIdx.x);
|
|
||||||
Tensor local_C_src = thread_copy_c.partition_S(shared_C);
|
|
||||||
Tensor local_C_dst = thread_copy_c.partition_D(local_C);
|
|
||||||
|
|
||||||
// Start fetches
|
|
||||||
int k_tile_count = size<2>(local_A);
|
|
||||||
int k_tile_next = 0;
|
|
||||||
CUTE_UNROLL
|
|
||||||
for (int k = 0; k < PIPE - 1; k++) {
|
|
||||||
copy(copy_a, local_A_src(_, _, _, k_tile_next), local_A_dst(_, _, _, k));
|
|
||||||
copy(copy_b, local_B_src(_, _, _, k_tile_next), local_B_dst(_, _, _, k));
|
|
||||||
cp_async_fence();
|
|
||||||
k_tile_count--;
|
|
||||||
k_tile_next += (k_tile_count > 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the MMA that corresponds to this thread and allocate registers
|
|
||||||
auto thread_mma = mma.get_slice(threadIdx.x);
|
|
||||||
Tensor mma_shared_A = thread_mma.partition_A(shared_A);
|
|
||||||
Tensor mma_shared_B = thread_mma.partition_B(shared_B);
|
|
||||||
Tensor mma_shared_C = thread_mma.partition_C(shared_C);
|
|
||||||
Tensor mma_global_C = thread_mma.partition_C(local_C);
|
|
||||||
Tensor mma_frag_A = mma.make_fragment_A(mma_shared_A(_, _, _, 0));
|
|
||||||
Tensor mma_frag_B = mma.make_fragment_B(mma_shared_B(_, _, _, 0));
|
|
||||||
Tensor mma_frag_C = mma.make_fragment_C(mma_global_C);
|
|
||||||
clear(mma_frag_C);
|
|
||||||
|
|
||||||
// Make shared to register copies
|
|
||||||
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_a;
|
|
||||||
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_b;
|
|
||||||
auto s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
|
|
||||||
auto s2r_thread_copy_a = s2r_copy_a.get_slice(threadIdx.x);
|
|
||||||
auto s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
|
|
||||||
auto s2r_thread_copy_b = s2r_copy_b.get_slice(threadIdx.x);
|
|
||||||
Tensor mma_A_src = s2r_thread_copy_a.partition_S(shared_A);
|
|
||||||
Tensor mma_A_dst = s2r_thread_copy_a.retile_D(mma_frag_A);
|
|
||||||
Tensor mma_B_src = s2r_thread_copy_b.partition_S(shared_B);
|
|
||||||
Tensor mma_B_dst = s2r_thread_copy_b.retile_D(mma_frag_B);
|
|
||||||
|
|
||||||
constexpr auto RPIPE = size<2>(mma_shared_A);
|
|
||||||
int smem_read = 0;
|
|
||||||
int smem_write = PIPE - 1;
|
|
||||||
Tensor mma_A_src_p = mma_A_src(_, _, _, smem_read);
|
|
||||||
Tensor mma_B_src_p = mma_B_src(_, _, _, smem_read);
|
|
||||||
|
|
||||||
// Start the register pipeline
|
|
||||||
if constexpr (RPIPE > 1) {
|
|
||||||
cp_async_wait<PIPE - 2>();
|
|
||||||
__syncthreads();
|
|
||||||
copy(s2r_copy_a, mma_A_src_p(_, _, Int<0>{}), mma_A_dst(_, _, Int<0>{}));
|
|
||||||
copy(s2r_copy_b, mma_B_src_p(_, _, Int<0>{}), mma_B_dst(_, _, Int<0>{}));
|
|
||||||
}
|
|
||||||
|
|
||||||
CUTE_NO_UNROLL
|
|
||||||
while (k_tile_count > -(PIPE - 1)) {
|
|
||||||
CUTE_UNROLL
|
|
||||||
for (int k_block = 0; k_block < RPIPE; k_block++) {
|
|
||||||
if (k_block == RPIPE - 1) {
|
|
||||||
mma_A_src_p = mma_A_src(_, _, _, smem_read);
|
|
||||||
mma_B_src_p = mma_B_src(_, _, _, smem_read);
|
|
||||||
cp_async_wait<PIPE - 2>();
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the next register tile
|
|
||||||
auto k_block_next = (k_block + 1) % RPIPE;
|
|
||||||
copy(
|
|
||||||
s2r_copy_a,
|
|
||||||
mma_A_src_p(_, _, k_block_next),
|
|
||||||
mma_A_dst(_, _, k_block_next));
|
|
||||||
copy(
|
|
||||||
s2r_copy_b,
|
|
||||||
mma_B_src_p(_, _, k_block_next),
|
|
||||||
mma_B_dst(_, _, k_block_next));
|
|
||||||
|
|
||||||
if (k_block == 0) {
|
|
||||||
copy(
|
|
||||||
copy_a,
|
|
||||||
local_A_src(_, _, _, k_tile_next),
|
|
||||||
local_A_dst(_, _, _, smem_write));
|
|
||||||
copy(
|
|
||||||
copy_b,
|
|
||||||
local_B_src(_, _, _, k_tile_next),
|
|
||||||
local_B_dst(_, _, _, smem_write));
|
|
||||||
cp_async_fence();
|
|
||||||
k_tile_count--;
|
|
||||||
k_tile_next += (k_tile_count > 0);
|
|
||||||
smem_write = smem_read;
|
|
||||||
smem_read = (smem_read == PIPE - 1) ? 0 : (smem_read + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
gemm(
|
|
||||||
mma,
|
|
||||||
mma_frag_A(_, _, k_block),
|
|
||||||
mma_frag_B(_, _, k_block),
|
|
||||||
mma_frag_C);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(mma_frag_C, mma_shared_C);
|
|
||||||
__syncthreads();
|
|
||||||
copy(copy_c, local_C_src, local_C_dst);
|
|
||||||
|
|
||||||
// if (threadIdx.x == 0) {
|
|
||||||
// print("fC: "); print(mma_frag_C); print("\n");
|
|
||||||
// print("sC: "); print(mma_shared_C); print("\n");
|
|
||||||
// print("dC: "); print(local_C_dst); print("\n");
|
|
||||||
//
|
|
||||||
// print(s2r_atom_a); print("\n");
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void cutlass_gemm(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
int M,
|
|
||||||
int N,
|
|
||||||
int K,
|
|
||||||
cu::CommandEncoder& enc) {
|
|
||||||
enc.set_input_array(a);
|
|
||||||
enc.set_input_array(b);
|
|
||||||
enc.set_output_array(out);
|
|
||||||
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
|
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
||||||
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
|
|
||||||
using namespace cute;
|
|
||||||
|
|
||||||
// Tile definitions
|
|
||||||
auto BM = Int<128>{};
|
|
||||||
auto BN = Int<128>{};
|
|
||||||
auto BK = Int<64>{};
|
|
||||||
auto BP = Int<3>{};
|
|
||||||
auto GM = Int<8>{};
|
|
||||||
|
|
||||||
// Thread definitions
|
|
||||||
using TM = Int<2>;
|
|
||||||
using TN = Int<2>;
|
|
||||||
using TK = Int<1>;
|
|
||||||
constexpr int num_threads = TM::value * TN::value * 32;
|
|
||||||
|
|
||||||
auto SA = make_smem_layout<16, false, 128>(make_shape(BM, BK, BP));
|
|
||||||
auto SB = make_smem_layout<16, false, 128>(make_shape(BN, BK, BP));
|
|
||||||
auto SC = make_result_smem_layout<16, false, 128>(make_shape(BM, BN));
|
|
||||||
|
|
||||||
constexpr auto smem_size = (cosize(SA) + cosize(SB)) * sizeof(bf16);
|
|
||||||
|
|
||||||
auto async_copy_op =
|
|
||||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, bf16>{};
|
|
||||||
auto tiled_copy_a = make_tiled_copy<num_threads, 16, false, 128>(
|
|
||||||
async_copy_op, make_shape(BM, BK));
|
|
||||||
auto tiled_copy_b = make_tiled_copy<num_threads, 16, false, 128>(
|
|
||||||
async_copy_op, make_shape(BN, BK));
|
|
||||||
|
|
||||||
auto sync_copy_op = Copy_Atom<UniversalCopy<uint128_t>, bf16>{};
|
|
||||||
auto tiled_copy_c = make_tiled_copy<num_threads, 16, false, 128>(
|
|
||||||
sync_copy_op, make_shape(BM, BN));
|
|
||||||
|
|
||||||
auto mma_op = SM80_16x8x16_F32BF16BF16F32_TN{};
|
|
||||||
auto tiled_mma = make_tiled_mma(
|
|
||||||
mma_op, Layout<cute::Shape<TM, TN, TK>>{}, Tile<_32, _32, _16>{});
|
|
||||||
|
|
||||||
auto kernel = matmul_kernel<
|
|
||||||
bf16,
|
|
||||||
decltype(SA),
|
|
||||||
decltype(SB),
|
|
||||||
decltype(SC),
|
|
||||||
decltype(tiled_copy_a),
|
|
||||||
decltype(tiled_copy_b),
|
|
||||||
decltype(tiled_copy_c),
|
|
||||||
decltype(tiled_mma),
|
|
||||||
GM.value>;
|
|
||||||
configure_matmul(kernel, smem_size);
|
|
||||||
|
|
||||||
dim3 block(size(tiled_mma));
|
|
||||||
dim3 grid(
|
|
||||||
size(ceil_div(M, BM) * GM), size(ceil_div(ceil_div(N, BN), GM)));
|
|
||||||
|
|
||||||
enc.add_kernel_node(
|
|
||||||
kernel,
|
|
||||||
grid,
|
|
||||||
block,
|
|
||||||
smem_size,
|
|
||||||
a.data<bf16>(),
|
|
||||||
b.data<bf16>(),
|
|
||||||
out.data<bf16>(),
|
|
||||||
SA,
|
|
||||||
SB,
|
|
||||||
SC,
|
|
||||||
tiled_copy_a,
|
|
||||||
tiled_copy_b,
|
|
||||||
tiled_copy_c,
|
|
||||||
tiled_mma,
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
K);
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("Only bfloat16 supported");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
void cutlass_gemm(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
int M,
|
|
||||||
int N,
|
|
||||||
int K,
|
|
||||||
cu::CommandEncoder& enc);
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
#include "mlx/backend/cuda/steel/gemm.cuh"
|
|
||||||
#include "mlx/dtype_utils.h"
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename Kernel>
|
|
||||||
static void configure_smem(Kernel kernel, int SM) {
|
|
||||||
static bool done = false;
|
|
||||||
if (done) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
std::cout << "configuring" << std::endl;
|
|
||||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SM);
|
|
||||||
cudaFuncSetAttribute(
|
|
||||||
kernel,
|
|
||||||
cudaFuncAttributePreferredSharedMemoryCarveout,
|
|
||||||
cudaSharedmemCarveoutMaxShared);
|
|
||||||
done = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void simple_gemm(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
int M,
|
|
||||||
int N,
|
|
||||||
int K,
|
|
||||||
cu::CommandEncoder& enc) {
|
|
||||||
enc.set_input_array(a);
|
|
||||||
enc.set_input_array(b);
|
|
||||||
enc.set_output_array(out);
|
|
||||||
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
|
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
||||||
constexpr int BM = 128;
|
|
||||||
constexpr int BN = 128;
|
|
||||||
constexpr int BK = 32;
|
|
||||||
constexpr int PIPE = 3;
|
|
||||||
constexpr int SM = PIPE * sizeof(DataType) * (BM * BK + BN * BK);
|
|
||||||
constexpr int WM = 2;
|
|
||||||
constexpr int WN = 4;
|
|
||||||
|
|
||||||
auto kernel = ab_t_aligned<DataType, BM, BN, BK, WM, WN, PIPE>;
|
|
||||||
configure_smem(kernel, SM);
|
|
||||||
|
|
||||||
dim3 grid(N / BN, M / BM);
|
|
||||||
enc.add_kernel_node(
|
|
||||||
kernel,
|
|
||||||
grid,
|
|
||||||
WM * WN * WARP_SIZE,
|
|
||||||
SM,
|
|
||||||
a.data<DataType>(),
|
|
||||||
b.data<DataType>(),
|
|
||||||
out.data<DataType>(),
|
|
||||||
N,
|
|
||||||
K);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
void simple_gemm(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
int M,
|
|
||||||
int N,
|
|
||||||
int K,
|
|
||||||
cu::CommandEncoder& enc);
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -94,7 +94,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
large ? "int64_t" : "int32_t"));
|
large ? "int64_t" : "int32_t"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_tuple(false, jit_source_gather, std::move(kernel_names));
|
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||||
});
|
});
|
||||||
|
|
||||||
cu::KernelArgs args;
|
cu::KernelArgs args;
|
||||||
@@ -189,7 +189,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
large ? "int64_t" : "int32_t"));
|
large ? "int64_t" : "int32_t"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));
|
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||||
});
|
});
|
||||||
|
|
||||||
cu::KernelArgs args;
|
cu::KernelArgs args;
|
||||||
@@ -268,8 +268,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_tuple(
|
return std::make_pair(jit_source_gather_axis, std::move(kernel_names));
|
||||||
false, jit_source_gather_axis, std::move(kernel_names));
|
|
||||||
});
|
});
|
||||||
|
|
||||||
size_t idx_size_pre = 1;
|
size_t idx_size_pre = 1;
|
||||||
@@ -372,8 +371,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_tuple(
|
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names));
|
||||||
false, jit_source_scatter_axis, std::move(kernel_names));
|
|
||||||
});
|
});
|
||||||
|
|
||||||
size_t idx_size_pre = 1;
|
size_t idx_size_pre = 1;
|
||||||
|
|||||||
@@ -101,8 +101,8 @@ const std::filesystem::path& ptx_cache_dir() {
|
|||||||
bool read_cached_ptx(
|
bool read_cached_ptx(
|
||||||
const std::filesystem::path& cache_dir,
|
const std::filesystem::path& cache_dir,
|
||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
std::string& ptx,
|
std::vector<char>* ptx,
|
||||||
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||||
if (cache_dir.empty()) {
|
if (cache_dir.empty()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -117,15 +117,15 @@ bool read_cached_ptx(
|
|||||||
if (!ptx_file.good()) {
|
if (!ptx_file.good()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ptx.resize(ptx_size);
|
ptx->resize(ptx_size);
|
||||||
ptx_file.read(ptx.data(), ptx_size);
|
ptx_file.read(ptx->data(), ptx_size);
|
||||||
|
|
||||||
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||||
std::string line;
|
std::string line;
|
||||||
while (std::getline(txt_file, line)) {
|
while (std::getline(txt_file, line)) {
|
||||||
auto tab = line.find('\t');
|
auto tab = line.find('\t');
|
||||||
if (tab != std::string::npos) {
|
if (tab != std::string::npos) {
|
||||||
ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@@ -135,7 +135,7 @@ bool read_cached_ptx(
|
|||||||
void write_cached_ptx(
|
void write_cached_ptx(
|
||||||
const std::filesystem::path& cache_dir,
|
const std::filesystem::path& cache_dir,
|
||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const std::string& ptx,
|
const std::vector<char>& ptx,
|
||||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||||
const std::string& source_code) {
|
const std::string& source_code) {
|
||||||
if (cache_dir.empty()) {
|
if (cache_dir.empty()) {
|
||||||
@@ -217,18 +217,22 @@ constexpr const char* g_headers[] = {
|
|||||||
jit_source_utils,
|
jit_source_utils,
|
||||||
};
|
};
|
||||||
|
|
||||||
void compile(
|
} // namespace
|
||||||
|
|
||||||
|
JitModule::JitModule(
|
||||||
Device& device,
|
Device& device,
|
||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const std::string& source,
|
const KernelBuilder& builder) {
|
||||||
const std::vector<std::string>& kernel_names,
|
// Check cache.
|
||||||
std::string& ptx,
|
std::vector<char> ptx;
|
||||||
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||||
// Create the program
|
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||||
|
// Create program.
|
||||||
|
auto [source_code, kernel_names] = builder();
|
||||||
nvrtcProgram prog;
|
nvrtcProgram prog;
|
||||||
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
||||||
&prog,
|
&prog,
|
||||||
source.c_str(),
|
source_code.c_str(),
|
||||||
(module_name + ".cu").c_str(),
|
(module_name + ".cu").c_str(),
|
||||||
std::size(g_headers),
|
std::size(g_headers),
|
||||||
g_headers,
|
g_headers,
|
||||||
@@ -282,20 +286,16 @@ void compile(
|
|||||||
} else {
|
} else {
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||||
}
|
}
|
||||||
ptx.resize(ptx_size);
|
ptx.resize(ptx_size, 0);
|
||||||
if (use_sass) {
|
if (use_sass) {
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||||
} else {
|
} else {
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||||
}
|
}
|
||||||
|
write_cached_ptx(
|
||||||
|
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||||
}
|
}
|
||||||
|
|
||||||
void load_module(
|
|
||||||
const std::string& module_name,
|
|
||||||
const std::string& ptx,
|
|
||||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
|
||||||
CUmodule& module_,
|
|
||||||
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
|
|
||||||
// Load module.
|
// Load module.
|
||||||
char jit_log[4089] = {};
|
char jit_log[4089] = {};
|
||||||
CUjit_option options[] = {
|
CUjit_option options[] = {
|
||||||
@@ -312,69 +312,21 @@ void load_module(
|
|||||||
for (const auto& [name, mangled] : ptx_kernels) {
|
for (const auto& [name, mangled] : ptx_kernels) {
|
||||||
CUfunction kernel;
|
CUfunction kernel;
|
||||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||||
kernels[name] = std::make_pair(kernel, false);
|
kernels_[name] = kernel;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
JitModule::JitModule(
|
|
||||||
Device& device,
|
|
||||||
const std::string& module_name,
|
|
||||||
const KernelBuilder& builder,
|
|
||||||
bool use_disk_cache) {
|
|
||||||
// Will hold the actual device executable source code and kernel names
|
|
||||||
std::string ptx;
|
|
||||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
|
||||||
|
|
||||||
// Try to load them from the file cache
|
|
||||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {
|
|
||||||
auto [precompiled, source_code, kernel_names] = builder();
|
|
||||||
|
|
||||||
// Get the PTX or cubin
|
|
||||||
if (precompiled) {
|
|
||||||
ptx = std::move(source_code);
|
|
||||||
for (auto& name : kernel_names) {
|
|
||||||
ptx_kernels.emplace_back(name, name);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If requested save them in the file cache for the next launch
|
|
||||||
if (use_disk_cache) {
|
|
||||||
write_cached_ptx(
|
|
||||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the module
|
|
||||||
load_module(module_name, ptx, ptx_kernels, module_, kernels_);
|
|
||||||
}
|
|
||||||
|
|
||||||
JitModule::~JitModule() {
|
JitModule::~JitModule() {
|
||||||
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
||||||
}
|
}
|
||||||
|
|
||||||
CUfunction JitModule::get_kernel(
|
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
||||||
const std::string& kernel_name,
|
|
||||||
std::function<void(CUfunction)> configure_kernel) {
|
|
||||||
auto it = kernels_.find(kernel_name);
|
auto it = kernels_.find(kernel_name);
|
||||||
if (it == kernels_.end()) {
|
if (it == kernels_.end()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
fmt::format("There is no kernel named {}.", kernel_name));
|
fmt::format("There is no kernel named {}.", kernel_name));
|
||||||
}
|
}
|
||||||
|
return it->second;
|
||||||
// If it is the first time we run this kernel then configure it. Do it only
|
|
||||||
// once!
|
|
||||||
if (!it->second.second) {
|
|
||||||
if (configure_kernel) {
|
|
||||||
configure_kernel(it->second.first);
|
|
||||||
}
|
|
||||||
it->second.second = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return it->second.first;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||||
@@ -385,12 +337,11 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
|||||||
JitModule& get_jit_module(
|
JitModule& get_jit_module(
|
||||||
const mlx::core::Device& device,
|
const mlx::core::Device& device,
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const KernelBuilder& builder,
|
const KernelBuilder& builder) {
|
||||||
bool cache) {
|
|
||||||
auto& map = get_jit_module_cache();
|
auto& map = get_jit_module_cache();
|
||||||
auto it = map.find(name);
|
auto it = map.find(name);
|
||||||
if (it == map.end()) {
|
if (it == map.end()) {
|
||||||
it = map.try_emplace(name, cu::device(device), name, builder, cache).first;
|
it = map.try_emplace(name, cu::device(device), name, builder).first;
|
||||||
}
|
}
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,8 +19,7 @@ namespace mlx::core::cu {
|
|||||||
|
|
||||||
class Device;
|
class Device;
|
||||||
|
|
||||||
using KernelBuilderResult = std::tuple<
|
using KernelBuilderResult = std::pair<
|
||||||
/* precompiled */ bool,
|
|
||||||
/* source code */ std::string,
|
/* source code */ std::string,
|
||||||
/* kernel names */ std::vector<std::string>>;
|
/* kernel names */ std::vector<std::string>>;
|
||||||
using KernelBuilder = std::function<KernelBuilderResult()>;
|
using KernelBuilder = std::function<KernelBuilderResult()>;
|
||||||
@@ -64,16 +63,14 @@ struct KernelArgs {
|
|||||||
private:
|
private:
|
||||||
std::vector<void*> args_;
|
std::vector<void*> args_;
|
||||||
|
|
||||||
// The cuGraphAddKernelNode API requires passing pointers to arguments so
|
// The cuLaunchKernel API requires passing pointers to arguments so store
|
||||||
// store temporary values until the node is created.
|
// temporary values untill kernel is launched.
|
||||||
using Arg = std::variant<
|
using Arg = std::variant<
|
||||||
std::monostate,
|
std::monostate,
|
||||||
CUdeviceptr,
|
CUdeviceptr,
|
||||||
bool,
|
|
||||||
int32_t,
|
int32_t,
|
||||||
uint32_t,
|
uint32_t,
|
||||||
int64_t,
|
int64_t,
|
||||||
float,
|
|
||||||
SmallVector<const void*>,
|
SmallVector<const void*>,
|
||||||
SmallVector<int32_t>,
|
SmallVector<int32_t>,
|
||||||
SmallVector<int64_t>>;
|
SmallVector<int64_t>>;
|
||||||
@@ -85,19 +82,16 @@ class JitModule {
|
|||||||
JitModule(
|
JitModule(
|
||||||
Device& device,
|
Device& device,
|
||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const KernelBuilder& builder,
|
const KernelBuilder& builder);
|
||||||
bool cache);
|
|
||||||
~JitModule();
|
~JitModule();
|
||||||
|
|
||||||
JitModule(const JitModule&) = delete;
|
JitModule(const JitModule&) = delete;
|
||||||
JitModule& operator=(const JitModule&) = delete;
|
JitModule& operator=(const JitModule&) = delete;
|
||||||
CUfunction get_kernel(
|
CUfunction get_kernel(const std::string& kernel_name);
|
||||||
const std::string& kernel_name,
|
|
||||||
std::function<void(CUfunction)> configure_kernel = nullptr);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CUmodule module_{nullptr};
|
CUmodule module_{nullptr};
|
||||||
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
|
std::unordered_map<std::string, CUfunction> kernels_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||||
@@ -105,7 +99,6 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
|||||||
JitModule& get_jit_module(
|
JitModule& get_jit_module(
|
||||||
const mlx::core::Device& device,
|
const mlx::core::Device& device,
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const KernelBuilder& builder,
|
const KernelBuilder& builder);
|
||||||
bool use_disk_cache = true);
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
|||||||
@@ -3,9 +3,7 @@
|
|||||||
#include "mlx/backend/common/matmul.h"
|
#include "mlx/backend/common/matmul.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
#include "mlx/backend/cuda/gemms/cutlass_gemm.h"
|
|
||||||
#include "mlx/backend/cuda/gemms/gemv.h"
|
#include "mlx/backend/cuda/gemms/gemv.h"
|
||||||
#include "mlx/backend/cuda/gemms/simple_gemm.h"
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@@ -13,14 +11,8 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
int get_test_gemm() {
|
|
||||||
static int t = env::get_var("MLX_ENABLE_TEST_GEMM", 0);
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<bool, int64_t, array>
|
std::tuple<bool, int64_t, array>
|
||||||
check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
||||||
auto stx = arr.strides()[arr.ndim() - 2];
|
auto stx = arr.strides()[arr.ndim() - 2];
|
||||||
@@ -103,18 +95,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
|
|
||||||
b_transposed && batch_count == 1 && get_test_gemm() == 1) {
|
|
||||||
cu::simple_gemm(a, b, out, M, N, K, encoder);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
|
|
||||||
b_transposed && batch_count == 1 && get_test_gemm() == 2) {
|
|
||||||
cu::cutlass_gemm(a, b, out, M, N, K, encoder);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt
|
||||||
CublasGemm gemm(
|
CublasGemm gemm(
|
||||||
|
|||||||
@@ -1,47 +1,11 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/cuda.h"
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
#include "mlx/fast.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace mlx::core::cu
|
||||||
|
|
||||||
namespace fast {
|
|
||||||
|
|
||||||
CustomKernelFunction cuda_kernel(
|
|
||||||
const std::string&,
|
|
||||||
const std::vector<std::string>&,
|
|
||||||
const std::vector<std::string>&,
|
|
||||||
const std::string&,
|
|
||||||
const std::string&,
|
|
||||||
bool,
|
|
||||||
int) {
|
|
||||||
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<array> precompiled_cuda_kernel(
|
|
||||||
const std::string&,
|
|
||||||
const std::string&,
|
|
||||||
const std::vector<array>&,
|
|
||||||
const std::vector<Shape>&,
|
|
||||||
const std::vector<Dtype>&,
|
|
||||||
const std::vector<ScalarArg>&,
|
|
||||||
std::tuple<int, int, int>,
|
|
||||||
std::tuple<int, int, int>,
|
|
||||||
int shared_memory,
|
|
||||||
std::optional<float> init_value,
|
|
||||||
bool ensure_row_contiguous,
|
|
||||||
StreamOrDevice) {
|
|
||||||
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace fast
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
|
|||||||
@@ -41,8 +41,11 @@ NO_GPU(Cholesky)
|
|||||||
NO_GPU_MULTI(Eig)
|
NO_GPU_MULTI(Eig)
|
||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
NO_GPU_MULTI(CustomKernel)
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
NO_GPU_MULTI(AllReduce)
|
|
||||||
NO_GPU_MULTI(AllGather)
|
NO_GPU_MULTI(AllGather)
|
||||||
NO_GPU_MULTI(Send)
|
NO_GPU_MULTI(Send)
|
||||||
NO_GPU_MULTI(Recv)
|
NO_GPU_MULTI(Recv)
|
||||||
|
|||||||
@@ -4,189 +4,95 @@
|
|||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
template <typename T, int BM, int BN, int BK, int WM, int WN>
|
|
||||||
__device__ inline void gemm_ab_t(
|
|
||||||
RegisterTile<float, BM / WM, BN / WN>& C,
|
|
||||||
SharedTile<T, BM, BK>& As,
|
|
||||||
SharedTile<T, BN, BK>& Bs,
|
|
||||||
RegisterTileLoader<SharedTile<T, BM, BK>>& rloader_a,
|
|
||||||
RegisterTileLoader<SharedTile<T, BN, BK>>& rloader_b) {
|
|
||||||
RegisterTile<T, BM / WM, 16> A[2];
|
|
||||||
RegisterTile<T, BN / WN, 16> B[2];
|
|
||||||
|
|
||||||
rloader_a.load(A[0], As.base_addr(), 0);
|
|
||||||
rloader_b.load(B[0], Bs.base_addr(), 0);
|
|
||||||
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int k = 1; k < BK / 16; k++) {
|
|
||||||
rloader_a.load(A[k & 1], As.base_addr(), k);
|
|
||||||
rloader_b.load(B[k & 1], Bs.base_addr(), k);
|
|
||||||
|
|
||||||
mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]);
|
|
||||||
}
|
|
||||||
mma_t(C, A[(BK / 16 - 1) & 1], B[(BK / 16 - 1) & 1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An example gemm written with the utils.
|
* An example gemm written with the utils.
|
||||||
*
|
*
|
||||||
* Computes A @ B.T when A and B are all aligned with the block sizes.
|
* Computes A @ B.T when A and B are all aligned with the block sizes.
|
||||||
*/
|
*/
|
||||||
// template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
|
template <typename T, int BM, int BN, int BK>
|
||||||
//__global__ __launch_bounds__(WM * WN * WARP_SIZE, 1)
|
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||||
// void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
constexpr int WARPS_M = 2;
|
||||||
// constexpr int NUM_WARPS = WM * WN;
|
constexpr int WARPS_N = 2;
|
||||||
// constexpr int WARP_STEP_M = BM / WM;
|
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
|
||||||
// constexpr int WARP_STEP_N = BN / WN;
|
constexpr int WARP_STEP_M = BM / WARPS_M;
|
||||||
//
|
constexpr int WARP_STEP_N = BN / WARPS_N;
|
||||||
// // Precompute some offsets for each thread
|
|
||||||
// const int warpid = threadIdx.x / 32;
|
|
||||||
// const int laneid = threadIdx.x % 32;
|
|
||||||
// const int wm = warpid / WN;
|
|
||||||
// const int wn = warpid % WN;
|
|
||||||
// const int offset_m = wm * WARP_STEP_M;
|
|
||||||
// const int offset_n = wn * WARP_STEP_N;
|
|
||||||
//
|
|
||||||
// // Allocate shared memory
|
|
||||||
// extern __shared__ char shmem[];
|
|
||||||
// SharedTile<T, BM, BK>(&as)[PIPE] =
|
|
||||||
// *(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
|
|
||||||
// SharedTile<T, BN, BK>(&bs)[PIPE] =
|
|
||||||
// *(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
|
|
||||||
//
|
|
||||||
// // Move the global pointers to the tile
|
|
||||||
// a += blockIdx.y * BM * K;
|
|
||||||
// b += blockIdx.x * BN * K;
|
|
||||||
// y += blockIdx.y * BM * N + blockIdx.x * BN;
|
|
||||||
//
|
|
||||||
// // Make the loaders to/from SMEM
|
|
||||||
// SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>> sloader_a(a, K);
|
|
||||||
// SharedTileLoader<NUM_WARPS, SharedTile<T, BN, BK>> sloader_b(b, K);
|
|
||||||
// RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
|
|
||||||
// RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
|
|
||||||
//
|
|
||||||
// // Start the SM pipeline
|
|
||||||
// MLX_UNROLL
|
|
||||||
// for (int i = 0; i < PIPE - 1; i++) {
|
|
||||||
// sloader_a.load_async(as[i].base_addr());
|
|
||||||
// sloader_b.load_async(bs[i].base_addr());
|
|
||||||
// cp_async_commit();
|
|
||||||
// sloader_a.next();
|
|
||||||
// sloader_b.next();
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // Allocate and zero the MMA accumulator
|
|
||||||
// RegisterTile<float, BM / WM, BN / WN> C;
|
|
||||||
// C.fill(0);
|
|
||||||
//
|
|
||||||
// // Matmul loop
|
|
||||||
// int num_blocks = K / BK;
|
|
||||||
// int sread = 0;
|
|
||||||
// int swrite = PIPE - 1;
|
|
||||||
// for (int i = 0; i < num_blocks; i++) {
|
|
||||||
// cp_async_wait<PIPE - 1>();
|
|
||||||
//
|
|
||||||
// gemm_ab_t<T, BM, BN, BK, WM, WN>(
|
|
||||||
// C, as[sread], bs[sread], rloader_a, rloader_b);
|
|
||||||
//
|
|
||||||
// sloader_a.load_async(as[swrite].base_addr());
|
|
||||||
// sloader_b.load_async(bs[swrite].base_addr());
|
|
||||||
// cp_async_commit();
|
|
||||||
// sloader_a.next(i + PIPE < num_blocks);
|
|
||||||
// sloader_b.next(i + PIPE < num_blocks);
|
|
||||||
//
|
|
||||||
// swrite = sread;
|
|
||||||
// sread = (sread + 1) % PIPE;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// C.store_global(y, N, offset_m, offset_n);
|
|
||||||
// }
|
|
||||||
|
|
||||||
template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
|
|
||||||
__global__ __launch_bounds__(
|
|
||||||
WM* WN* WARP_SIZE,
|
|
||||||
1) void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
|
||||||
constexpr int NUM_WARPS = WM * WN;
|
|
||||||
constexpr int WARP_STEP_M = BM / WM;
|
|
||||||
constexpr int WARP_STEP_N = BN / WN;
|
|
||||||
|
|
||||||
// Precompute some offsets for each thread
|
// Precompute some offsets for each thread
|
||||||
const int warpid = threadIdx.x / 32;
|
const int warpid = threadIdx.x / 32;
|
||||||
const int laneid = threadIdx.x % 32;
|
const int laneid = threadIdx.x % 32;
|
||||||
const int wm = warpid / WN;
|
const int wm = warpid / WARPS_N;
|
||||||
const int wn = warpid % WN;
|
const int wn = warpid % WARPS_N;
|
||||||
const int offset_m = wm * WARP_STEP_M;
|
const int offset_m = wm * WARP_STEP_M;
|
||||||
const int offset_n = wn * WARP_STEP_N;
|
const int offset_n = wn * WARP_STEP_N;
|
||||||
|
|
||||||
// Allocate shared memory
|
// Allocate shared memory
|
||||||
extern __shared__ char shmem[];
|
extern __shared__ char shmem[];
|
||||||
SharedTile<T, BM, BK>(&as)[PIPE] =
|
SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]);
|
||||||
*(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
|
SharedTile<T, BN, BK>(&bs)[2] =
|
||||||
SharedTile<T, BN, BK>(&bs)[PIPE] =
|
*(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]);
|
||||||
*(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
|
|
||||||
|
// Allocate registers for the MMA
|
||||||
|
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
|
||||||
|
RegisterTile<T, BM / WARPS_M, 16> A;
|
||||||
|
RegisterTile<T, BN / WARPS_N, 16> B;
|
||||||
|
|
||||||
// Move the global pointers to the tile
|
// Move the global pointers to the tile
|
||||||
a += blockIdx.y * BM * K;
|
a += blockIdx.y * BM * K;
|
||||||
b += blockIdx.x * BN * K;
|
b += blockIdx.x * BN * K;
|
||||||
y += blockIdx.y * BM * N + blockIdx.x * BN;
|
y += blockIdx.y * BM * N + blockIdx.x * BN;
|
||||||
|
|
||||||
// Make the loaders to/from SMEM
|
// Zero the accumulators
|
||||||
using sloader = SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>>;
|
|
||||||
constexpr int SSTEP = sloader::STEP_ROWS * sizeof(T) * BK;
|
|
||||||
const int srow = threadIdx.x / sloader::NUM_LOADS_PER_ROW;
|
|
||||||
const int scol =
|
|
||||||
(threadIdx.x % sloader::NUM_LOADS_PER_ROW) * sloader::ELEMENTS_PER_LOAD;
|
|
||||||
a += srow * K + scol;
|
|
||||||
b += srow * K + scol;
|
|
||||||
uint32_t sm_offsets[PIPE][2];
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int s = 0; s < PIPE; s++) {
|
|
||||||
sm_offsets[s][0] = as[s].loc(as[s].base_addr(), srow, scol);
|
|
||||||
sm_offsets[s][1] = bs[s].loc(bs[s].base_addr(), srow, scol);
|
|
||||||
}
|
|
||||||
RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
|
|
||||||
RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
|
|
||||||
|
|
||||||
// Start the SM pipeline
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int s = 0; s < PIPE - 1; s++) {
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
|
|
||||||
cp_async<16>(sm_offsets[s][0] + l * SSTEP, a);
|
|
||||||
cp_async<16>(sm_offsets[s][1] + l * SSTEP, b);
|
|
||||||
a += sloader::STEP_ROWS * K;
|
|
||||||
b += sloader::STEP_ROWS * K;
|
|
||||||
}
|
|
||||||
cp_async_commit();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allocate and zero the MMA accumulator
|
|
||||||
RegisterTile<float, BM / WM, BN / WN> C;
|
|
||||||
C.fill(0);
|
C.fill(0);
|
||||||
|
|
||||||
// Matmul loop
|
// Start the SM pipeline
|
||||||
int num_blocks = K / BK;
|
load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K);
|
||||||
int sread = 0;
|
load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K);
|
||||||
int swrite = PIPE - 1;
|
|
||||||
for (int i = 0; i < num_blocks; i++) {
|
|
||||||
cp_async_wait<PIPE - 1>();
|
|
||||||
|
|
||||||
gemm_ab_t<T, BM, BN, BK, WM, WN>(
|
|
||||||
C, as[sread], bs[sread], rloader_a, rloader_b);
|
|
||||||
|
|
||||||
if (false) {
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
|
|
||||||
cp_async<16>(sm_offsets[swrite][0] + l * SSTEP, a);
|
|
||||||
cp_async<16>(sm_offsets[swrite][1] + l * SSTEP, b);
|
|
||||||
a += sloader::STEP_ROWS * K;
|
|
||||||
b += sloader::STEP_ROWS * K;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cp_async_commit();
|
cp_async_commit();
|
||||||
|
|
||||||
swrite = sread;
|
int tic = 0;
|
||||||
sread = (sread + 1) % PIPE;
|
for (int k_block = BK; k_block < K; k_block += BK) {
|
||||||
|
load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);
|
||||||
|
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
|
||||||
|
cp_async_commit();
|
||||||
|
cp_async_wait<1>();
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int k = 0; k < BK / 16; k++) {
|
||||||
|
A.load(
|
||||||
|
as[tic],
|
||||||
|
as[tic].base_addr(),
|
||||||
|
offset_m + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
B.load(
|
||||||
|
bs[tic],
|
||||||
|
bs[tic].base_addr(),
|
||||||
|
offset_n + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
|
||||||
|
mma_t(C, A, B);
|
||||||
|
}
|
||||||
|
|
||||||
|
tic ^= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty the pipeline
|
||||||
|
cp_async_wait_all();
|
||||||
|
__syncthreads();
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int k = 0; k < BK / 16; k++) {
|
||||||
|
A.load(
|
||||||
|
as[tic],
|
||||||
|
as[tic].base_addr(),
|
||||||
|
offset_m + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
B.load(
|
||||||
|
bs[tic],
|
||||||
|
bs[tic].base_addr(),
|
||||||
|
offset_n + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
|
||||||
|
mma_t(C, A, B);
|
||||||
}
|
}
|
||||||
|
|
||||||
C.store_global(y, N, offset_m, offset_n);
|
C.store_global(y, N, offset_m, offset_n);
|
||||||
|
|||||||
@@ -223,10 +223,59 @@ struct RegisterTile {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple container of multiple Tile16x16.
|
||||||
|
*
|
||||||
|
* Provides utility functions for loading and manipulating collections of basic
|
||||||
|
* tiles.
|
||||||
|
*/
|
||||||
|
template <typename T, int ROWS_, int COLS_>
|
||||||
|
struct RegisterTile {
|
||||||
|
static constexpr int ROWS = ROWS_;
|
||||||
|
static constexpr int COLS = COLS_;
|
||||||
|
static constexpr int TILES_X = COLS / 16;
|
||||||
|
static constexpr int TILES_Y = ROWS / 16;
|
||||||
|
|
||||||
|
Tile16x16<T> data[TILES_X * TILES_Y];
|
||||||
|
|
||||||
|
__device__ inline void fill(T v) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].fill(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tile>
|
||||||
|
__device__ inline void
|
||||||
|
load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].load(
|
||||||
|
tile.loc(base_address, row + i * 16, col + j * 16));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
__device__ inline void store_global(U* x, int N, int row, int col) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].store_global(
|
||||||
|
x + (row + i * 16) * N + col + j * 16, N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, int ROWS_, int COLS_>
|
template <typename T, int ROWS_, int COLS_>
|
||||||
struct SharedTile {
|
struct SharedTile {
|
||||||
using value_type = T;
|
|
||||||
|
|
||||||
static constexpr int ROWS = ROWS_;
|
static constexpr int ROWS = ROWS_;
|
||||||
static constexpr int COLS = COLS_;
|
static constexpr int COLS = COLS_;
|
||||||
static constexpr int TILES_X = COLS / 16;
|
static constexpr int TILES_X = COLS / 16;
|
||||||
@@ -268,26 +317,23 @@ struct SharedTile {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ static inline uint32_t offset(int row, int col) {
|
// Return the location of the element at (row, col) using the swizzle.
|
||||||
|
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
|
||||||
if constexpr (swizzle_bytes > 0) {
|
if constexpr (swizzle_bytes > 0) {
|
||||||
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
||||||
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
||||||
const int outer_idx = col / subtile_cols;
|
const int outer_idx = col / subtile_cols;
|
||||||
const uint32_t addr = sizeof(T) *
|
const uint32_t addr = ptr +
|
||||||
|
sizeof(T) *
|
||||||
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
||||||
col % subtile_cols);
|
col % subtile_cols);
|
||||||
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
||||||
return (addr ^ swizzle);
|
return (addr ^ swizzle);
|
||||||
} else {
|
} else {
|
||||||
return sizeof(T) * (row * COLS + col);
|
return ptr + sizeof(T) * (row * COLS + col);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the location of the element at (row, col) using the swizzle.
|
|
||||||
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
|
|
||||||
return ptr + offset(row, col);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convenience functions to edit elements going through the swizzle.
|
// Convenience functions to edit elements going through the swizzle.
|
||||||
__device__ inline T& operator()(int row, int col) {
|
__device__ inline T& operator()(int row, int col) {
|
||||||
return *ptr(data, row, col);
|
return *ptr(data, row, col);
|
||||||
@@ -318,76 +364,6 @@ struct SharedTile {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int NUM_WARPS, typename Tile>
|
|
||||||
struct SharedTileLoader {
|
|
||||||
using T = typename Tile::value_type;
|
|
||||||
|
|
||||||
static constexpr int NUM_THREADS = NUM_WARPS * 32;
|
|
||||||
static constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
|
||||||
static constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
|
||||||
static constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
|
||||||
static constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
|
||||||
static constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
|
||||||
|
|
||||||
const T* x_;
|
|
||||||
int N_;
|
|
||||||
uint32_t offset_;
|
|
||||||
|
|
||||||
__device__ SharedTileLoader(const T* x, int N) : x_(x), N_(N) {
|
|
||||||
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
|
||||||
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
|
||||||
|
|
||||||
x_ += row * N + col * ELEMENTS_PER_LOAD;
|
|
||||||
offset_ = Tile::offset(row, col * ELEMENTS_PER_LOAD);
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ inline void load_async(uint32_t base_address) {
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
|
||||||
cp_async<16>(
|
|
||||||
base_address + offset_ + i * STEP_ROWS * sizeof(T) * Tile::COLS,
|
|
||||||
x_ + i * STEP_ROWS * N_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ inline void next() {
|
|
||||||
x_ += Tile::COLS;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Tile>
|
|
||||||
struct RegisterTileLoader {
|
|
||||||
using T = typename Tile::value_type;
|
|
||||||
|
|
||||||
uint32_t offset_[Tile::COLS / 16];
|
|
||||||
|
|
||||||
__device__ RegisterTileLoader(int offset_row, int laneid) {
|
|
||||||
const int row = offset_row + laneid & 15;
|
|
||||||
const int col = (laneid >> 4) << 3;
|
|
||||||
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int i = 0; i < Tile::COLS / 16; i++) {
|
|
||||||
offset_[i] = Tile::offset(row, col + i * 16);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, int ROWS, int COLS>
|
|
||||||
__device__ inline void
|
|
||||||
load(RegisterTile<T, ROWS, COLS>& x, uint32_t base_address, int col) {
|
|
||||||
constexpr int TILES_Y = RegisterTile<T, ROWS, COLS>::TILES_Y;
|
|
||||||
constexpr int TILES_X = RegisterTile<T, ROWS, COLS>::TILES_X;
|
|
||||||
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
|
||||||
x.data[i * TILES_X + j].load(
|
|
||||||
base_address + offset_[j + col] + i * 16 * Tile::COLS * sizeof(T));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load the tile from global memory by loading 16 bytes at a time and storing
|
* Load the tile from global memory by loading 16 bytes at a time and storing
|
||||||
* them immediately.
|
* them immediately.
|
||||||
|
|||||||
@@ -21,15 +21,15 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
|
|||||||
#if defined(MLX_CUDA_SM_80_ENABLED)
|
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||||
if constexpr (N == 16) {
|
if constexpr (N == 16) {
|
||||||
asm volatile(
|
asm volatile(
|
||||||
"cp.async.cg.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
||||||
"l"(reinterpret_cast<const int4*>(x)));
|
"l"(reinterpret_cast<const int4*>(x)));
|
||||||
} else if constexpr (N == 8) {
|
} else if constexpr (N == 8) {
|
||||||
asm volatile(
|
asm volatile(
|
||||||
"cp.async.cg.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
|
"cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
|
||||||
"l"(reinterpret_cast<const int2*>(x)));
|
"l"(reinterpret_cast<const int2*>(x)));
|
||||||
} else if constexpr (N == 4) {
|
} else if constexpr (N == 4) {
|
||||||
asm volatile(
|
asm volatile(
|
||||||
"cp.async.cg.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
|
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
|
||||||
"l"(reinterpret_cast<const int*>(x)));
|
"l"(reinterpret_cast<const int*>(x)));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ std::string write_template(
|
|||||||
return template_def.str();
|
return template_def.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
CustomKernelFunction metal_kernel(
|
MetalKernelFunction metal_kernel(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const std::vector<std::string>& input_names,
|
const std::vector<std::string>& input_names,
|
||||||
const std::vector<std::string>& output_names,
|
const std::vector<std::string>& output_names,
|
||||||
@@ -316,10 +316,7 @@ CustomKernelFunction metal_kernel(
|
|||||||
threadgroup,
|
threadgroup,
|
||||||
shape_infos,
|
shape_infos,
|
||||||
ensure_row_contiguous,
|
ensure_row_contiguous,
|
||||||
init_value,
|
init_value),
|
||||||
std::vector<ScalarArg>{},
|
|
||||||
false,
|
|
||||||
0),
|
|
||||||
std::move(inputs));
|
std::move(inputs));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ struct Conv2DInputBlockLoaderSmallChannels {
|
|||||||
const constant MLXConvParams<2>* params;
|
const constant MLXConvParams<2>* params;
|
||||||
const constant ImplicitGemmConv2DParams* gemm_params;
|
const constant ImplicitGemmConv2DParams* gemm_params;
|
||||||
|
|
||||||
int weight_hw;
|
short weight_hw;
|
||||||
|
|
||||||
const device T* src[n_rows];
|
const device T* src[n_rows];
|
||||||
|
|
||||||
|
|||||||
@@ -26,15 +26,15 @@ device_info() {
|
|||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
|
|
||||||
CustomKernelFunction metal_kernel(
|
MetalKernelFunction metal_kernel(
|
||||||
const std::string&,
|
const std::string&,
|
||||||
const std::vector<std::string>&,
|
const std::vector<std::string>&,
|
||||||
const std::vector<std::string>&,
|
const std::vector<std::string>&,
|
||||||
const std::string&,
|
const std::string&,
|
||||||
const std::string&,
|
const std::string&,
|
||||||
bool,
|
bool ensure_row_contiguous,
|
||||||
bool) {
|
bool atomic_outputs) {
|
||||||
throw std::runtime_error("[metal_kernel] No Metal back-end.");
|
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|||||||
@@ -6,3 +6,4 @@ target_sources(
|
|||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||||
|
|||||||
@@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
#include "mlx/distributed/ring/ring.h"
|
#include "mlx/distributed/ring/ring.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
@@ -80,7 +82,7 @@ class EmptyGroup : public GroupImpl {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi::is_available() || ring::is_available();
|
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
int Group::rank() const {
|
int Group::rank() const {
|
||||||
@@ -111,6 +113,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = mpi::init(strict);
|
group = mpi::init(strict);
|
||||||
} else if (bk == "ring") {
|
} else if (bk == "ring") {
|
||||||
group = ring::init(strict);
|
group = ring::init(strict);
|
||||||
|
} else if (bk == "nccl") {
|
||||||
|
group = nccl::init(strict);
|
||||||
} else if (bk == "any") {
|
} else if (bk == "any") {
|
||||||
group = ring::init(false);
|
group = ring::init(false);
|
||||||
bk_ = "ring";
|
bk_ = "ring";
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|||||||
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
|
||||||
|
find_package(NCCL REQUIRED)
|
||||||
|
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
|
||||||
|
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
|
||||||
|
else()
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
|
||||||
|
endif()
|
||||||
354
mlx/distributed/nccl/nccl.cpp
Normal file
354
mlx/distributed/nccl/nccl.cpp
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
#include <arpa/inet.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <nccl.h>
|
||||||
|
#include <netdb.h>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <iostream>
|
||||||
|
#include <mutex>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
#define CHECK_CUDA(cmd) \
|
||||||
|
do { \
|
||||||
|
cudaError_t e = cmd; \
|
||||||
|
if (e != cudaSuccess) { \
|
||||||
|
fprintf( \
|
||||||
|
stderr, \
|
||||||
|
"CUDA error %s:%d '%s'\n", \
|
||||||
|
__FILE__, \
|
||||||
|
__LINE__, \
|
||||||
|
cudaGetErrorString(e)); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define CHECK_NCCL(cmd) \
|
||||||
|
do { \
|
||||||
|
ncclResult_t r = cmd; \
|
||||||
|
if (r != ncclSuccess) { \
|
||||||
|
fprintf( \
|
||||||
|
stderr, \
|
||||||
|
"NCCL error %s:%d '%s'\n", \
|
||||||
|
__FILE__, \
|
||||||
|
__LINE__, \
|
||||||
|
ncclGetErrorString(r)); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define MLX_NCCL_TYPE_LIST(X) \
|
||||||
|
X(int8_t, ncclChar) \
|
||||||
|
X(uint8_t, ncclUint8) \
|
||||||
|
X(int32_t, ncclInt) \
|
||||||
|
X(uint32_t, ncclUint32) \
|
||||||
|
X(int64_t, ncclInt64) \
|
||||||
|
X(uint64_t, ncclUint64) \
|
||||||
|
X(float16_t, ncclHalf) \
|
||||||
|
X(bfloat16_t, ncclBfloat16) \
|
||||||
|
X(float, ncclFloat) \
|
||||||
|
X(double, ncclDouble)
|
||||||
|
|
||||||
|
template <class>
|
||||||
|
struct nccl_map {
|
||||||
|
static constexpr bool ok = false; // default: unsupported
|
||||||
|
};
|
||||||
|
|
||||||
|
#define MLX_DEF_NCCL_MAP(T, E) \
|
||||||
|
template <> \
|
||||||
|
struct nccl_map<T> { \
|
||||||
|
static constexpr bool ok = true; \
|
||||||
|
static constexpr ncclDataType_t value = E; \
|
||||||
|
};
|
||||||
|
|
||||||
|
MLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP)
|
||||||
|
#undef MLX_DEF_NCCL_MAP
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_dtype(const array& arr, F&& f) {
|
||||||
|
dispatch_all_types(arr.dtype(), [&](auto type_tag) {
|
||||||
|
using T = MLX_GET_TYPE(type_tag);
|
||||||
|
if constexpr (nccl_map<T>::ok) {
|
||||||
|
f(type_tag, nccl_map<T>::value);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void sendAll(int sock, const void* buf, size_t len) {
|
||||||
|
const char* ptr = reinterpret_cast<const char*>(buf);
|
||||||
|
while (len > 0) {
|
||||||
|
ssize_t sent = send(sock, ptr, len, 0);
|
||||||
|
if (sent <= 0) {
|
||||||
|
perror("send");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
ptr += sent;
|
||||||
|
len -= sent;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void recvAll(int sock, void* buf, size_t len) {
|
||||||
|
char* ptr = reinterpret_cast<char*>(buf);
|
||||||
|
while (len > 0) {
|
||||||
|
ssize_t rec = recv(sock, ptr, len, 0);
|
||||||
|
if (rec <= 0) {
|
||||||
|
perror("recv");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
ptr += rec;
|
||||||
|
len -= rec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void bootstrap_unique_id(
|
||||||
|
ncclUniqueId& id,
|
||||||
|
int rank,
|
||||||
|
int size,
|
||||||
|
const std::string& initMethod) {
|
||||||
|
// Parse the init method to extract the host and port
|
||||||
|
if (initMethod.rfind("tcp://", 0) != 0)
|
||||||
|
throw;
|
||||||
|
auto hostport = initMethod.substr(6);
|
||||||
|
auto colon = hostport.find(':');
|
||||||
|
std::string host = hostport.substr(0, colon);
|
||||||
|
int port = std::stoi(hostport.substr(colon + 1));
|
||||||
|
|
||||||
|
if (rank == 0) {
|
||||||
|
// create a unique id on the rank 0
|
||||||
|
CHECK_NCCL(ncclGetUniqueId(&id));
|
||||||
|
|
||||||
|
// create a socket to send the unique id to all other ranks
|
||||||
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
sockaddr_in serv = {};
|
||||||
|
serv.sin_family = AF_INET;
|
||||||
|
serv.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||||
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
|
int reuse = 1;
|
||||||
|
// Without this, if rank-0 crashes or restarts process quickly,
|
||||||
|
// the OS might refuse to let binding to the same port, so reuse
|
||||||
|
|
||||||
|
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] bind() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
if (listen(sock, size - 1) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] listen() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int peer = 1; peer < size; ++peer) {
|
||||||
|
int conn = accept(sock, nullptr, nullptr);
|
||||||
|
if (conn < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] accept() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
sendAll(conn, &id, sizeof(id));
|
||||||
|
close(conn);
|
||||||
|
}
|
||||||
|
close(sock);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Here just wanted to make show that rank 0 has enough time to bind
|
||||||
|
// so we will retry to connect until max attempts
|
||||||
|
|
||||||
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] socket() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
hostent* he = gethostbyname(host.c_str());
|
||||||
|
if (!he) {
|
||||||
|
throw std::runtime_error("[nccl] lookup failed for host: " + host);
|
||||||
|
}
|
||||||
|
sockaddr_in serv = {};
|
||||||
|
serv.sin_family = AF_INET;
|
||||||
|
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
||||||
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
|
const int max_retries = 30;
|
||||||
|
int attempt = 0;
|
||||||
|
bool connected = false;
|
||||||
|
|
||||||
|
for (attempt = 0; attempt < max_retries; ++attempt) {
|
||||||
|
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
||||||
|
0) {
|
||||||
|
connected = true;
|
||||||
|
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
|
||||||
|
<< attempt + 1 << std::endl;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (errno != ECONNREFUSED) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!connected) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[Rank " << rank << "] connect() failed after " << attempt
|
||||||
|
<< " retries: " << strerror(errno);
|
||||||
|
close(sock);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
recvAll(sock, &id, sizeof(id));
|
||||||
|
close(sock);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
class NCCLGroup : public GroupImpl {
|
||||||
|
public:
|
||||||
|
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
|
||||||
|
: rank_(worldRank),
|
||||||
|
size_(worldSize),
|
||||||
|
comm_(nullptr),
|
||||||
|
initMethod_(initMethod) {
|
||||||
|
if (initialized_)
|
||||||
|
return;
|
||||||
|
int ndev;
|
||||||
|
CHECK_CUDA(cudaGetDeviceCount(&ndev));
|
||||||
|
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
|
||||||
|
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
|
||||||
|
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
~NCCLGroup() {
|
||||||
|
ncclCommDestroy(comm_);
|
||||||
|
ncclGroupEnd();
|
||||||
|
initialized_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank() override {
|
||||||
|
return rank_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int size() override {
|
||||||
|
return size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_sum(const array& input, array& output, Stream stream) override {
|
||||||
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
|
throw std::runtime_error("[nccl] Group split not supported.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_gather(const array& input, array& output, Stream stream) override {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[nccl] All gather not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void send(const array& input, int dst, Stream stream) override {
|
||||||
|
throw std::runtime_error("[nccl] Send not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv(array& output, int src, Stream stream) override {
|
||||||
|
throw std::runtime_error("[nccl] Recv not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_max(const array& input, array& output, Stream stream) override {
|
||||||
|
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
|
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void all_reduce_impl(
|
||||||
|
const array& input,
|
||||||
|
array& output,
|
||||||
|
Stream stream,
|
||||||
|
ncclDataType_t dt,
|
||||||
|
ncclRedOp_t op) {
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
|
||||||
|
CHECK_NCCL(ncclAllReduce(
|
||||||
|
input.data<T>(),
|
||||||
|
output.data<T>(),
|
||||||
|
input.size(),
|
||||||
|
dt,
|
||||||
|
op,
|
||||||
|
comm_,
|
||||||
|
encoder.stream()));
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank_, size_;
|
||||||
|
std::string initMethod_;
|
||||||
|
ncclUniqueId uniqueId_;
|
||||||
|
ncclComm_t comm_;
|
||||||
|
bool initialized_ = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
static std::string get_env_var_or_throw(const char* env_var_name) {
|
||||||
|
const char* value = std::getenv(env_var_name);
|
||||||
|
if (value == nullptr) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] Required environment variable '" << env_var_name
|
||||||
|
<< "' is not set. "
|
||||||
|
<< "Please set it before initializing the distributed backend.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
return std::string(value);
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
|
||||||
|
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
|
||||||
|
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
|
||||||
|
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
|
||||||
|
|
||||||
|
int rank = std::stoi(rank_str);
|
||||||
|
int n_nodes = std::stoi(n_nodes_str);
|
||||||
|
std::string init_method = "tcp://" + host + ":" + port;
|
||||||
|
|
||||||
|
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
|
||||||
|
}
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
||||||
12
mlx/distributed/nccl/nccl.h
Normal file
12
mlx/distributed/nccl/nccl.h
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
||||||
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
if (strict) {
|
||||||
|
throw std::runtime_error("Cannot initialize nccl distributed backend.");
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
||||||
@@ -2,9 +2,20 @@
|
|||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
|
|
||||||
|
inline mlx::core::Device get_device() {
|
||||||
|
if (mlx::core::metal::is_available()) {
|
||||||
|
return mlx::core::Device::cpu;
|
||||||
|
} else if (mlx::core::cu::is_available()) {
|
||||||
|
return mlx::core::Device::gpu;
|
||||||
|
}
|
||||||
|
throw std::runtime_error("No available device for distributed operations.");
|
||||||
|
}
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@@ -24,6 +35,7 @@ array all_sum(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
auto dev = get_device();
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
@@ -31,8 +43,7 @@ array all_sum(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(
|
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Sum),
|
||||||
to_stream(s, Device::cpu), group, AllReduce::Sum),
|
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,6 +52,7 @@ array all_max(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
auto dev = get_device();
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
@@ -48,8 +60,7 @@ array all_max(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(
|
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Max),
|
||||||
to_stream(s, Device::cpu), group, AllReduce::Max),
|
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,6 +69,7 @@ array all_min(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
auto dev = get_device();
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
@@ -65,8 +77,7 @@ array all_min(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(
|
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Min),
|
||||||
to_stream(s, Device::cpu), group, AllReduce::Min),
|
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,6 +86,7 @@ array all_gather(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
auto dev = get_device();
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
@@ -89,7 +101,7 @@ array all_gather(
|
|||||||
return array(
|
return array(
|
||||||
std::move(result_shape),
|
std::move(result_shape),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllGather>(to_stream(s, Device::cpu), group),
|
std::make_shared<AllGather>(to_stream(s, dev), group),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,6 +111,7 @@ array send(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
auto dev = get_device();
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
throw std::invalid_argument("Cannot send to a singleton group");
|
throw std::invalid_argument("Cannot send to a singleton group");
|
||||||
@@ -114,7 +127,7 @@ array send(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<Send>(to_stream(s, Device::cpu), group, dst),
|
std::make_shared<Send>(to_stream(s, dev), group, dst),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,6 +138,7 @@ array recv(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
auto dev = get_device();
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
throw std::invalid_argument("Cannot recv from a singleton group");
|
throw std::invalid_argument("Cannot recv from a singleton group");
|
||||||
@@ -139,7 +153,7 @@ array recv(
|
|||||||
return array(
|
return array(
|
||||||
std::move(shape),
|
std::move(shape),
|
||||||
std::move(dtype),
|
std::move(dtype),
|
||||||
std::make_shared<Recv>(to_stream(s, Device::cpu), group, src),
|
std::make_shared<Recv>(to_stream(s, dev), group, src),
|
||||||
std::vector<array>{});
|
std::vector<array>{});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
33
mlx/fast.h
33
mlx/fast.h
@@ -66,10 +66,9 @@ array affine_dequantize(
|
|||||||
int bits = 4,
|
int bits = 4,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
using TemplateArg = std::variant<int, bool, Dtype>;
|
typedef std::variant<int, bool, Dtype> TemplateArg;
|
||||||
using ScalarArg = std::variant<bool, int, float>;
|
|
||||||
|
|
||||||
using CustomKernelFunction = std::function<std::vector<array>(
|
typedef std::function<std::vector<array>(
|
||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
const std::vector<Shape>&,
|
const std::vector<Shape>&,
|
||||||
const std::vector<Dtype>&,
|
const std::vector<Dtype>&,
|
||||||
@@ -78,9 +77,10 @@ using CustomKernelFunction = std::function<std::vector<array>(
|
|||||||
std::vector<std::pair<std::string, TemplateArg>>,
|
std::vector<std::pair<std::string, TemplateArg>>,
|
||||||
std::optional<float>,
|
std::optional<float>,
|
||||||
bool,
|
bool,
|
||||||
StreamOrDevice)>;
|
StreamOrDevice)>
|
||||||
|
MetalKernelFunction;
|
||||||
|
|
||||||
CustomKernelFunction metal_kernel(
|
MetalKernelFunction metal_kernel(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const std::vector<std::string>& input_names,
|
const std::vector<std::string>& input_names,
|
||||||
const std::vector<std::string>& output_names,
|
const std::vector<std::string>& output_names,
|
||||||
@@ -89,27 +89,4 @@ CustomKernelFunction metal_kernel(
|
|||||||
bool ensure_row_contiguous = true,
|
bool ensure_row_contiguous = true,
|
||||||
bool atomic_outputs = false);
|
bool atomic_outputs = false);
|
||||||
|
|
||||||
CustomKernelFunction cuda_kernel(
|
|
||||||
const std::string& name,
|
|
||||||
const std::vector<std::string>& input_names,
|
|
||||||
const std::vector<std::string>& output_names,
|
|
||||||
const std::string& source,
|
|
||||||
const std::string& header = "",
|
|
||||||
bool ensure_row_contiguous = true,
|
|
||||||
int shared_memory = 0);
|
|
||||||
|
|
||||||
std::vector<array> precompiled_cuda_kernel(
|
|
||||||
const std::string& name,
|
|
||||||
const std::string& compiled_source,
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<Shape>& output_shapes,
|
|
||||||
const std::vector<Dtype>& output_dtypes,
|
|
||||||
const std::vector<ScalarArg>& scalars,
|
|
||||||
std::tuple<int, int, int> grid,
|
|
||||||
std::tuple<int, int, int> threadgroup,
|
|
||||||
int shared_memory = 0,
|
|
||||||
std::optional<float> init_value = std::nullopt,
|
|
||||||
bool ensure_row_contiguous = false,
|
|
||||||
StreamOrDevice s = {});
|
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <variant>
|
|
||||||
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@@ -284,8 +283,6 @@ struct CustomKernelShapeInfo {
|
|||||||
bool ndim = false;
|
bool ndim = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
using ScalarArg = std::variant<bool, int, float>;
|
|
||||||
|
|
||||||
class CustomKernel : public Primitive {
|
class CustomKernel : public Primitive {
|
||||||
public:
|
public:
|
||||||
CustomKernel(
|
CustomKernel(
|
||||||
@@ -296,10 +293,7 @@ class CustomKernel : public Primitive {
|
|||||||
std::tuple<int, int, int> threadgroup,
|
std::tuple<int, int, int> threadgroup,
|
||||||
std::vector<CustomKernelShapeInfo> shape_infos,
|
std::vector<CustomKernelShapeInfo> shape_infos,
|
||||||
bool ensure_row_contiguous,
|
bool ensure_row_contiguous,
|
||||||
std::optional<float> init_value,
|
std::optional<float> init_value)
|
||||||
std::vector<ScalarArg> scalar_arguments,
|
|
||||||
bool is_precompiled,
|
|
||||||
int shared_memory)
|
|
||||||
: Primitive(stream),
|
: Primitive(stream),
|
||||||
source_(std::move(source)),
|
source_(std::move(source)),
|
||||||
name_(std::move(name)),
|
name_(std::move(name)),
|
||||||
@@ -307,14 +301,11 @@ class CustomKernel : public Primitive {
|
|||||||
threadgroup_(threadgroup),
|
threadgroup_(threadgroup),
|
||||||
shape_infos_(std::move(shape_infos)),
|
shape_infos_(std::move(shape_infos)),
|
||||||
ensure_row_contiguous_(ensure_row_contiguous),
|
ensure_row_contiguous_(ensure_row_contiguous),
|
||||||
init_value_(init_value),
|
init_value_(init_value) {}
|
||||||
scalar_arguments_(std::move(scalar_arguments)),
|
|
||||||
is_precompiled_(is_precompiled),
|
|
||||||
shared_memory_(shared_memory) {}
|
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
throw std::runtime_error("Custom kernels only run on GPU.");
|
throw std::runtime_error("Custom Metal kernels only run on GPU.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
@@ -330,9 +321,6 @@ class CustomKernel : public Primitive {
|
|||||||
std::vector<CustomKernelShapeInfo> shape_infos_;
|
std::vector<CustomKernelShapeInfo> shape_infos_;
|
||||||
bool ensure_row_contiguous_;
|
bool ensure_row_contiguous_;
|
||||||
std::optional<float> init_value_;
|
std::optional<float> init_value_;
|
||||||
std::vector<ScalarArg> scalar_arguments_;
|
|
||||||
bool is_precompiled_;
|
|
||||||
int shared_memory_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
|||||||
@@ -415,6 +415,48 @@ def launch_mpi(parser, hosts, args, command):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def launch_nccl(parser, hosts, args, command):
|
||||||
|
master_host = hosts[0].ips[0]
|
||||||
|
|
||||||
|
if master_host != "127.0.0.1":
|
||||||
|
raise ValueError("The NCCL backend only supports localhost for now. ")
|
||||||
|
master_port = args.nccl_port
|
||||||
|
world_size = len(hosts)
|
||||||
|
|
||||||
|
base_env = os.environ.copy()
|
||||||
|
base_env.update(
|
||||||
|
{
|
||||||
|
"NCCL_DEBUG": "INFO",
|
||||||
|
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
||||||
|
"NCCL_HOST_IP": master_host,
|
||||||
|
"NCCL_PORT": str(master_port),
|
||||||
|
"MLX_WORLD_SIZE": str(world_size),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
procs = []
|
||||||
|
try:
|
||||||
|
for rank in range(world_size):
|
||||||
|
env = base_env.copy()
|
||||||
|
env["MLX_RANK"] = str(rank)
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
|
||||||
|
p = Popen(command, env=env)
|
||||||
|
procs.append(p)
|
||||||
|
|
||||||
|
for p in procs:
|
||||||
|
ret = p.wait()
|
||||||
|
if ret != 0:
|
||||||
|
raise RuntimeError(f"Rank process exited with {ret}")
|
||||||
|
|
||||||
|
except (RuntimeError, KeyboardInterrupt) as err:
|
||||||
|
for p in procs:
|
||||||
|
if p.poll() is None:
|
||||||
|
try:
|
||||||
|
p.kill()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def check_ssh_connections(hosts):
|
def check_ssh_connections(hosts):
|
||||||
results = [False] * len(hosts)
|
results = [False] * len(hosts)
|
||||||
|
|
||||||
@@ -665,7 +707,7 @@ def distributed_config():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi"],
|
choices=["ring", "mpi", "nccl"],
|
||||||
default="ring",
|
default="ring",
|
||||||
help="Which distributed backend to configure",
|
help="Which distributed backend to configure",
|
||||||
)
|
)
|
||||||
@@ -737,7 +779,7 @@ def main():
|
|||||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi"],
|
choices=["ring", "mpi", "nccl"],
|
||||||
default="ring",
|
default="ring",
|
||||||
help="Which distributed backend to launch",
|
help="Which distributed backend to launch",
|
||||||
)
|
)
|
||||||
@@ -769,6 +811,13 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cwd", help="Set the working directory on each node to the provided one"
|
"--cwd", help="Set the working directory on each node to the provided one"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--nccl-port",
|
||||||
|
type=int,
|
||||||
|
default=12345,
|
||||||
|
help="The port to use for the NCCL communication (only for nccl backend)",
|
||||||
|
)
|
||||||
|
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
if rest[0] == "--":
|
if rest[0] == "--":
|
||||||
rest.pop(0)
|
rest.pop(0)
|
||||||
@@ -799,8 +848,10 @@ def main():
|
|||||||
# Launch
|
# Launch
|
||||||
if args.backend == "ring":
|
if args.backend == "ring":
|
||||||
launch_ring(parser, hosts, args, rest)
|
launch_ring(parser, hosts, args, rest)
|
||||||
elif args.backend == "mpi":
|
if args.backend == "mpi":
|
||||||
launch_mpi(parser, hosts, args, rest)
|
launch_mpi(parser, hosts, args, rest)
|
||||||
|
if args.backend == "nccl":
|
||||||
|
launch_nccl(parser, hosts, args, rest)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ def average_gradients(
|
|||||||
group: Optional[mx.distributed.Group] = None,
|
group: Optional[mx.distributed.Group] = None,
|
||||||
all_reduce_size: int = 32 * 1024**2,
|
all_reduce_size: int = 32 * 1024**2,
|
||||||
communication_type: Optional[mx.Dtype] = None,
|
communication_type: Optional[mx.Dtype] = None,
|
||||||
|
stream: mx.Stream = mx.cpu,
|
||||||
):
|
):
|
||||||
"""Average the gradients across the distributed processes in the passed group.
|
"""Average the gradients across the distributed processes in the passed group.
|
||||||
|
|
||||||
@@ -94,6 +95,7 @@ def average_gradients(
|
|||||||
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
||||||
type before performing the communication. Typically cast to a
|
type before performing the communication. Typically cast to a
|
||||||
smaller float to reduce the communication size. Default: ``None``.
|
smaller float to reduce the communication size. Default: ``None``.
|
||||||
|
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
|
||||||
"""
|
"""
|
||||||
group = group or mx.distributed.init()
|
group = group or mx.distributed.init()
|
||||||
N = group.size()
|
N = group.size()
|
||||||
@@ -104,7 +106,7 @@ def average_gradients(
|
|||||||
def _average(x):
|
def _average(x):
|
||||||
dt = x.dtype
|
dt = x.dtype
|
||||||
x = x.astype(communication_type) if communication_type is not None else x
|
x = x.astype(communication_type) if communication_type is not None else x
|
||||||
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
|
return mx.distributed.all_sum(x, stream=stream).astype(dt) / N
|
||||||
|
|
||||||
if all_reduce_size <= 0:
|
if all_reduce_size <= 0:
|
||||||
return tree_map(_average, gradients)
|
return tree_map(_average, gradients)
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ nanobind_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
// Copyright © 2023-2025 Apple Inc.
|
|
||||||
|
|
||||||
#include <nanobind/nanobind.h>
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/cuda.h"
|
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
|
||||||
namespace nb = nanobind;
|
|
||||||
|
|
||||||
void init_cuda(nb::module_& m) {
|
|
||||||
nb::module_ cuda = m.def_submodule("cuda", "mlx.cuda");
|
|
||||||
|
|
||||||
cuda.def(
|
|
||||||
"is_available",
|
|
||||||
&mx::cu::is_available,
|
|
||||||
R"pbdoc(
|
|
||||||
Check if the CUDA back-end is available.
|
|
||||||
)pbdoc");
|
|
||||||
}
|
|
||||||
@@ -79,7 +79,7 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||||
it throws a runtime error. Default: ``False``
|
it throws a runtime error. Default: ``False``
|
||||||
backend (str, optional): Which distributed backend to initialize.
|
backend (str, optional): Which distributed backend to initialize.
|
||||||
Possible values ``mpi``, ``ring``, ``any``. If set to ``any`` all
|
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
|
||||||
available backends are tried and the first one that succeeds
|
available backends are tried and the first one that succeeds
|
||||||
becomes the global group which will be returned in subsequent
|
becomes the global group which will be returned in subsequent
|
||||||
calls. Default: ``any``
|
calls. Default: ``any``
|
||||||
|
|||||||
@@ -17,66 +17,6 @@ namespace mx = mlx::core;
|
|||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
struct PyCustomKernelFunction {
|
|
||||||
PyCustomKernelFunction(mx::fast::CustomKernelFunction kernel, const char* tag)
|
|
||||||
: kernel_(std::move(kernel)), tag_(tag) {}
|
|
||||||
|
|
||||||
std::vector<mx::array> operator()(
|
|
||||||
const std::vector<ScalarOrArray>& inputs_,
|
|
||||||
const std::vector<mx::Shape>& output_shapes,
|
|
||||||
const std::vector<mx::Dtype>& output_dtypes,
|
|
||||||
std::tuple<int, int, int> grid,
|
|
||||||
std::tuple<int, int, int> threadgroup,
|
|
||||||
const std::optional<std::vector<std::pair<std::string, nb::object>>>&
|
|
||||||
template_args_ = std::nullopt,
|
|
||||||
std::optional<float> init_value = std::nullopt,
|
|
||||||
bool verbose = false,
|
|
||||||
mx::StreamOrDevice s = {}) const {
|
|
||||||
std::vector<mx::array> inputs;
|
|
||||||
for (const auto& value : inputs_) {
|
|
||||||
inputs.push_back(to_array(value, std::nullopt));
|
|
||||||
}
|
|
||||||
std::vector<std::pair<std::string, mx::fast::TemplateArg>> template_args;
|
|
||||||
if (template_args_) {
|
|
||||||
for (const auto& [name, value] : template_args_.value()) {
|
|
||||||
// Handle bool, int and dtype template args
|
|
||||||
if (nb::isinstance<bool>(value)) {
|
|
||||||
bool bool_val = nb::cast<bool>(value);
|
|
||||||
template_args.emplace_back(name, bool_val);
|
|
||||||
} else if (nb::isinstance<int>(value)) {
|
|
||||||
int int_val = nb::cast<int>(value);
|
|
||||||
template_args.emplace_back(name, int_val);
|
|
||||||
} else if (nb::isinstance<mx::Dtype>(value)) {
|
|
||||||
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
|
|
||||||
template_args.emplace_back(name, dtype);
|
|
||||||
} else {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << tag_
|
|
||||||
<< " Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return kernel_(
|
|
||||||
inputs,
|
|
||||||
output_shapes,
|
|
||||||
output_dtypes,
|
|
||||||
grid,
|
|
||||||
threadgroup,
|
|
||||||
template_args,
|
|
||||||
init_value,
|
|
||||||
verbose,
|
|
||||||
s);
|
|
||||||
}
|
|
||||||
|
|
||||||
mx::fast::CustomKernelFunction kernel_;
|
|
||||||
const char* tag_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void init_fast(nb::module_& parent_module) {
|
void init_fast(nb::module_& parent_module) {
|
||||||
auto m =
|
auto m =
|
||||||
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
|
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
|
||||||
@@ -300,7 +240,53 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
ensure_row_contiguous,
|
ensure_row_contiguous,
|
||||||
atomic_outputs);
|
atomic_outputs);
|
||||||
return nb::cpp_function(
|
return nb::cpp_function(
|
||||||
PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"),
|
[kernel = std::move(kernel)](
|
||||||
|
const std::vector<ScalarOrArray>& inputs_,
|
||||||
|
const std::vector<mx::Shape>& output_shapes,
|
||||||
|
const std::vector<mx::Dtype>& output_dtypes,
|
||||||
|
std::tuple<int, int, int> grid,
|
||||||
|
std::tuple<int, int, int> threadgroup,
|
||||||
|
const std::optional<
|
||||||
|
std::vector<std::pair<std::string, nb::object>>>&
|
||||||
|
template_args_ = std::nullopt,
|
||||||
|
std::optional<float> init_value = std::nullopt,
|
||||||
|
bool verbose = false,
|
||||||
|
mx::StreamOrDevice s = {}) {
|
||||||
|
std::vector<mx::array> inputs;
|
||||||
|
for (const auto& value : inputs_) {
|
||||||
|
inputs.push_back(to_array(value, std::nullopt));
|
||||||
|
}
|
||||||
|
std::vector<std::pair<std::string, mx::fast::TemplateArg>>
|
||||||
|
template_args;
|
||||||
|
if (template_args_) {
|
||||||
|
for (const auto& [name, value] : template_args_.value()) {
|
||||||
|
// Handle bool, int and dtype template args
|
||||||
|
if (nb::isinstance<bool>(value)) {
|
||||||
|
bool bool_val = nb::cast<bool>(value);
|
||||||
|
template_args.emplace_back(name, bool_val);
|
||||||
|
} else if (nb::isinstance<int>(value)) {
|
||||||
|
int int_val = nb::cast<int>(value);
|
||||||
|
template_args.emplace_back(name, int_val);
|
||||||
|
} else if (nb::isinstance<mx::Dtype>(value)) {
|
||||||
|
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
|
||||||
|
template_args.emplace_back(name, dtype);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kernel(
|
||||||
|
inputs,
|
||||||
|
output_shapes,
|
||||||
|
output_dtypes,
|
||||||
|
grid,
|
||||||
|
threadgroup,
|
||||||
|
template_args,
|
||||||
|
init_value,
|
||||||
|
verbose,
|
||||||
|
s);
|
||||||
|
},
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"inputs"_a,
|
"inputs"_a,
|
||||||
"output_shapes"_a,
|
"output_shapes"_a,
|
||||||
@@ -398,216 +384,4 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
b = exp_elementwise(a)
|
b = exp_elementwise(a)
|
||||||
assert mx.allclose(b, mx.exp(a))
|
assert mx.allclose(b, mx.exp(a))
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
|
||||||
"cuda_kernel",
|
|
||||||
[](const std::string& name,
|
|
||||||
const std::vector<std::string>& input_names,
|
|
||||||
const std::vector<std::string>& output_names,
|
|
||||||
const std::string& source,
|
|
||||||
const std::string& header,
|
|
||||||
bool ensure_row_contiguous,
|
|
||||||
int shared_mem) {
|
|
||||||
auto kernel = mx::fast::cuda_kernel(
|
|
||||||
name,
|
|
||||||
input_names,
|
|
||||||
output_names,
|
|
||||||
source,
|
|
||||||
header,
|
|
||||||
ensure_row_contiguous,
|
|
||||||
shared_mem);
|
|
||||||
return nb::cpp_function(
|
|
||||||
PyCustomKernelFunction(std::move(kernel), "[cuda_kernel]"),
|
|
||||||
nb::kw_only(),
|
|
||||||
"inputs"_a,
|
|
||||||
"output_shapes"_a,
|
|
||||||
"output_dtypes"_a,
|
|
||||||
"grid"_a,
|
|
||||||
"threadgroup"_a,
|
|
||||||
"template"_a = nb::none(),
|
|
||||||
"init_value"_a = nb::none(),
|
|
||||||
"verbose"_a = false,
|
|
||||||
"stream"_a = nb::none(),
|
|
||||||
nb::sig(
|
|
||||||
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
|
||||||
R"pbdoc(
|
|
||||||
Run the kernel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs (List[array]): The inputs passed to the CUDA kernel.
|
|
||||||
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
|
||||||
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
|
||||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
|
||||||
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadgroups.
|
|
||||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
|
||||||
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
|
||||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
|
||||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
|
||||||
By default, output arrays are uninitialized. Default: ``None``.
|
|
||||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
|
||||||
when it is run. Default: ``False``.
|
|
||||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[array]: The list of output arrays.)pbdoc");
|
|
||||||
},
|
|
||||||
"name"_a,
|
|
||||||
"input_names"_a,
|
|
||||||
"output_names"_a,
|
|
||||||
"source"_a,
|
|
||||||
"header"_a = "",
|
|
||||||
"ensure_row_contiguous"_a = true,
|
|
||||||
"shared_memory"_a = 0,
|
|
||||||
R"pbdoc(
|
|
||||||
A jit-compiled custom CUDA kernel defined from a source string.
|
|
||||||
|
|
||||||
This is the CUDA equivalent of :ref:`custom_metal_kernels`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): Name for the kernel.
|
|
||||||
input_names (List[str]): The parameter names of the inputs in the
|
|
||||||
function signature.
|
|
||||||
output_names (List[str]): The parameter names of the outputs in the
|
|
||||||
function signature.
|
|
||||||
source (str): Source code. This is the body of a function in CUDA,
|
|
||||||
the function signature will be automatically generated.
|
|
||||||
header (str): Header source code to include before the main function.
|
|
||||||
Useful for helper functions or includes that should live outside of
|
|
||||||
the main function body.
|
|
||||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
|
||||||
before the kernel runs. Default: ``True``.
|
|
||||||
shared_memory (int): The dynamic shared memory to request for the
|
|
||||||
kernel. A value of 0 means no dynamic shared memory. Default: ``0``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Callable ``cuda_kernel``.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
|
||||||
source = '''
|
|
||||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
|
||||||
T tmp = inp[elem];
|
|
||||||
out[elem] = exp(tmp);
|
|
||||||
'''
|
|
||||||
|
|
||||||
kernel = mx.fast.cuda_kernel(
|
|
||||||
name="myexp",
|
|
||||||
input_names=["inp"],
|
|
||||||
output_names=["out"],
|
|
||||||
source=source
|
|
||||||
)
|
|
||||||
outputs = kernel(
|
|
||||||
inputs=[a],
|
|
||||||
template=[("T", mx.float32)],
|
|
||||||
grid=(a.size, 1, 1),
|
|
||||||
threadgroup=(256, 1, 1),
|
|
||||||
output_shapes=[a.shape],
|
|
||||||
output_dtypes=[a.dtype],
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
return outputs[0]
|
|
||||||
|
|
||||||
a = mx.random.normal(shape=(16, 16)).astype(mx.float16)
|
|
||||||
b = exp_elementwise(a)
|
|
||||||
assert mx.allclose(b, mx.exp(a))
|
|
||||||
)pbdoc");
|
|
||||||
|
|
||||||
m.def(
|
|
||||||
"precompiled_cuda_kernel",
|
|
||||||
[](const std::string& name,
|
|
||||||
const nb::bytes compiled_source,
|
|
||||||
const std::vector<ScalarOrArray>& inputs_,
|
|
||||||
const std::vector<mx::Shape>& output_shapes,
|
|
||||||
const std::vector<mx::Dtype>& output_dtypes,
|
|
||||||
const std::vector<nb::object>& scalars_,
|
|
||||||
std::tuple<int, int, int> grid,
|
|
||||||
std::tuple<int, int, int> threadgroup,
|
|
||||||
int shared_memory,
|
|
||||||
std::optional<float> init_value = std::nullopt,
|
|
||||||
bool ensure_row_contiguous = false,
|
|
||||||
mx::StreamOrDevice s = {}) {
|
|
||||||
// Collect the inputs and cast them to array
|
|
||||||
std::vector<mx::array> inputs;
|
|
||||||
for (const auto& value : inputs_) {
|
|
||||||
inputs.push_back(to_array(value, std::nullopt));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect the scalar inputs
|
|
||||||
std::vector<mx::fast::ScalarArg> scalars;
|
|
||||||
scalars.reserve(scalars_.size());
|
|
||||||
for (const auto& v : scalars_) {
|
|
||||||
if (nb::isinstance<bool>(v)) {
|
|
||||||
scalars.push_back(nb::cast<bool>(v));
|
|
||||||
} else if (nb::isinstance<int>(v)) {
|
|
||||||
scalars.push_back(nb::cast<int>(v));
|
|
||||||
} else if (nb::isinstance<float>(v)) {
|
|
||||||
scalars.push_back(nb::cast<float>(v));
|
|
||||||
} else {
|
|
||||||
nb::object vtype = v.attr("__class__");
|
|
||||||
std::string vtype_name =
|
|
||||||
nb::cast<std::string>(vtype.attr("__name__"));
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[precompiled_cuda_kernel] Invalid scalar argument type. "
|
|
||||||
<< "Received " << vtype_name
|
|
||||||
<< " but must be one of bool, int or float";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return mx::fast::precompiled_cuda_kernel(
|
|
||||||
name,
|
|
||||||
std::string(
|
|
||||||
static_cast<const char*>(compiled_source.data()),
|
|
||||||
compiled_source.size()),
|
|
||||||
inputs,
|
|
||||||
output_shapes,
|
|
||||||
output_dtypes,
|
|
||||||
scalars,
|
|
||||||
grid,
|
|
||||||
threadgroup,
|
|
||||||
shared_memory,
|
|
||||||
init_value,
|
|
||||||
ensure_row_contiguous,
|
|
||||||
s);
|
|
||||||
},
|
|
||||||
nb::kw_only(),
|
|
||||||
"name"_a,
|
|
||||||
"compiled_source"_a,
|
|
||||||
"inputs"_a,
|
|
||||||
"output_shapes"_a,
|
|
||||||
"output_dtypes"_a,
|
|
||||||
"scalars"_a,
|
|
||||||
"grid"_a,
|
|
||||||
"threadgroup"_a,
|
|
||||||
"shared_memory"_a = 0,
|
|
||||||
"init_value"_a = nb::none(),
|
|
||||||
"ensure_row_contiguous"_a = false,
|
|
||||||
"stream"_a = nb::none(),
|
|
||||||
R"pbdoc(
|
|
||||||
Run a precompiled CUDA kernel defined from PTX or cubin.
|
|
||||||
|
|
||||||
This op is still experimental and various parts of the API may change.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): Name for the kernel
|
|
||||||
compiled_source (bytes): The precompiled kernel in raw bytes.
|
|
||||||
inputs (List[array]): The inputs passed to the CUDA kernel.
|
|
||||||
output_shapes (List[Sequence[int]]): The list of shapes for each output.
|
|
||||||
output_dtypes (List[Dtype]): The list of data types for each output.
|
|
||||||
scalars (List[Union[bool, int, float]]): A list of scalar arguments to
|
|
||||||
pass to the kernel.
|
|
||||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
|
||||||
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks.
|
|
||||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
|
||||||
shared_memory (int): The dynamic shared memory to request for the
|
|
||||||
kernel. A value of 0 means no dynamic shared memory. Default: ``0``.
|
|
||||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
|
||||||
By default, output arrays are uninitialized. Default: ``None``.
|
|
||||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
|
||||||
before the kernel runs. Default: ``False``.
|
|
||||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
|
||||||
)pbdoc");
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ void init_array(nb::module_&);
|
|||||||
void init_device(nb::module_&);
|
void init_device(nb::module_&);
|
||||||
void init_stream(nb::module_&);
|
void init_stream(nb::module_&);
|
||||||
void init_metal(nb::module_&);
|
void init_metal(nb::module_&);
|
||||||
void init_cuda(nb::module_&);
|
|
||||||
void init_memory(nb::module_&);
|
void init_memory(nb::module_&);
|
||||||
void init_ops(nb::module_&);
|
void init_ops(nb::module_&);
|
||||||
void init_transforms(nb::module_&);
|
void init_transforms(nb::module_&);
|
||||||
@@ -36,7 +35,6 @@ NB_MODULE(core, m) {
|
|||||||
init_stream(m);
|
init_stream(m);
|
||||||
init_array(m);
|
init_array(m);
|
||||||
init_metal(m);
|
init_metal(m);
|
||||||
init_cuda(m);
|
|
||||||
init_memory(m);
|
init_memory(m);
|
||||||
init_ops(m);
|
init_ops(m);
|
||||||
init_transforms(m);
|
init_transforms(m);
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ cuda_skip = {
|
|||||||
# Hadamard NYI
|
# Hadamard NYI
|
||||||
"TestOps.test_hadamard",
|
"TestOps.test_hadamard",
|
||||||
"TestOps.test_hadamard_grad_vmap",
|
"TestOps.test_hadamard_grad_vmap",
|
||||||
|
# Convolutions NYI
|
||||||
|
"TestConv.test_1d_conv_with_2d",
|
||||||
# FFTs NYI
|
# FFTs NYI
|
||||||
"TestFFT.test_fft",
|
"TestFFT.test_fft",
|
||||||
"TestFFT.test_fft_big_powers_of_two",
|
"TestFFT.test_fft_big_powers_of_two",
|
||||||
|
|||||||
284
python/tests/nccl_test_distributed.py
Normal file
284
python/tests/nccl_test_distributed.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx_tests
|
||||||
|
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
||||||
|
from mlx.nn.utils import average_gradients
|
||||||
|
|
||||||
|
|
||||||
|
class TestNCCLDistributed(mlx_tests.MLXTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
world = mx.distributed.init(strict=True, backend="nccl")
|
||||||
|
rank = world.rank()
|
||||||
|
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
|
||||||
|
|
||||||
|
def test_all_reduce(self):
|
||||||
|
world = mx.distributed.init()
|
||||||
|
dtypes = [
|
||||||
|
(mx.int8, 0),
|
||||||
|
(mx.uint8, 0),
|
||||||
|
(mx.int32, 0),
|
||||||
|
(mx.uint32, 0),
|
||||||
|
(mx.float32, 1e-6),
|
||||||
|
(mx.float16, 5e-3),
|
||||||
|
(mx.bfloat16, 1e-1),
|
||||||
|
]
|
||||||
|
sizes = [
|
||||||
|
(7,),
|
||||||
|
(10,),
|
||||||
|
(1024,),
|
||||||
|
(1024, 1024),
|
||||||
|
]
|
||||||
|
key = mx.random.key(0)
|
||||||
|
|
||||||
|
for dt, rtol in dtypes:
|
||||||
|
for sh in sizes:
|
||||||
|
x = (
|
||||||
|
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||||
|
).astype(dt)
|
||||||
|
|
||||||
|
# All sum
|
||||||
|
y = mx.distributed.all_sum(x[world.rank()])
|
||||||
|
z = x.sum(0)
|
||||||
|
maxrelerror = (y - z).abs()
|
||||||
|
if rtol > 0:
|
||||||
|
maxrelerror /= z.abs()
|
||||||
|
maxrelerror = maxrelerror.max()
|
||||||
|
self.assertLessEqual(maxrelerror, rtol)
|
||||||
|
|
||||||
|
def test_average_gradients(self):
|
||||||
|
original_all_sum = mx.distributed.all_sum
|
||||||
|
n_calls = 0
|
||||||
|
xtype = None
|
||||||
|
|
||||||
|
def new_all_sum(x, **kwargs):
|
||||||
|
nonlocal n_calls
|
||||||
|
nonlocal xtype
|
||||||
|
|
||||||
|
n_calls += 1
|
||||||
|
if xtype is not None:
|
||||||
|
self.assertEqual(xtype, x.dtype)
|
||||||
|
|
||||||
|
return original_all_sum(x, **kwargs)
|
||||||
|
|
||||||
|
mx.distributed.all_sum = new_all_sum
|
||||||
|
try:
|
||||||
|
grads = [mx.ones(10) for i in range(10)]
|
||||||
|
new_grads = average_gradients(grads, stream=mx.gpu)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 1)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 2)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 10)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
xtype = mx.float16
|
||||||
|
new_grads = average_gradients(
|
||||||
|
grads,
|
||||||
|
all_reduce_size=2 * 50,
|
||||||
|
communication_type=mx.float16,
|
||||||
|
stream=mx.gpu,
|
||||||
|
)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 2)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
mx.distributed.all_sum = original_all_sum
|
||||||
|
|
||||||
|
def test_donation(self):
|
||||||
|
x = mx.random.normal((1024,))
|
||||||
|
mx.eval(x)
|
||||||
|
mx.synchronize()
|
||||||
|
|
||||||
|
mx.reset_peak_memory()
|
||||||
|
scale = mx.array(2.0)
|
||||||
|
y = mx.distributed.all_sum(x)
|
||||||
|
mx.eval(y)
|
||||||
|
mx.synchronize()
|
||||||
|
all_sum_only = mx.get_peak_memory()
|
||||||
|
y = mx.distributed.all_sum(x) * scale
|
||||||
|
mx.eval(y)
|
||||||
|
mx.synchronize()
|
||||||
|
all_sum_with_binary = mx.get_peak_memory()
|
||||||
|
|
||||||
|
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||||
|
|
||||||
|
def test_shard_linear(self):
|
||||||
|
# Seed the prng to have the same inputs and weights generated everywhere
|
||||||
|
mx.random.seed(0xF0F0F0F0)
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
world = mx.distributed.init()
|
||||||
|
part = (
|
||||||
|
slice(None),
|
||||||
|
slice(
|
||||||
|
world.rank() * 1024 // world.size(),
|
||||||
|
(world.rank() + 1) * 1024 // world.size(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
x = mx.random.normal((4, 1024))
|
||||||
|
|
||||||
|
# Create and shard some linear layers
|
||||||
|
lin = nn.Linear(1024, 1024, bias=True)
|
||||||
|
slin1 = shard_linear(lin, "all-to-sharded")
|
||||||
|
slin2 = shard_linear(lin, "sharded-to-all")
|
||||||
|
y = lin(x)
|
||||||
|
y1 = slin1(x)
|
||||||
|
y2 = slin2(x[part])
|
||||||
|
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
|
||||||
|
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
# Check the backward works as expected
|
||||||
|
def dummy_loss(model, x, y):
|
||||||
|
return (model(x) * y).sum()
|
||||||
|
|
||||||
|
mod = nn.Sequential(
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
)
|
||||||
|
smod = nn.Sequential(
|
||||||
|
shard_linear(mod.layers[0], "all-to-sharded"),
|
||||||
|
shard_linear(mod.layers[1], "sharded-to-all"),
|
||||||
|
shard_linear(mod.layers[2], "all-to-sharded"),
|
||||||
|
shard_linear(mod.layers[3], "sharded-to-all"),
|
||||||
|
)
|
||||||
|
|
||||||
|
grad1 = nn.value_and_grad(mod, dummy_loss)
|
||||||
|
grad2 = nn.value_and_grad(smod, dummy_loss)
|
||||||
|
|
||||||
|
x = mx.random.normal((4, 128))
|
||||||
|
y = mx.random.normal((4, 128))
|
||||||
|
|
||||||
|
l1, g1 = grad1(mod, x, y)
|
||||||
|
l2, g2 = grad2(smod, x, y)
|
||||||
|
mx.eval(l1, g1, l2, g2)
|
||||||
|
|
||||||
|
part = slice(
|
||||||
|
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(l1, l2))
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][0]["weight"][part],
|
||||||
|
g2["layers"][0]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][2]["weight"][part],
|
||||||
|
g2["layers"][2]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][1]["weight"][:, part],
|
||||||
|
g2["layers"][1]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][3]["weight"][:, part],
|
||||||
|
g2["layers"][3]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][0]["bias"][part],
|
||||||
|
g2["layers"][0]["bias"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][2]["bias"][part],
|
||||||
|
g2["layers"][2]["bias"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_shard_predicate(self):
|
||||||
|
mx.random.seed(0xF0F0F0F0)
|
||||||
|
|
||||||
|
class MyConv(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.aggregate = kwargs.pop("aggregate", False)
|
||||||
|
self.conv = nn.Conv2d(*args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
if self.aggregate:
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def sharding(path, weight):
|
||||||
|
parts = path.split(".")
|
||||||
|
even = int(parts[1]) % 2 == 0
|
||||||
|
if even:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return -1 if parts[-1] != "bias" else None
|
||||||
|
|
||||||
|
mod = nn.Sequential(
|
||||||
|
MyConv(3, 128, kernel_size=3),
|
||||||
|
MyConv(128, 128, kernel_size=3),
|
||||||
|
MyConv(128, 128, kernel_size=3),
|
||||||
|
MyConv(128, 3, kernel_size=3),
|
||||||
|
)
|
||||||
|
smod = nn.Sequential(
|
||||||
|
MyConv(3, 128, kernel_size=3),
|
||||||
|
MyConv(128, 128, kernel_size=3, aggregate=True),
|
||||||
|
MyConv(128, 128, kernel_size=3),
|
||||||
|
MyConv(128, 3, kernel_size=3, aggregate=True),
|
||||||
|
)
|
||||||
|
smod.update(mod.parameters())
|
||||||
|
shard_inplace(smod, sharding)
|
||||||
|
|
||||||
|
x = mx.random.normal((4, 16, 16, 3))
|
||||||
|
y1 = mod(x)
|
||||||
|
y2 = smod(x)
|
||||||
|
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mlx_tests.MLXTestRunner()
|
||||||
@@ -1186,13 +1186,6 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
y_hat = mx.conv2d(x, w)
|
y_hat = mx.conv2d(x, w)
|
||||||
self.assertTrue(mx.allclose(y, y_hat))
|
self.assertTrue(mx.allclose(y, y_hat))
|
||||||
|
|
||||||
def test_conv2d_large_filter_small_channels(self):
|
|
||||||
x = mx.random.normal(shape=(1, 181, 181, 1))
|
|
||||||
w = mx.random.normal(shape=(1, 182, 182, 1))
|
|
||||||
y = mx.conv2d(x, w, (1, 1), (1, 1), stream=mx.cpu)
|
|
||||||
y_hat = mx.conv2d(x, w, (1, 1), (1, 1))
|
|
||||||
self.assertTrue(mx.allclose(y, y_hat, rtol=1e-3, atol=1e-3))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
@@ -581,28 +581,18 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)(x)
|
)(x)
|
||||||
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
|
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
def test_custom_kernel_basic(self):
|
def test_custom_kernel_basic(self):
|
||||||
if mx.metal.is_available():
|
|
||||||
source = """
|
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
out1[elem] = a[elem];
|
|
||||||
"""
|
|
||||||
custom_kernel = mx.fast.metal_kernel
|
|
||||||
elif mx.cuda.is_available():
|
|
||||||
source = """
|
|
||||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
|
||||||
out1[elem] = a[elem];
|
|
||||||
"""
|
|
||||||
custom_kernel = mx.fast.cuda_kernel
|
|
||||||
|
|
||||||
mx.random.seed(7)
|
mx.random.seed(7)
|
||||||
a = mx.random.normal(shape=(2, 2))
|
a = mx.random.normal(shape=(2, 2))
|
||||||
kernel = custom_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="basic",
|
name="basic",
|
||||||
input_names=["a"],
|
input_names=["a"],
|
||||||
output_names=["out1"],
|
output_names=["out1"],
|
||||||
source=source,
|
source="""
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
out1[elem] = a[elem];
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
out = kernel(
|
out = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
@@ -614,9 +604,16 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(out[0], a))
|
self.assertTrue(mx.allclose(out[0], a))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
def test_custom_kernel_args(self):
|
def test_custom_kernel_args(self):
|
||||||
if mx.metal.is_available():
|
mx.random.seed(7)
|
||||||
|
a = mx.random.normal(shape=(3, 6))
|
||||||
|
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
|
||||||
|
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="arg_test",
|
||||||
|
input_names=["a", "b", "c", "d"],
|
||||||
|
output_names=["out1", "out2"],
|
||||||
source="""
|
source="""
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
T tmp = a[0];
|
T tmp = a[0];
|
||||||
@@ -626,30 +623,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
out1[elem] = 1;
|
out1[elem] = 1;
|
||||||
}
|
}
|
||||||
out2[elem] = a[1] + b[2] + c[1] - d;
|
out2[elem] = a[1] + b[2] + c[1] - d;
|
||||||
"""
|
""",
|
||||||
custom_kernel = mx.fast.metal_kernel
|
|
||||||
elif mx.cuda.is_available():
|
|
||||||
source = """
|
|
||||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
|
||||||
T tmp = a[0];
|
|
||||||
if (e) {
|
|
||||||
out1[elem] = a[1] + b[2] + static_cast<float>(c[3]) + d[0] + f;
|
|
||||||
} else {
|
|
||||||
out1[elem] = 1;
|
|
||||||
}
|
|
||||||
out2[elem] = a[1] + b[2] + static_cast<float>(c[1]) - d[0];
|
|
||||||
"""
|
|
||||||
custom_kernel = mx.fast.cuda_kernel
|
|
||||||
|
|
||||||
mx.random.seed(7)
|
|
||||||
a = mx.random.normal(shape=(3, 6))
|
|
||||||
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
|
|
||||||
|
|
||||||
kernel = custom_kernel(
|
|
||||||
name="arg_test",
|
|
||||||
input_names=["a", "b", "c", "d"],
|
|
||||||
output_names=["out1", "out2"],
|
|
||||||
source=source,
|
|
||||||
)
|
)
|
||||||
out = kernel(
|
out = kernel(
|
||||||
inputs=[
|
inputs=[
|
||||||
@@ -673,9 +647,10 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484)))
|
self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484)))
|
||||||
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
|
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
def test_custom_kernel_strides(self):
|
def test_custom_kernel_strides(self):
|
||||||
if mx.metal.is_available():
|
mx.random.seed(7)
|
||||||
|
a = mx.random.normal(shape=(3, 6))
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||||
@@ -687,29 +662,12 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
T tmp = inp[elem];
|
T tmp = inp[elem];
|
||||||
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||||
"""
|
"""
|
||||||
custom_kernel = mx.fast.metal_kernel
|
|
||||||
elif mx.cuda.is_available():
|
|
||||||
source = """
|
|
||||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
|
||||||
auto loc = elem_to_loc(elem, inp_shape.data(), inp_strides.data(), inp_ndim);
|
|
||||||
T tmp = inp[loc];
|
|
||||||
out[elem] = exp(tmp) * WARP_SIZE;
|
|
||||||
"""
|
|
||||||
source_contig = """
|
|
||||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
|
||||||
T tmp = inp[elem];
|
|
||||||
out[elem] = exp(tmp) * WARP_SIZE;
|
|
||||||
"""
|
|
||||||
custom_kernel = mx.fast.cuda_kernel
|
|
||||||
|
|
||||||
mx.random.seed(7)
|
|
||||||
a = mx.random.normal(shape=(3, 6))
|
|
||||||
|
|
||||||
# non contiguous
|
# non contiguous
|
||||||
a = mx.tile(a[::2], [4, 1])
|
a = mx.tile(a[::2], [4, 1])
|
||||||
|
|
||||||
for contig in [True, False]:
|
for contig in [True, False]:
|
||||||
kernel = custom_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="myexp" + str(contig),
|
name="myexp" + str(contig),
|
||||||
input_names=["inp"],
|
input_names=["inp"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
@@ -727,41 +685,24 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
|
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
def test_custom_kernel_helper(self):
|
def test_custom_kernel_helper(self):
|
||||||
if mx.metal.is_available():
|
mx.random.seed(7)
|
||||||
|
a = mx.random.normal(shape=(2, 2))
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="helper",
|
||||||
|
input_names=["a"],
|
||||||
|
output_names=["out1"],
|
||||||
header="""
|
header="""
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T do_exp(T x) {
|
T do_exp(T x) {
|
||||||
return metal::precise::exp(x);
|
return metal::precise::exp(x);
|
||||||
}
|
}
|
||||||
"""
|
""",
|
||||||
source="""
|
source="""
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
out1[elem] = do_exp(a[elem]);
|
out1[elem] = do_exp(a[elem]);
|
||||||
"""
|
""",
|
||||||
custom_kernel = mx.fast.metal_kernel
|
|
||||||
elif mx.cuda.is_available():
|
|
||||||
header = """
|
|
||||||
template <typename T>
|
|
||||||
__device__ T do_exp(T x) {
|
|
||||||
return exp(x);
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
source = """
|
|
||||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
|
||||||
out1[elem] = do_exp(a[elem]);
|
|
||||||
"""
|
|
||||||
custom_kernel = mx.fast.cuda_kernel
|
|
||||||
|
|
||||||
mx.random.seed(7)
|
|
||||||
a = mx.random.normal(shape=(2, 2))
|
|
||||||
kernel = custom_kernel(
|
|
||||||
name="helper",
|
|
||||||
input_names=["a"],
|
|
||||||
output_names=["out1"],
|
|
||||||
header=header,
|
|
||||||
source=source,
|
|
||||||
)
|
)
|
||||||
out = kernel(
|
out = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
@@ -773,21 +714,16 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(out[0], mx.exp(a)))
|
self.assertTrue(mx.allclose(out[0], mx.exp(a)))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
def test_custom_kernel_attributes(self):
|
def test_custom_kernel_attributes(self):
|
||||||
if mx.metal.is_available():
|
|
||||||
source = "out[0] = threads_per_threadgroup.x;"
|
|
||||||
custom_kernel = mx.fast.metal_kernel
|
|
||||||
elif mx.cuda.is_available():
|
|
||||||
source = "out[0] = blockDim.x;"
|
|
||||||
custom_kernel = mx.fast.cuda_kernel
|
|
||||||
|
|
||||||
a = mx.zeros(shape=(1, 1))
|
a = mx.zeros(shape=(1, 1))
|
||||||
kernel = custom_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="test_fun",
|
name="test_fun",
|
||||||
input_names=["a"],
|
input_names=["a"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source,
|
source="""
|
||||||
|
out[0] = threads_per_threadgroup.x;
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
out = kernel(
|
out = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
|
|||||||
Reference in New Issue
Block a user