mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Compare commits
24 Commits
d5e3fe7f86
...
ea451af9a0
Author | SHA1 | Date | |
---|---|---|---|
![]() |
ea451af9a0 | ||
![]() |
53fa981caf | ||
![]() |
b1d95a3880 | ||
![]() |
4b02d3e738 | ||
![]() |
dd5e833068 | ||
![]() |
b3013042ca | ||
![]() |
fc2f6bc51c | ||
![]() |
9dbaa35be3 | ||
![]() |
13eccfa887 | ||
![]() |
96a7017442 | ||
![]() |
c2f1c2a338 | ||
![]() |
9fd8eb357c | ||
![]() |
0c69f10d55 | ||
![]() |
c8b4787e4e | ||
![]() |
2188199ff8 | ||
![]() |
aa07429bad | ||
![]() |
918761a25a | ||
![]() |
a4fc671d3e | ||
![]() |
f5f65ef48c | ||
![]() |
c2dd81a8aa | ||
![]() |
d7e680ffe4 | ||
![]() |
c371baf53a | ||
![]() |
ccf78f566c | ||
![]() |
c9fa68664a |
@ -1,12 +1,14 @@
|
||||
# Filename rules in cuda backend:
|
||||
#
|
||||
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
|
||||
# * Device-only kernel code should be put in kernels/ subdir.
|
||||
# * Files in kernels/ subdir should not include files outside.
|
||||
# * Device-only code should be put in device/ subdir.
|
||||
# * Files in device/ subdir should not include files outside.
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||
@ -17,18 +19,47 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||
|
||||
# Embed kernel sources in binary for JIT compilation.
|
||||
file(
|
||||
GLOB MLX_JIT_SOURCES
|
||||
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh")
|
||||
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
|
||||
add_custom_command(
|
||||
OUTPUT gen/cuda_jit_sources.h
|
||||
COMMAND
|
||||
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
|
||||
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
|
||||
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
|
||||
add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h)
|
||||
add_dependencies(mlx cuda_jit_sources)
|
||||
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
||||
|
||||
# Enable defining device lambda functions.
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
@ -36,12 +67,16 @@ target_compile_options(mlx
|
||||
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
|
||||
# Explicitly pass this flag to suppress the warning, it is safe to set it to
|
||||
# true but the warning wouldn't be suppressed.
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
||||
target_compile_options(
|
||||
mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--static-global-template-stub=false>")
|
||||
endif()
|
||||
|
||||
# Suppress warning when building for compute capability 7 used by V100.
|
||||
target_compile_options(
|
||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
|
||||
|
||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
@ -75,6 +110,9 @@ target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||
# Use cublasLt.
|
||||
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
||||
|
||||
# Use NVRTC and driver APIs.
|
||||
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
||||
|
189
mlx/backend/cuda/arg_reduce.cu
Normal file
189
mlx/backend/cuda/arg_reduce.cu
Normal file
@ -0,0 +1,189 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T>
|
||||
struct IndexValPair {
|
||||
uint32_t index;
|
||||
T val;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ArgMin {
|
||||
constexpr __device__ T init() {
|
||||
return Limits<T>::max();
|
||||
}
|
||||
|
||||
__device__ IndexValPair<T> operator()(
|
||||
const IndexValPair<T>& best,
|
||||
const IndexValPair<T>& current) {
|
||||
if (best.val > current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ IndexValPair<T>
|
||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] < best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ArgMax {
|
||||
constexpr __device__ T init() {
|
||||
return Limits<T>::min();
|
||||
}
|
||||
|
||||
__device__ IndexValPair<T> operator()(
|
||||
const IndexValPair<T>& best,
|
||||
const IndexValPair<T>& current) {
|
||||
if (best.val < current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ IndexValPair<T>
|
||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] > best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void arg_reduce_general(
|
||||
const T* in,
|
||||
uint32_t* out,
|
||||
size_t size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides in_strides,
|
||||
const __grid_constant__ Strides out_strides,
|
||||
int32_t ndim,
|
||||
int64_t axis_stride,
|
||||
int32_t axis_size) {
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
int64_t index = cg::this_grid().block_rank();
|
||||
if (index >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
||||
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
||||
|
||||
Op op;
|
||||
T init = op.init();
|
||||
IndexValPair<T> best{0, init};
|
||||
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T vals[N_READS];
|
||||
auto tid = r * BLOCK_DIM + block.thread_index().z;
|
||||
cub::LoadDirectBlocked(
|
||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||
best = op.reduce_many(best, vals, tid * N_READS);
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
best = BlockReduceT(temp).Reduce(best, op);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[out_idx] = best.index;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ArgReduce::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
|
||||
// Prepare the shapes, strides and axis arguments.
|
||||
Shape shape = remove_index(in.shape(), axis_);
|
||||
Strides in_strides = remove_index(in.strides(), axis_);
|
||||
Strides out_strides = out.ndim() == in.ndim()
|
||||
? remove_index(out.strides(), axis_)
|
||||
: out.strides();
|
||||
int64_t axis_stride = in.strides()[axis_];
|
||||
int32_t axis_size = in.shape()[axis_];
|
||||
int32_t ndim = shape.size();
|
||||
|
||||
// ArgReduce.
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims{1, 1, BLOCK_DIM};
|
||||
auto kernel = &cu::arg_reduce_general<
|
||||
InType,
|
||||
cu::ArgMax<InType>,
|
||||
BLOCK_DIM,
|
||||
N_READS>;
|
||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||
kernel = &cu::arg_reduce_general<
|
||||
InType,
|
||||
cu::ArgMin<InType>,
|
||||
BLOCK_DIM,
|
||||
N_READS>;
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(),
|
||||
out.data<uint32_t>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(in_strides),
|
||||
const_param(out_strides),
|
||||
ndim,
|
||||
axis_stride,
|
||||
axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
150
mlx/backend/cuda/bin2h.cmake
Normal file
150
mlx/backend/cuda/bin2h.cmake
Normal file
@ -0,0 +1,150 @@
|
||||
# Based on: https://github.com/sivachandran/cmake-bin2h
|
||||
#
|
||||
# Copyright 2020 Sivachandran Paramasivam
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
include(CMakeParseArguments)
|
||||
|
||||
# Function to wrap a given string into multiple lines at the given column
|
||||
# position.
|
||||
#
|
||||
# Parameters:
|
||||
#
|
||||
# * VARIABLE - The name of the CMake variable holding the string.
|
||||
# * AT_COLUMN - The column position at which string will be wrapped.
|
||||
function(WRAP_STRING)
|
||||
set(oneValueArgs VARIABLE AT_COLUMN)
|
||||
cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN})
|
||||
|
||||
string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength)
|
||||
math(EXPR offset "0")
|
||||
|
||||
while(stringLength GREATER 0)
|
||||
if(stringLength GREATER ${WRAP_STRING_AT_COLUMN})
|
||||
math(EXPR length "${WRAP_STRING_AT_COLUMN}")
|
||||
else()
|
||||
math(EXPR length "${stringLength}")
|
||||
endif()
|
||||
|
||||
string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line)
|
||||
set(lines "${lines}\n ${line}")
|
||||
|
||||
math(EXPR stringLength "${stringLength} - ${length}")
|
||||
math(EXPR offset "${offset} + ${length}")
|
||||
endwhile()
|
||||
|
||||
set(${WRAP_STRING_VARIABLE}
|
||||
"${lines}"
|
||||
PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Function to embed contents of a file as byte array in C/C++ header file(.h).
|
||||
# The header file will contain a byte array and integer variable holding the
|
||||
# size of the array.
|
||||
#
|
||||
# Parameters:
|
||||
#
|
||||
# * SOURCE_FILES - The paths of source files whose contents will be embedded in
|
||||
# the header file.
|
||||
# * VARIABLE_NAME - The name of the variable for the byte array. The string
|
||||
# "_SIZE" will be append to this name and will be used a variable name for
|
||||
# size variable.
|
||||
# * HEADER_FILE - The path of header file.
|
||||
# * APPEND - If specified appends to the header file instead of overwriting it
|
||||
# * HEADER_NAMESPACE - The namespace, where the array should be located in.
|
||||
# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte
|
||||
# array.
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG")
|
||||
function(BIN2H)
|
||||
set(options APPEND NULL_TERMINATE)
|
||||
set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE)
|
||||
set(multiValueArgs SOURCE_FILES)
|
||||
cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}"
|
||||
"${multiValueArgs}" ${ARGN})
|
||||
|
||||
set(arrayDefinition "")
|
||||
foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES)
|
||||
# get filename without extension
|
||||
get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE)
|
||||
# convert the filename to a valid C identifier
|
||||
string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME)
|
||||
|
||||
# reads source file contents as hex string
|
||||
file(READ ${SOURCE_FILE} hexString HEX)
|
||||
|
||||
# append null
|
||||
if(BIN2H_NULL_TERMINATE)
|
||||
string(APPEND hexString "00")
|
||||
endif()
|
||||
|
||||
# wraps the hex string into multiple lines
|
||||
wrap_string(VARIABLE hexString AT_COLUMN 24)
|
||||
|
||||
# strip the © in source code
|
||||
string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString})
|
||||
|
||||
string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues
|
||||
${arrayValues})
|
||||
|
||||
# make a full variable name for the array
|
||||
set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}")
|
||||
|
||||
# declares byte array and the length variables
|
||||
string(APPEND arrayDefinition
|
||||
"constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n")
|
||||
endforeach()
|
||||
|
||||
# add namespace wrapper if defined
|
||||
if(DEFINED BIN2H_HEADER_NAMESPACE)
|
||||
set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {")
|
||||
set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}")
|
||||
set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n")
|
||||
endif()
|
||||
|
||||
set(arrayIncludes "#pragma once")
|
||||
string(PREPEND declarations "${arrayIncludes}\n\n")
|
||||
|
||||
if(BIN2H_APPEND)
|
||||
file(APPEND ${BIN2H_HEADER_FILE} "${declarations}")
|
||||
else()
|
||||
file(WRITE ${BIN2H_HEADER_FILE} "${declarations}")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# ----------------------------- CLI args -----------------------------
|
||||
|
||||
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
|
||||
foreach(source ${MLX_JIT_SOURCES_LIST})
|
||||
list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}")
|
||||
endforeach()
|
||||
|
||||
bin2h(
|
||||
SOURCE_FILES
|
||||
${MLX_JIT_SOURCES_ABS}
|
||||
NULL_TERMINATE
|
||||
VARIABLE_NAME
|
||||
"jit_source"
|
||||
HEADER_NAMESPACE
|
||||
"mlx::core"
|
||||
HEADER_FILE
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h")
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
|
228
mlx/backend/cuda/compiled.cpp
Normal file
228
mlx/backend/cuda/compiled.cpp
Normal file
@ -0,0 +1,228 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
struct FusedKernelBuilder {
|
||||
std::string os;
|
||||
const std::string& kernel_name;
|
||||
const std::vector<array>& inputs;
|
||||
const std::vector<array>& outputs;
|
||||
const std::vector<array>& tape;
|
||||
const std::function<bool(size_t)>& is_constant;
|
||||
|
||||
void build(const char* name, bool contiguous) {
|
||||
NodeNamer namer;
|
||||
|
||||
// Function parameters.
|
||||
std::vector<std::string> params;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
params.push_back(
|
||||
fmt::format("const {}* {}", dtype_to_cuda_type(x.dtype()), xname));
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
params.push_back(fmt::format(
|
||||
"const __grid_constant__ cuda::std::array<int64_t, NDIM> {}_strides",
|
||||
xname));
|
||||
}
|
||||
}
|
||||
for (const auto& x : outputs) {
|
||||
params.push_back(fmt::format(
|
||||
"{}* {}", dtype_to_cuda_type(x.dtype()), namer.get_name(x)));
|
||||
}
|
||||
if (!contiguous) {
|
||||
params.push_back(
|
||||
"const __grid_constant__ cuda::std::array<int32_t, NDIM> shape");
|
||||
}
|
||||
params.push_back("IdxT size");
|
||||
|
||||
// Build function signature.
|
||||
if (contiguous) {
|
||||
os += "template <typename IdxT = uint32_t>\n";
|
||||
} else {
|
||||
os += "template <int NDIM, typename IdxT = uint32_t>\n";
|
||||
}
|
||||
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
os += " ";
|
||||
os += params[i];
|
||||
if (i != params.size() - 1) {
|
||||
os += ",\n";
|
||||
}
|
||||
}
|
||||
os += ") {\n";
|
||||
|
||||
// Index.
|
||||
os +=
|
||||
" IdxT index = cg::this_grid().thread_rank();\n"
|
||||
" if (index >= size) {\n"
|
||||
" return;\n"
|
||||
" }\n";
|
||||
|
||||
// Read inputs.
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
std::string value;
|
||||
if (is_constant(i)) {
|
||||
std::ostringstream ss;
|
||||
print_constant(ss, x);
|
||||
value = fmt::format("static_cast<{}>({})", type, ss.str());
|
||||
} else if (is_scalar(x)) {
|
||||
value = fmt::format("{}[0]", xname);
|
||||
} else if (contiguous) {
|
||||
value = fmt::format("{}[index]", xname);
|
||||
} else {
|
||||
std::string index = fmt::format(
|
||||
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
|
||||
xname);
|
||||
value = fmt::format("{}[{}]", xname, index);
|
||||
}
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
}
|
||||
|
||||
// Write tape.
|
||||
for (const auto& x : tape) {
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
std::string value;
|
||||
if (is_static_cast(x.primitive())) {
|
||||
value = fmt::format(
|
||||
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
|
||||
} else {
|
||||
std::ostringstream ss;
|
||||
x.primitive().print(ss);
|
||||
value = ss.str();
|
||||
value += "{}(";
|
||||
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
|
||||
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
|
||||
}
|
||||
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
||||
}
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
}
|
||||
|
||||
// Write output.
|
||||
for (const auto& x : outputs) {
|
||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
}
|
||||
|
||||
os += "}\n";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cu
|
||||
|
||||
constexpr const char* g_jit_includes = R"(
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
)";
|
||||
|
||||
void Compiled::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("Compiled::eval_gpu");
|
||||
auto& s = stream();
|
||||
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
|
||||
// Build source code.
|
||||
cu::FusedKernelBuilder builder{
|
||||
g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_};
|
||||
builder.os +=
|
||||
"namespace mlx::core::cu {\n\n"
|
||||
"namespace cg = cooperative_groups;\n\n";
|
||||
builder.build("_contiguous", true);
|
||||
builder.os += "\n";
|
||||
builder.build("_strided", false);
|
||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||
// Build kernel names.
|
||||
std::vector<std::string> kernel_names = {
|
||||
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
|
||||
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
|
||||
};
|
||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
|
||||
kernel_names.push_back(
|
||||
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
|
||||
}
|
||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||
});
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
auto [contiguous, shape, strides_vec] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Whether to use large index.
|
||||
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
||||
|
||||
// Put inputs.
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
const auto& x = inputs[i];
|
||||
mod.append_arg(x);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
mod.append_arg(strides_vec[strides_index++]);
|
||||
}
|
||||
}
|
||||
|
||||
// Put outputs.
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
for (auto& x : outputs) {
|
||||
mod.append_arg(x);
|
||||
}
|
||||
|
||||
// Put shape and size.
|
||||
if (!contiguous) {
|
||||
mod.append_arg(shape);
|
||||
}
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(outputs[0].data_size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(outputs[0].data_size());
|
||||
}
|
||||
|
||||
// Launch kernel.
|
||||
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||
if (contiguous) {
|
||||
kernel_name += fmt::format("_contiguous<{}>", index_type);
|
||||
} else {
|
||||
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
|
||||
}
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
for (const auto& out : outputs) {
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, outputs[0], large);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -3,8 +3,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
|
72
mlx/backend/cuda/device/atomic_ops.cuh
Normal file
72
mlx/backend/cuda/device/atomic_ops.cuh
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
|
||||
#include <cuda/atomic>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_add(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref += val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_prod(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
T old = ref.load();
|
||||
while (!ref.compare_exchange_strong(old, old * val)) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_max(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref.fetch_max(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_min(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref.fetch_min(val);
|
||||
}
|
||||
|
||||
// Somehow cuda::atomic_ref does not provide atomic add for following types.
|
||||
template <typename T>
|
||||
inline __device__ void atomic_add_general(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
T old = ref.load();
|
||||
while (!ref.compare_exchange_strong(old, old + val)) {
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__half* out, __half val) {
|
||||
atomicAdd(out, val);
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
|
||||
#if __CUDA_ARCH__ < 900
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||
#if __CUDA_ARCH__ < 800
|
||||
#if CCCL_VERSION >= 2008000
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
bool cccl_version_too_old_for_bfloat16_atomic_add = false;
|
||||
assert(cccl_version_too_old_for_bfloat16_atomic_add);
|
||||
#endif
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -1,6 +1,6 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/array>
|
12
mlx/backend/cuda/device/config.h
Normal file
12
mlx/backend/cuda/device/config.h
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file is used by both CUDA kernel code and host-only C++ code.
|
||||
|
||||
#pragma once
|
||||
|
||||
// The maximum dimensions of shape/strides passed as kernel parameters.
|
||||
#define MAX_NDIM 8
|
||||
|
||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||
#define WARP_SIZE 32
|
53
mlx/backend/cuda/device/gather.cuh
Normal file
53
mlx/backend/cuda/device/gather.cuh
Normal file
@ -0,0 +1,53 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
|
||||
__global__ void gather(
|
||||
const T* src,
|
||||
T* out,
|
||||
LocT size,
|
||||
const __grid_constant__ Shape src_shape,
|
||||
const __grid_constant__ Strides src_strides,
|
||||
int32_t src_ndim,
|
||||
const __grid_constant__ Shape slice_sizes,
|
||||
uint32_t slice_size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,
|
||||
const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>
|
||||
indices_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>
|
||||
indices_strides) {
|
||||
LocT out_idx = cg::this_grid().thread_rank();
|
||||
if (out_idx >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
LocT src_elem = out_idx % slice_size;
|
||||
LocT idx_elem = out_idx / slice_size;
|
||||
|
||||
LocT src_loc =
|
||||
elem_to_loc(src_elem, slice_sizes.data(), src_strides.data(), src_ndim);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(
|
||||
idx_elem,
|
||||
indices_shape.data() + i * IDX_NDIM,
|
||||
indices_strides.data() + i * IDX_NDIM);
|
||||
int32_t axis = axes[i];
|
||||
LocT idx_val = absolute_index(indices[i][idx_loc], src_shape[axis]);
|
||||
src_loc += idx_val * src_strides[axis];
|
||||
}
|
||||
|
||||
out[out_idx] = src[src_loc];
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
65
mlx/backend/cuda/device/gather_axis.cuh
Normal file
65
mlx/backend/cuda/device/gather_axis.cuh
Normal file
@ -0,0 +1,65 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
int NDIM,
|
||||
bool SrcC,
|
||||
bool IdxC,
|
||||
typename LocT>
|
||||
__global__ void gather_axis(
|
||||
const T* src,
|
||||
const IdxT* indices,
|
||||
T* out,
|
||||
LocT idx_size_pre,
|
||||
LocT idx_size_axis,
|
||||
LocT idx_size_post,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> src_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,
|
||||
int32_t axis,
|
||||
int32_t axis_size,
|
||||
int64_t src_stride_axis,
|
||||
int64_t idx_stride_axis) {
|
||||
LocT index = cg::this_grid().thread_rank();
|
||||
if (index >= idx_size_pre * idx_size_axis * idx_size_post) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);
|
||||
|
||||
LocT elem_idx = z * idx_size_post;
|
||||
|
||||
LocT idx_loc = y * idx_stride_axis;
|
||||
if constexpr (IdxC) {
|
||||
idx_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
idx_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());
|
||||
}
|
||||
|
||||
auto idx_val = absolute_index(indices[idx_loc], axis_size);
|
||||
|
||||
LocT src_loc = idx_val * src_stride_axis;
|
||||
if constexpr (SrcC) {
|
||||
src_loc += elem_idx * axis_size + x;
|
||||
} else {
|
||||
src_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), src_strides.data());
|
||||
}
|
||||
|
||||
LocT out_idx = y * idx_size_post + elem_idx * idx_size_axis + x;
|
||||
|
||||
out[out_idx] = src[src_loc];
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
30
mlx/backend/cuda/device/indexing.cuh
Normal file
30
mlx/backend/cuda/device/indexing.cuh
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <cuda/std/tuple>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Convert an absolute index to positions in a 3d grid, assuming the index is
|
||||
// calculated with:
|
||||
// index = x * dim1 * dim2 + y * dim2 + z
|
||||
template <typename T>
|
||||
inline __host__ __device__ cuda::std::tuple<T, T, T>
|
||||
index_to_dims(T index, T dim1, T dim2) {
|
||||
T x = index / (dim1 * dim2);
|
||||
T y = (index % (dim1 * dim2)) / dim2;
|
||||
T z = index % dim2;
|
||||
return cuda::std::make_tuple(x, y, z);
|
||||
}
|
||||
|
||||
// Get absolute index from possible negative index.
|
||||
template <typename IdxT>
|
||||
inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) {
|
||||
if constexpr (cuda::std::is_unsigned_v<IdxT>) {
|
||||
return idx;
|
||||
} else {
|
||||
return static_cast<int32_t>(idx < 0 ? idx + size : idx);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
68
mlx/backend/cuda/device/scatter.cuh
Normal file
68
mlx/backend/cuda/device/scatter.cuh
Normal file
@ -0,0 +1,68 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/scatter_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NIDX,
|
||||
int IDX_NDIM,
|
||||
typename LocT>
|
||||
__global__ void scatter(
|
||||
const T* upd,
|
||||
T* out,
|
||||
LocT size,
|
||||
const __grid_constant__ Shape upd_shape,
|
||||
const __grid_constant__ Strides upd_strides,
|
||||
int32_t upd_ndim,
|
||||
LocT upd_post_idx_size,
|
||||
const __grid_constant__ Shape out_shape,
|
||||
const __grid_constant__ Strides out_strides,
|
||||
int32_t out_ndim,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,
|
||||
const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>
|
||||
indices_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>
|
||||
indices_strides) {
|
||||
LocT upd_idx = cg::this_grid().thread_rank();
|
||||
if (upd_idx >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
LocT out_elem = upd_idx % upd_post_idx_size;
|
||||
LocT idx_elem = upd_idx / upd_post_idx_size;
|
||||
|
||||
LocT out_idx = elem_to_loc(
|
||||
out_elem, upd_shape.data() + IDX_NDIM, out_strides.data(), out_ndim);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(
|
||||
idx_elem,
|
||||
indices_shape.data() + i * IDX_NDIM,
|
||||
indices_strides.data() + i * IDX_NDIM);
|
||||
int32_t axis = axes[i];
|
||||
LocT idx_val = absolute_index(indices[i][idx_loc], out_shape[axis]);
|
||||
out_idx += idx_val * out_strides[axis];
|
||||
}
|
||||
|
||||
LocT upd_loc = elem_to_loc(
|
||||
out_elem + idx_elem * upd_post_idx_size,
|
||||
upd_shape.data(),
|
||||
upd_strides.data(),
|
||||
upd_ndim);
|
||||
|
||||
Op{}(out + out_idx, upd[upd_loc]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
67
mlx/backend/cuda/device/scatter_axis.cuh
Normal file
67
mlx/backend/cuda/device/scatter_axis.cuh
Normal file
@ -0,0 +1,67 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/scatter_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
bool UpdC,
|
||||
bool IdxC,
|
||||
typename LocT>
|
||||
__global__ void scatter_axis(
|
||||
const T* upd,
|
||||
const IdxT* indices,
|
||||
T* out,
|
||||
LocT idx_size_pre,
|
||||
LocT idx_size_axis,
|
||||
LocT idx_size_post,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> upd_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,
|
||||
int32_t axis,
|
||||
int32_t axis_size,
|
||||
int64_t upd_stride_axis,
|
||||
int64_t idx_stride_axis) {
|
||||
LocT index = cg::this_grid().thread_rank();
|
||||
if (index >= idx_size_pre * idx_size_axis * idx_size_post) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);
|
||||
|
||||
LocT elem_idx = z * idx_size_post;
|
||||
|
||||
LocT idx_loc = y * idx_stride_axis;
|
||||
if constexpr (IdxC) {
|
||||
idx_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
idx_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());
|
||||
}
|
||||
|
||||
auto idx_val = absolute_index(indices[idx_loc], axis_size);
|
||||
|
||||
LocT upd_loc = y * upd_stride_axis;
|
||||
if constexpr (UpdC) {
|
||||
upd_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
upd_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), upd_strides.data());
|
||||
}
|
||||
|
||||
LocT out_idx = idx_val * idx_size_post + elem_idx * axis_size + x;
|
||||
|
||||
Op{}(out + out_idx, upd[upd_loc]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
44
mlx/backend/cuda/device/scatter_ops.cuh
Normal file
44
mlx/backend/cuda/device/scatter_ops.cuh
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/atomic_ops.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct ScatterAssign {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
*out = val;
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterSum {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_add(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterProd {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_prod(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterMax {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_max(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterMin {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_min(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
12
mlx/backend/cuda/device/ternary_ops.cuh
Normal file
12
mlx/backend/cuda/device/ternary_ops.cuh
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
__device__ T operator()(bool condition, T x, T y) {
|
||||
return condition ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -2,8 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
339
mlx/backend/cuda/device/utils.cuh
Normal file
339
mlx/backend/cuda/device/utils.cuh
Normal file
@ -0,0 +1,339 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file must not include any host-only code, utilies that work under both
|
||||
// host and device can be put here.
|
||||
//
|
||||
// See more about the requirements at:
|
||||
// https://docs.nvidia.com/cuda/nvrtc/#language
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/array>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/tuple>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CUDA kernel utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// To pass shape/strides to kernels via constant memory, their size must be
|
||||
// known at compile time.
|
||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct Limits {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
return cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
return cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Limits<
|
||||
T,
|
||||
cuda::std::enable_if_t<
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double>>> {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
}
|
||||
};
|
||||
|
||||
// CUDA 11 does not have host side arithmatic operators for half types.
|
||||
template <typename T>
|
||||
struct Limits<
|
||||
T,
|
||||
cuda::std::enable_if_t<
|
||||
cuda::std::is_same_v<T, __half> ||
|
||||
cuda::std::is_same_v<T, __nv_bfloat16>>> {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
#else
|
||||
return -cuda::std::numeric_limits<float>::infinity();
|
||||
#endif
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
#else
|
||||
return cuda::std::numeric_limits<float>::lowest();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr __host__ __device__ bool max() {
|
||||
return true;
|
||||
}
|
||||
static constexpr __host__ __device__ bool min() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Limits<cuComplex> {
|
||||
static constexpr __host__ __device__ cuComplex max() {
|
||||
return {Limits<float>::max(), Limits<float>::max()};
|
||||
}
|
||||
static constexpr __host__ __device__ cuComplex min() {
|
||||
return {Limits<float>::min(), Limits<float>::min()};
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
// Optimize when the ndim is known at compile time.
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) {
|
||||
IdxT loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
const int64_t* c_strides) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
IdxT c_loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
c_loc += dim_idx * c_strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||
}
|
||||
|
||||
// Optimized version when ndim is larger than 4.
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
int ndim) {
|
||||
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
const int64_t* c_strides,
|
||||
int ndim) {
|
||||
auto [a_loc, b_loc, c_loc] =
|
||||
elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
c_loc += dim_idx * c_strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Elem to loc in a loop utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int DIM, bool General = true, typename OffsetT = size_t>
|
||||
struct LoopedElemToLoc {
|
||||
int dim;
|
||||
LoopedElemToLoc<DIM - 1, General, OffsetT> inner_looper;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
||||
|
||||
__device__ void next(const int* shape, const int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index++;
|
||||
offset += OffsetT(strides[dim - 1]);
|
||||
if (index >= shape[dim - 1]) {
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
offset = inner_looper.offset;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index += n;
|
||||
offset += n * OffsetT(strides[dim - 1]);
|
||||
|
||||
if (index >= shape[dim - 1]) {
|
||||
int extra = index - shape[dim - 1];
|
||||
if (extra >= shape[dim - 1]) {
|
||||
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
|
||||
extra = extra % shape[dim - 1];
|
||||
} else {
|
||||
inner_looper.next(shape, strides);
|
||||
}
|
||||
index = 0;
|
||||
offset = inner_looper.offset;
|
||||
if (extra > 0) {
|
||||
next(extra, shape, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, true, OffsetT> {
|
||||
int dim;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int dim) : dim(dim) {}
|
||||
|
||||
__device__ void next(const int* shape, const int64_t* strides) {
|
||||
index++;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||
index += n;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset = index * OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, false, OffsetT> {
|
||||
OffsetT offset{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int) {}
|
||||
|
||||
__device__ void next(const int*, const int64_t* strides) {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int*, const int64_t* strides) {
|
||||
offset += n * OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
420
mlx/backend/cuda/indexing.cpp
Normal file
420
mlx/backend/cuda/indexing.cpp
Normal file
@ -0,0 +1,420 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "cuda_jit_sources.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"};
|
||||
|
||||
void append_indices_arg(
|
||||
cu::JitModule& mod,
|
||||
const std::vector<array>& inputs,
|
||||
int nidx,
|
||||
int idx_ndim) {
|
||||
std::vector<const void*> indices(nidx);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
indices[i] = inputs[i + 1].data<void>();
|
||||
}
|
||||
mod.append_arg(std::move(indices));
|
||||
std::vector<int32_t> indices_shape(nidx * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy_n(
|
||||
inputs[i + 1].shape().begin(),
|
||||
idx_ndim,
|
||||
indices_shape.data() + i * idx_ndim);
|
||||
}
|
||||
mod.append_arg(std::move(indices_shape));
|
||||
std::vector<int64_t> indices_strides(nidx * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy_n(
|
||||
inputs[i + 1].strides().begin(),
|
||||
idx_ndim,
|
||||
indices_strides.data() + i * idx_ndim);
|
||||
}
|
||||
mod.append_arg(std::move(indices_strides));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Gather::eval_gpu");
|
||||
assert(inputs.size() > 0);
|
||||
const auto& src = inputs[0];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int nidx = inputs.size() - 1;
|
||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
||||
(src.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
||||
|
||||
uint32_t slice_size = std::accumulate(
|
||||
slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());
|
||||
|
||||
std::string module_name = fmt::format(
|
||||
"gather_{}_{}_{}",
|
||||
dtype_to_string(out.dtype()),
|
||||
dtype_to_string(idx_dtype),
|
||||
nidx);
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
for (int large = 0; large <= 1; ++large) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::gather<{}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
nidx,
|
||||
ndim,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||
});
|
||||
|
||||
mod.append_arg(src);
|
||||
mod.append_arg(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(out.size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(out.size());
|
||||
}
|
||||
mod.append_ndim_arg(src.shape());
|
||||
mod.append_ndim_arg(src.strides());
|
||||
mod.append_arg<int32_t>(src.ndim());
|
||||
mod.append_ndim_arg(slice_sizes_);
|
||||
mod.append_arg(slice_size);
|
||||
mod.append_arg(axes_);
|
||||
append_indices_arg(mod, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::gather<{}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, out, large);
|
||||
});
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Gather::eval_gpu");
|
||||
assert(inputs.size() > 1);
|
||||
auto& upd = inputs.back();
|
||||
|
||||
// Copy src into out.
|
||||
CopyType copy_type;
|
||||
if (inputs[0].data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (inputs[0].flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(inputs[0], out, copy_type);
|
||||
|
||||
// Empty update.
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int nidx = axes_.size();
|
||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
||||
(upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
||||
|
||||
uint32_t upd_post_idx_size = std::accumulate(
|
||||
upd.shape().begin() + idx_ndim,
|
||||
upd.shape().end(),
|
||||
1,
|
||||
std::multiplies<uint32_t>());
|
||||
|
||||
const char* op = g_scatter_ops[reduce_type_];
|
||||
std::string module_name = fmt::format(
|
||||
"scatter_{}_{}_{}_{}",
|
||||
dtype_to_string(out.dtype()),
|
||||
dtype_to_string(idx_dtype),
|
||||
op,
|
||||
nidx);
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
for (int large = 0; large <= 1; ++large) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
op,
|
||||
nidx,
|
||||
ndim,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||
});
|
||||
|
||||
mod.append_arg(upd);
|
||||
mod.append_arg(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd.size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(upd.size());
|
||||
}
|
||||
mod.append_ndim_arg(upd.shape());
|
||||
mod.append_ndim_arg(upd.strides());
|
||||
mod.append_arg<int32_t>(upd.ndim());
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd_post_idx_size);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(upd_post_idx_size);
|
||||
}
|
||||
mod.append_ndim_arg(out.shape());
|
||||
mod.append_ndim_arg(out.strides());
|
||||
mod.append_arg<int32_t>(out.ndim());
|
||||
mod.append_arg(axes_);
|
||||
append_indices_arg(mod, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
op,
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, upd, large);
|
||||
});
|
||||
}
|
||||
|
||||
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("GatherAxis::eval_gpu");
|
||||
assert(inputs.size() > 1);
|
||||
const auto& src = inputs[0];
|
||||
const auto& idx = inputs[1];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
||||
|
||||
std::string module_name = fmt::format(
|
||||
"gather_axis_{}_{}",
|
||||
dtype_to_string(out.dtype()),
|
||||
dtype_to_string(idx.dtype()));
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
for (int contiguous = 0; contiguous < 4; ++contiguous) {
|
||||
for (int large = 0; large <= 1; ++large) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx.dtype()),
|
||||
ndim,
|
||||
contiguous & 1 ? true : false,
|
||||
contiguous & 2 ? true : false,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
size_t idx_size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
idx_size_pre *= idx.shape(i);
|
||||
}
|
||||
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
||||
idx_size_post *= idx.shape(i);
|
||||
}
|
||||
size_t idx_size_axis = idx.shape(axis_);
|
||||
|
||||
mod.append_arg(src);
|
||||
mod.append_arg(idx);
|
||||
mod.append_arg(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(idx_size_pre);
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(idx_size_pre);
|
||||
mod.append_arg<uint32_t>(idx_size_axis);
|
||||
mod.append_arg<uint32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(src.strides(), axis_));
|
||||
mod.append_arg(remove_index(idx.strides(), axis_));
|
||||
mod.append_arg<int32_t>(axis_);
|
||||
mod.append_arg(src.shape(axis_));
|
||||
mod.append_arg(src.strides(axis_));
|
||||
mod.append_arg(idx.strides(axis_));
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx.dtype()),
|
||||
src.ndim() - 1,
|
||||
src.flags().row_contiguous,
|
||||
idx.flags().row_contiguous,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, idx, large);
|
||||
});
|
||||
}
|
||||
|
||||
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ScatterAxis::eval_gpu");
|
||||
assert(inputs.size() > 2);
|
||||
const auto& src = inputs[0];
|
||||
const auto& idx = inputs[1];
|
||||
const auto& upd = inputs[2];
|
||||
|
||||
// Copy src into out.
|
||||
CopyType copy_type;
|
||||
if (src.data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (src.flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(src, out, copy_type);
|
||||
|
||||
// Empty update.
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
||||
|
||||
const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign";
|
||||
std::string module_name = fmt::format(
|
||||
"scatter_axis_{}_{}_{}",
|
||||
dtype_to_string(out.dtype()),
|
||||
dtype_to_string(idx.dtype()),
|
||||
op);
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
for (int contiguous = 0; contiguous < 4; ++contiguous) {
|
||||
for (int large = 0; large <= 1; ++large) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx.dtype()),
|
||||
op,
|
||||
ndim,
|
||||
contiguous & 1 ? true : false,
|
||||
contiguous & 2 ? true : false,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
size_t idx_size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
idx_size_pre *= idx.shape(i);
|
||||
}
|
||||
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
||||
idx_size_post *= idx.shape(i);
|
||||
}
|
||||
size_t idx_size_axis = idx.shape(axis_);
|
||||
|
||||
mod.append_arg(upd);
|
||||
mod.append_arg(idx);
|
||||
mod.append_arg(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(idx_size_pre);
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(idx_size_pre);
|
||||
mod.append_arg<uint32_t>(idx_size_axis);
|
||||
mod.append_arg<uint32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(upd.strides(), axis_));
|
||||
mod.append_arg(remove_index(idx.strides(), axis_));
|
||||
mod.append_arg<int32_t>(axis_);
|
||||
mod.append_arg(out.shape(axis_));
|
||||
mod.append_arg(upd.strides(axis_));
|
||||
mod.append_arg(idx.strides(axis_));
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx.dtype()),
|
||||
op,
|
||||
idx.ndim() - 1,
|
||||
upd.flags().row_contiguous,
|
||||
idx.flags().row_contiguous,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, idx, large);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thrust/iterator/iterator_adaptor.h>
|
||||
#include <thrust/iterator/iterator_facade.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// RandomAccessIterator for strided access to array entries.
|
||||
template <typename Iterator, typename Stride = int64_t>
|
||||
class strided_iterator
|
||||
: public thrust::
|
||||
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
|
||||
public:
|
||||
using super_t =
|
||||
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
|
||||
|
||||
using reference = typename super_t::reference;
|
||||
using difference_type = typename super_t::difference_type;
|
||||
|
||||
__host__ __device__ strided_iterator(Iterator it, Stride stride)
|
||||
: super_t(it), stride_(stride) {}
|
||||
|
||||
__host__ __device__ Stride stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class thrust::iterator_core_access;
|
||||
|
||||
__host__ __device__ bool equal(const strided_iterator& other) const {
|
||||
return this->base() == other.base();
|
||||
}
|
||||
|
||||
__host__ __device__ void advance(difference_type n) {
|
||||
this->base_reference() += n * stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ void increment() {
|
||||
this->base_reference() += stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ void decrement() {
|
||||
this->base_reference() -= stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ difference_type
|
||||
distance_to(const strided_iterator& other) const {
|
||||
const difference_type dist = other.base() - this->base();
|
||||
_CCCL_ASSERT(
|
||||
dist % stride() == 0,
|
||||
"Underlying iterator difference must be divisible by the stride");
|
||||
return dist / stride();
|
||||
}
|
||||
|
||||
Stride stride_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
348
mlx/backend/cuda/jit_module.cpp
Normal file
348
mlx/backend/cuda/jit_module.cpp
Normal file
@ -0,0 +1,348 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
#include "cuda_jit_sources.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvrtc.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
#define CHECK_NVRTC_ERROR(cmd) check_nvrtc_error(#cmd, (cmd))
|
||||
|
||||
void check_nvrtc_error(const char* name, nvrtcResult err) {
|
||||
if (err != NVRTC_SUCCESS) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed: {}", name, nvrtcGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd))
|
||||
|
||||
void check_cu_error(const char* name, CUresult err) {
|
||||
if (err != CUDA_SUCCESS) {
|
||||
const char* err_str = "Unknown error";
|
||||
cuGetErrorString(err, &err_str);
|
||||
throw std::runtime_error(fmt::format("{} failed: {}", name, err_str));
|
||||
}
|
||||
}
|
||||
|
||||
// Return the location of the CUDA toolkit.
|
||||
const char* cuda_home() {
|
||||
const char* home = std::getenv("CUDA_HOME");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
home = std::getenv("CUDA_PATH");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
#if defined(__linux__)
|
||||
home = "/usr/local/cuda";
|
||||
if (std::filesystem::exists(home)) {
|
||||
return home;
|
||||
}
|
||||
#endif
|
||||
throw std::runtime_error(
|
||||
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
||||
}
|
||||
|
||||
// Get the cache directory for storing compiled results.
|
||||
bool get_ptx_cache_dir(std::filesystem::path* result) {
|
||||
auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||
if (!std::filesystem::is_directory(path)) {
|
||||
std::error_code error;
|
||||
if (!std::filesystem::create_directories(path, error)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
*result = path;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
||||
bool read_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
std::vector<char>* ptx,
|
||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||
auto ptx_path = cache_dir / (module_name + ".ptx");
|
||||
std::error_code error;
|
||||
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
||||
if (error) {
|
||||
return false;
|
||||
}
|
||||
std::ifstream ptx_file(ptx_path, std::ios::binary);
|
||||
if (!ptx_file.good()) {
|
||||
return false;
|
||||
}
|
||||
ptx->resize(ptx_size);
|
||||
ptx_file.read(ptx->data(), ptx_size);
|
||||
|
||||
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||
std::string line;
|
||||
while (std::getline(txt_file, line)) {
|
||||
auto tab = line.find('\t');
|
||||
if (tab != std::string::npos) {
|
||||
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Write the |ptx| and |ptx_kernels| to |cache_dir| with |name|.
|
||||
void write_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
const std::vector<char>& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
||||
if (!ptx.empty()) {
|
||||
ptx_file.write(&ptx.front(), ptx.size());
|
||||
}
|
||||
std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
txt_file << name << "\t" << mangled << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Return if |device|'s version is not newer than |major|.|minor| version.
|
||||
inline bool version_lower_equal(Device& device, int major, int minor) {
|
||||
if (device.compute_capability_major() < major) {
|
||||
return true;
|
||||
} else if (device.compute_capability_major() == major) {
|
||||
return device.compute_capability_minor() <= minor;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Return whether NVRTC supports compiling to |device|'s SASS code.
|
||||
bool compiler_supports_device_sass(Device& device) {
|
||||
int nvrtc_major, nvrtc_minor;
|
||||
CHECK_NVRTC_ERROR(nvrtcVersion(&nvrtc_major, &nvrtc_minor));
|
||||
if (nvrtc_major < 9) {
|
||||
return false;
|
||||
} else if (nvrtc_major == 9) {
|
||||
return version_lower_equal(device, 7, 2);
|
||||
} else if (nvrtc_major == 10) {
|
||||
return version_lower_equal(device, 7, 5);
|
||||
} else if (nvrtc_major == 11 && nvrtc_minor == 0) {
|
||||
return version_lower_equal(device, 8, 0);
|
||||
} else if (nvrtc_major == 11 && nvrtc_minor < 8) {
|
||||
return version_lower_equal(device, 8, 6);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/"
|
||||
|
||||
constexpr const char* g_include_names[] = {
|
||||
INCLUDE_PREFIX "atomic_ops.cuh",
|
||||
INCLUDE_PREFIX "binary_ops.cuh",
|
||||
INCLUDE_PREFIX "cast_op.cuh",
|
||||
INCLUDE_PREFIX "config.h",
|
||||
INCLUDE_PREFIX "cucomplex_math.cuh",
|
||||
INCLUDE_PREFIX "fp16_math.cuh",
|
||||
INCLUDE_PREFIX "indexing.cuh",
|
||||
INCLUDE_PREFIX "scatter_ops.cuh",
|
||||
INCLUDE_PREFIX "unary_ops.cuh",
|
||||
INCLUDE_PREFIX "ternary_ops.cuh",
|
||||
INCLUDE_PREFIX "utils.cuh",
|
||||
};
|
||||
|
||||
#undef INCLUDE_PREFIX
|
||||
|
||||
constexpr const char* g_headers[] = {
|
||||
jit_source_atomic_ops,
|
||||
jit_source_binary_ops,
|
||||
jit_source_cast_op,
|
||||
jit_source_config,
|
||||
jit_source_cucomplex_math,
|
||||
jit_source_fp16_math,
|
||||
jit_source_indexing,
|
||||
jit_source_scatter_ops,
|
||||
jit_source_unary_ops,
|
||||
jit_source_ternary_ops,
|
||||
jit_source_utils,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
JitModule::JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder) {
|
||||
// Check cache.
|
||||
std::filesystem::path cache_dir;
|
||||
std::vector<char> ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
if (!get_ptx_cache_dir(&cache_dir) ||
|
||||
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
|
||||
// Create program.
|
||||
auto [source_code, kernel_names] = builder();
|
||||
nvrtcProgram prog;
|
||||
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
||||
&prog,
|
||||
source_code.c_str(),
|
||||
(module_name + ".cu").c_str(),
|
||||
std::size(g_headers),
|
||||
g_headers,
|
||||
g_include_names));
|
||||
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
|
||||
&prog,
|
||||
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
|
||||
for (const auto& name : kernel_names) {
|
||||
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
|
||||
}
|
||||
|
||||
// Compile program.
|
||||
bool use_sass = compiler_supports_device_sass(device);
|
||||
std::string compute = fmt::format(
|
||||
"--gpu-architecture={}_{}{}",
|
||||
use_sass ? "sm" : "compute",
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
std::string include = fmt::format("--include-path={}/include", cuda_home());
|
||||
const char* args[] = {compute.c_str(), include.c_str()};
|
||||
nvrtcResult compile_result =
|
||||
nvrtcCompileProgram(prog, std::size(args), args);
|
||||
if (compile_result != NVRTC_SUCCESS) {
|
||||
size_t log_size;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||
std::vector<char> log(log_size + 1, 0);
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
|
||||
throw std::runtime_error(
|
||||
fmt::format("Failed to compile kernel: {}.", log.data()));
|
||||
}
|
||||
|
||||
// Get mangled names of kernel names.
|
||||
for (const auto& name : kernel_names) {
|
||||
const char* mangled;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
|
||||
ptx_kernels.emplace_back(name, mangled);
|
||||
}
|
||||
|
||||
// Get ptx data.
|
||||
size_t ptx_size;
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||
}
|
||||
ptx.resize(ptx_size, 0);
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels);
|
||||
}
|
||||
|
||||
// Load module.
|
||||
char jit_log[4089] = {};
|
||||
CUjit_option options[] = {
|
||||
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
|
||||
void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)};
|
||||
CUresult jit_result = cuModuleLoadDataEx(
|
||||
&module_, ptx.data(), std::size(options), options, values);
|
||||
if (jit_result != CUDA_SUCCESS) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Failed to load compiled {} kernel: {}.", module_name, jit_log));
|
||||
}
|
||||
|
||||
// Load kernels.
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
CUfunction kernel;
|
||||
CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||
kernels_[name] = kernel;
|
||||
}
|
||||
}
|
||||
|
||||
JitModule::~JitModule() {
|
||||
CHECK_CU_ERROR(cuModuleUnload(module_));
|
||||
}
|
||||
|
||||
void JitModule::launch_kernel(
|
||||
CUstream stream,
|
||||
const std::string& kernel_name,
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread) {
|
||||
CUfunction kernel = get_kernel(kernel_name);
|
||||
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
|
||||
int _, block_dim;
|
||||
CHECK_CU_ERROR(
|
||||
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
|
||||
if (block_dim > nthreads) {
|
||||
block_dim = nthreads;
|
||||
}
|
||||
Dims num_blocks{1, 1, 1};
|
||||
if (large) {
|
||||
num_blocks =
|
||||
get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread);
|
||||
std::get<0>(num_blocks) =
|
||||
(std::get<0>(num_blocks) + block_dim - 1) / block_dim;
|
||||
} else {
|
||||
std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim;
|
||||
}
|
||||
launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1});
|
||||
}
|
||||
|
||||
void JitModule::launch_kernel(
|
||||
CUstream stream,
|
||||
CUfunction kernel,
|
||||
Dims num_blocks,
|
||||
Dims block_dims) {
|
||||
CHECK_CU_ERROR(cuLaunchKernel(
|
||||
kernel,
|
||||
std::get<0>(num_blocks),
|
||||
std::get<1>(num_blocks),
|
||||
std::get<2>(num_blocks),
|
||||
std::get<0>(block_dims),
|
||||
std::get<1>(block_dims),
|
||||
std::get<2>(block_dims),
|
||||
0,
|
||||
stream,
|
||||
args_.data(),
|
||||
nullptr));
|
||||
args_.clear();
|
||||
storage_.clear();
|
||||
}
|
||||
|
||||
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
||||
auto it = kernels_.find(kernel_name);
|
||||
if (it == kernels_.end()) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("There is no kernel named {}.", kernel_name));
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void JitModule::append_ptr_arg(const void* v) {
|
||||
args_.push_back(const_cast<void*>(v));
|
||||
}
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder) {
|
||||
static std::unordered_map<std::string, JitModule> map;
|
||||
auto it = map.find(name);
|
||||
if (it == map.end()) {
|
||||
it = map.try_emplace(name, cu::device(device), name, builder).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
113
mlx/backend/cuda/jit_module.h
Normal file
113
mlx/backend/cuda/jit_module.h
Normal file
@ -0,0 +1,113 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
|
||||
#include <deque>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
|
||||
using KernelBuilderResult = std::pair<
|
||||
/* source code */ std::string,
|
||||
/* kernel names */ std::vector<std::string>>;
|
||||
using KernelBuilder = std::function<KernelBuilderResult()>;
|
||||
|
||||
class JitModule {
|
||||
public:
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder);
|
||||
~JitModule();
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
|
||||
void append_arg(const array& a) {
|
||||
append_arg(reinterpret_cast<CUdeviceptr>(a.data<void>()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append_arg(T val) {
|
||||
storage_.emplace_back(val);
|
||||
append_ptr_arg(&storage_.back());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append_arg(std::vector<T> vec) {
|
||||
if (vec.empty()) {
|
||||
// The nullptr can not be used as arg, pass something not null.
|
||||
append_arg(std::monostate{});
|
||||
} else {
|
||||
append_ptr_arg(vec.data());
|
||||
storage_.emplace_back(std::move(vec));
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure the arg is copied to an array with size of NDIM.
|
||||
template <size_t NDIM = MAX_NDIM, typename T>
|
||||
void append_ndim_arg(const std::vector<T>& vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
}
|
||||
std::vector<T> copied(NDIM);
|
||||
std::copy(vec.begin(), vec.end(), copied.data());
|
||||
append_arg(std::move(copied));
|
||||
}
|
||||
|
||||
// Launch kernel with |kernel_name| that each thread works on
|
||||
// |work_per_thread| elements of |arr|.
|
||||
void launch_kernel(
|
||||
CUstream stream,
|
||||
const std::string& kernel_name,
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1);
|
||||
|
||||
void launch_kernel(
|
||||
CUstream stream,
|
||||
CUfunction kernel,
|
||||
Dims num_blocks,
|
||||
Dims block_dims);
|
||||
|
||||
CUfunction get_kernel(const std::string& kernel_name);
|
||||
|
||||
private:
|
||||
void append_ptr_arg(const void* v);
|
||||
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, CUfunction> kernels_;
|
||||
std::vector<void*> args_;
|
||||
|
||||
// The cuLaunchKernel API requires passing pointers to arguments so store
|
||||
// temporary values untill kernel is launched.
|
||||
using Arg = std::variant<
|
||||
std::monostate,
|
||||
CUdeviceptr,
|
||||
int32_t,
|
||||
uint32_t,
|
||||
int64_t,
|
||||
std::vector<const void*>,
|
||||
std::vector<int32_t>,
|
||||
std::vector<int64_t>>;
|
||||
std::deque<Arg> storage_;
|
||||
};
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder);
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -1,13 +1,13 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file includes host-only utilies for writing CUDA kernels, the difference
|
||||
// from backend/cuda/kernels/utils.cuh is that the latter file only include
|
||||
// from backend/cuda/device/utils.cuh is that the latter file only include
|
||||
// device-only code.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
@ -47,6 +47,31 @@ namespace mlx::core {
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
|
||||
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2.
|
||||
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
|
||||
{ \
|
||||
uint32_t _num_threads = NUM_THREADS; \
|
||||
if (_num_threads <= WARP_SIZE) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 2) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 4) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 8) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 16) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
}
|
||||
|
||||
// Maps CPU types to CUDA types.
|
||||
template <typename T>
|
||||
struct CTypeToCudaType {
|
||||
|
@ -1,104 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file must not include any host-only code, utilies that work under both
|
||||
// host and device can be put here.
|
||||
//
|
||||
// See more about the requirements at:
|
||||
// https://docs.nvidia.com/cuda/nvrtc/#language
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/array>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/tuple>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CUDA kernel utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// To pass shape/strides to kernels via constant memory, their size must be
|
||||
// known at compile time.
|
||||
#define MAX_NDIM 8
|
||||
|
||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
// Optimize when the ndim is known at compile time.
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) {
|
||||
IdxT loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
// Optimized version when ndim is larger than 4.
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
int ndim) {
|
||||
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
389
mlx/backend/cuda/layer_norm.cu
Normal file
389
mlx/backend/cuda/layer_norm.cu
Normal file
@ -0,0 +1,389 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
inline __device__ float3 plus_f3(const float3& a, const float3& b) {
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
||||
}
|
||||
|
||||
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||
template <typename T, int BLOCK_DIM>
|
||||
struct BlockBroadcastReduce {
|
||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||
|
||||
cg::thread_block& block;
|
||||
TempStorage& temp;
|
||||
|
||||
template <typename Op>
|
||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
T x = cg::reduce(warp, input, op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
}
|
||||
|
||||
__device__ T Sum(const T& input) {
|
||||
return Reduce(input, cg::plus<T>{}, T{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void layer_norm(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* b,
|
||||
T* out,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride,
|
||||
int64_t b_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum.
|
||||
float sum = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||
}
|
||||
sum = BlockReduceT{block, temp}.Sum(sum);
|
||||
|
||||
// Mean.
|
||||
float mean = sum / axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
normalizer += t * t;
|
||||
}
|
||||
}
|
||||
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T bn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
||||
}
|
||||
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void layer_norm_vjp(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* g,
|
||||
T* gx,
|
||||
T* gw,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
using BlockReduceF3 = BlockBroadcastReduce<float3, BLOCK_DIM>;
|
||||
__shared__ union {
|
||||
typename BlockReduceF::TempStorage f;
|
||||
typename BlockReduceF3::TempStorage f3;
|
||||
} temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
g += grid.block_rank() * axis_size;
|
||||
gx += grid.block_rank() * axis_size;
|
||||
gw += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum.
|
||||
float sum = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||
}
|
||||
sum = BlockReduceF{block, temp.f}.Sum(sum);
|
||||
|
||||
// Mean.
|
||||
float mean = sum / axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float3 factors = {};
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||
}
|
||||
}
|
||||
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
|
||||
float meanwg = factors.x / axis_size;
|
||||
float meanwgxc = factors.y / axis_size;
|
||||
float normalizer2 = 1 / (factors.z / axis_size + eps);
|
||||
float normalizer = sqrt(normalizer2);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2;
|
||||
if constexpr (HAS_W) {
|
||||
wn[i] = gi * xi;
|
||||
}
|
||||
}
|
||||
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool LayerNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||
void LayerNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("LayerNorm::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::layer_norm<DataType, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride,
|
||||
b_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void LayerNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("LayerNormVJP::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
}
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return {x_copy, true};
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[3].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
auto [g, g_copied] = check_input(inputs[3]);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
array& gb = outputs[2];
|
||||
|
||||
// Check whether we had a weight.
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs.
|
||||
bool g_in_gx = false;
|
||||
if (donate_x) {
|
||||
gx.copy_shared_buffer(x);
|
||||
} else if (donate_g) {
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
}
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
// gradient accumulators.
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
if (has_w) {
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc(gb.nbytes()));
|
||||
|
||||
// Finish with the gradient for b in case we had a b.
|
||||
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(g);
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if (has_w) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
159
mlx/backend/cuda/logsumexp.cu
Normal file
159
mlx/backend/cuda/logsumexp.cu
Normal file
@ -0,0 +1,159 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return __expf(x);
|
||||
}
|
||||
|
||||
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void logsumexp(const T* in, T* out, int axis_size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
in += grid.block_rank() * axis_size;
|
||||
|
||||
cg::greater<AccT> max_op;
|
||||
cg::plus<AccT> plus_op;
|
||||
|
||||
// Thread reduce.
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min();
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
AccT vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM + block.thread_rank(),
|
||||
make_cast_iterator<AccT>(in),
|
||||
vals,
|
||||
axis_size,
|
||||
Limits<AccT>::min());
|
||||
prevmax = maxval;
|
||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||
// Online normalizer calculation for softmax:
|
||||
// https://github.com/NVIDIA/online-softmax
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
|
||||
// First warp reduce.
|
||||
prevmax = maxval;
|
||||
maxval = cg::reduce(warp, maxval, max_op);
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||
|
||||
__shared__ AccT local_max[WARP_SIZE];
|
||||
__shared__ AccT local_normalizer[WARP_SIZE];
|
||||
|
||||
// Write to shared memory and do second warp reduce.
|
||||
prevmax = maxval;
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_max[warp.meta_group_rank()] = maxval;
|
||||
}
|
||||
block.sync();
|
||||
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_max[warp.thread_rank()]
|
||||
: Limits<AccT>::finite_min();
|
||||
maxval = cg::reduce(warp, maxval, max_op);
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_normalizer[warp.meta_group_rank()] = normalizer;
|
||||
}
|
||||
block.sync();
|
||||
normalizer = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_normalizer[warp.thread_rank()]
|
||||
: AccT{};
|
||||
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||
|
||||
// Write output.
|
||||
if (block.thread_rank() == 0) {
|
||||
out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("LogSumExp::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto in = ensure_contiguous(inputs[0]);
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else {
|
||||
auto n = in.shape(-1);
|
||||
auto flags = in.flags();
|
||||
auto strides = in.strides();
|
||||
for (auto& s : strides) {
|
||||
s /= n;
|
||||
}
|
||||
bool col_contig = strides[0] == 1;
|
||||
for (int i = 1; col_contig && i < strides.size(); ++i) {
|
||||
col_contig &=
|
||||
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
|
||||
}
|
||||
flags.col_contiguous = col_contig;
|
||||
out.set_data(
|
||||
allocator::malloc(in.nbytes() / n),
|
||||
in.data_size() / n,
|
||||
std::move(strides),
|
||||
flags);
|
||||
}
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,9 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/arange.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/arange.cuh"
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
@ -72,32 +72,21 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
}
|
||||
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(ArgReduce)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Gather)
|
||||
NO_GPU(GatherAxis)
|
||||
NO_GPU(GatherMM)
|
||||
NO_GPU(GatherQMM)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Load)
|
||||
NO_GPU(LogSumExp)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(Reduce)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(SliceUpdate)
|
||||
NO_GPU(Softmax)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
@ -105,10 +94,6 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_USE_FALLBACK(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
|
82
mlx/backend/cuda/reduce.cu
Normal file
82
mlx/backend/cuda/reduce.cu
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/fill.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Reduce::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
array in = inputs[0];
|
||||
|
||||
// Make sure no identity reductions trickle down here.
|
||||
assert(!axes_.empty());
|
||||
assert(out.size() != in.size());
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// Fill out with init value.
|
||||
if (in.size() == 0) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
thrust::fill_n(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
out.data_size(),
|
||||
cu::ReduceInit<OP, InType>::value());
|
||||
});
|
||||
});
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Reduce.
|
||||
ReductionPlan plan = get_reduction_plan(in, axes_);
|
||||
|
||||
// If it is a general reduce then copy the input to a contiguous array and
|
||||
// recompute the plan.
|
||||
if (plan.type == GeneralReduce) {
|
||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, in_copy, CopyType::General, s);
|
||||
encoder.add_temporary(in_copy);
|
||||
in = in_copy;
|
||||
plan = get_reduction_plan(in, axes_);
|
||||
}
|
||||
|
||||
if ((plan.type == ContiguousAllReduce) ||
|
||||
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
|
||||
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
|
||||
row_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousStridedReduce ||
|
||||
plan.type == GeneralStridedReduce) {
|
||||
col_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
return;
|
||||
}
|
||||
|
||||
throw std::runtime_error("No plan reached in reduce.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
278
mlx/backend/cuda/reduce/col_reduce.cu
Normal file
278
mlx/backend/cuda/reduce/col_reduce.cu
Normal file
@ -0,0 +1,278 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
struct ColReduceArgs {
|
||||
// The size of the contiguous column reduction.
|
||||
size_t reduction_size;
|
||||
int64_t reduction_stride;
|
||||
|
||||
// Input shape and strides excluding the reduction axes.
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
int ndim;
|
||||
|
||||
// Input shape and strides of the reduction axes (including last dimension).
|
||||
Shape reduce_shape;
|
||||
Strides reduce_strides;
|
||||
int reduce_ndim;
|
||||
|
||||
// The number of column we are reducing. Namely prod(reduce_shape).
|
||||
size_t non_col_reductions;
|
||||
|
||||
ColReduceArgs(
|
||||
const array& in,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes) {
|
||||
assert(!plan.shape.empty());
|
||||
reduction_size = plan.shape.back();
|
||||
reduction_stride = plan.strides.back();
|
||||
|
||||
int64_t stride_back = 1;
|
||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
||||
stride_back *= shape_vec.back();
|
||||
shape_vec.pop_back();
|
||||
strides_vec.pop_back();
|
||||
}
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
|
||||
reduce_shape = const_param(plan.shape);
|
||||
reduce_strides = const_param(plan.strides);
|
||||
reduce_ndim = plan.shape.size();
|
||||
|
||||
non_col_reductions = 1;
|
||||
for (int i = 0; i < reduce_ndim - 1; i++) {
|
||||
non_col_reductions *= reduce_shape[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void col_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
int column =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
if (column * N_READS >= args.reduction_stride) {
|
||||
return;
|
||||
}
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(
|
||||
block.thread_index().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
for (size_t r = block.thread_index().y;
|
||||
r < args.non_col_reductions * args.reduction_size;
|
||||
r += block.dim_threads().y) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.reduction_stride,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
loop.next(
|
||||
block.dim_threads().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do block reduce when each column has more than 1 element to reduce.
|
||||
if (block.dim_threads().y > 1) {
|
||||
__shared__ U shared_vals[32 * 8 * N_READS];
|
||||
size_t col =
|
||||
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col * N_READS + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
if (block.thread_index().y == 0) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
|
||||
}
|
||||
for (int j = 1; j < block.dim_threads().y; j++) {
|
||||
col = j * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (block.thread_index().y == 0) {
|
||||
cub::StoreDirectBlocked(
|
||||
column,
|
||||
out + out_idx * args.reduction_stride,
|
||||
totals,
|
||||
args.reduction_stride);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BM,
|
||||
int BN,
|
||||
int N_READS = 4>
|
||||
__global__ void col_reduce_looped(
|
||||
const T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
constexpr int n_warps = BN / N_READS;
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
int r = block.thread_rank() / n_warps;
|
||||
int column = block.thread_rank() % n_warps;
|
||||
int in_offset = grid.block_index().x * BN;
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location() + in_offset),
|
||||
vals,
|
||||
args.reduction_stride - in_offset,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do warp reduce for each output.
|
||||
constexpr int n_outputs = BN / n_warps;
|
||||
static_assert(BM == 32 && n_outputs == N_READS);
|
||||
__shared__ U shared_vals[BM * BN];
|
||||
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (warp.thread_rank() == 0) {
|
||||
size_t out_offset = grid.block_index().x * BN;
|
||||
cub::StoreDirectBlocked(
|
||||
warp.meta_group_rank(),
|
||||
out + out_idx * args.reduction_stride + out_offset,
|
||||
totals,
|
||||
args.reduction_stride - out_offset);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
inline auto output_grid_for_col_reduce(
|
||||
const array& out,
|
||||
const cu::ColReduceArgs& args) {
|
||||
auto out_shape = out.shape();
|
||||
auto out_strides = out.strides();
|
||||
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
|
||||
out_shape.pop_back();
|
||||
out_strides.pop_back();
|
||||
}
|
||||
return get_2d_grid_dims(out_shape, out_strides);
|
||||
}
|
||||
|
||||
void col_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
cu::ColReduceArgs args(in, plan, axes);
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
constexpr int N_READS = 4;
|
||||
dim3 block_dims;
|
||||
dim3 num_blocks = output_grid_for_col_reduce(out, args);
|
||||
num_blocks.z = num_blocks.y;
|
||||
num_blocks.y = num_blocks.x;
|
||||
auto kernel =
|
||||
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
size_t total = args.non_col_reductions * args.reduction_size;
|
||||
if (total < 32) {
|
||||
size_t stride_blocks =
|
||||
cuda::ceil_div(args.reduction_stride, N_READS);
|
||||
block_dims.x = std::min(stride_blocks, 32ul);
|
||||
block_dims.y = std::min(total, 8ul);
|
||||
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
|
||||
} else {
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
block_dims.x = BM * BN / N_READS;
|
||||
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
|
||||
kernel = cu::
|
||||
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(), out.data<OutType>(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
74
mlx/backend/cuda/reduce/reduce.cuh
Normal file
74
mlx/backend/cuda/reduce/reduce.cuh
Normal file
@ -0,0 +1,74 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Dispatch dynamic ndim to constexpr.
|
||||
// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file.
|
||||
#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \
|
||||
if (ndim == 1) { \
|
||||
constexpr uint32_t NDIM = 1; \
|
||||
__VA_ARGS__; \
|
||||
} else if (ndim == 2) { \
|
||||
constexpr uint32_t NDIM = 2; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr uint32_t NDIM = 5; \
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
|
||||
// Dispatch reduce ops to constexpr.
|
||||
#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \
|
||||
if (REDUCE == Reduce::ReduceType::And) { \
|
||||
using OP = cu::And; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Or) { \
|
||||
using OP = cu::Or; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Sum) { \
|
||||
using OP = cu::Sum; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Prod) { \
|
||||
using OP = cu::Prod; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Max) { \
|
||||
using OP = cu::Max; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Min) { \
|
||||
using OP = cu::Min; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
throw std::invalid_argument("Unknown reduce type."); \
|
||||
}
|
||||
|
||||
void segmented_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
|
||||
void row_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
|
||||
void col_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
|
||||
} // namespace mlx::core
|
144
mlx/backend/cuda/reduce/reduce_ops.cuh
Normal file
144
mlx/backend/cuda/reduce/reduce_ops.cuh
Normal file
@ -0,0 +1,144 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Reduce ops.
|
||||
struct And {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a && b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Or {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a || b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Prod {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Min {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Max {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
// Traits to get the result type of reduce op.
|
||||
template <typename Op, typename T>
|
||||
struct ReduceResult;
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<And, T> {
|
||||
using type = bool;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Or, T> {
|
||||
using type = bool;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Sum, T> {
|
||||
using type = cuda::std::conditional_t<
|
||||
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
|
||||
int32_t,
|
||||
T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Prod, T> {
|
||||
using type = cuda::std::conditional_t<
|
||||
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
|
||||
int32_t,
|
||||
T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Min, T> {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Max, T> {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
// Traits to get the init value of reduce op.
|
||||
template <typename Op, typename T>
|
||||
struct ReduceInit;
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<And, T> {
|
||||
static constexpr __host__ __device__ bool value() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Or, T> {
|
||||
static constexpr __host__ __device__ bool value() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Sum, T> {
|
||||
static constexpr __host__ __device__ auto value() {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return T{0, 0};
|
||||
} else {
|
||||
return typename ReduceResult<Sum, T>::type{0};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Prod, T> {
|
||||
static constexpr __host__ __device__ auto value() {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return T{1, 1};
|
||||
} else {
|
||||
return typename ReduceResult<Prod, T>::type{1};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Min, T> {
|
||||
static constexpr __host__ __device__ T value() {
|
||||
return Limits<T>::max();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Max, T> {
|
||||
static constexpr __host__ __device__ T value() {
|
||||
return Limits<T>::min();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
250
mlx/backend/cuda/reduce/row_reduce.cu
Normal file
250
mlx/backend/cuda/reduce/row_reduce.cu
Normal file
@ -0,0 +1,250 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
struct RowReduceArgs {
|
||||
// The size of the row being reduced, i.e. the size of last dimension.
|
||||
int row_size;
|
||||
|
||||
// Input shape and strides excluding the reduction axes.
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
int ndim;
|
||||
|
||||
// Input shape and strides of the reduction axes excluding last dimension.
|
||||
Shape reduce_shape;
|
||||
Strides reduce_strides;
|
||||
int reduce_ndim;
|
||||
|
||||
// The number of rows we are reducing. Namely prod(reduce_shape).
|
||||
size_t non_row_reductions;
|
||||
|
||||
RowReduceArgs(
|
||||
const array& in,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes) {
|
||||
assert(!plan.shape.empty());
|
||||
row_size = plan.shape.back();
|
||||
|
||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
|
||||
reduce_shape = const_param(plan.shape);
|
||||
reduce_strides = const_param(plan.strides);
|
||||
reduce_ndim = plan.shape.size() - 1;
|
||||
|
||||
non_row_reductions = 1;
|
||||
for (int i = 0; i < reduce_ndim; i++) {
|
||||
non_row_reductions *= reduce_shape[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void row_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
size_t out_idx = cg::this_grid().thread_rank();
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void row_reduce_small_warp(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
size_t out_idx = grid.thread_rank() / WARP_SIZE;
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
|
||||
n += WARP_SIZE) {
|
||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
total_val = cg::reduce(warp, total_val, op);
|
||||
|
||||
if (warp.thread_rank() == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BLOCK_DIM_X,
|
||||
int N_READS = 4>
|
||||
__global__ void row_reduce_looped(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
|
||||
r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM_X + block.thread_index().x,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
total_val = BlockReduceT(temp).Reduce(total_val, op);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void row_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
cu::RowReduceArgs args(in, plan, axes);
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
constexpr size_t N_READS = 4;
|
||||
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims, num_blocks;
|
||||
auto kernel =
|
||||
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
if (args.row_size <= 64) {
|
||||
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
|
||||
(args.non_row_reductions <= 8)) {
|
||||
block_dims.x = std::min(out_dims.x, 1024u);
|
||||
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
|
||||
num_blocks.y = out_dims.y;
|
||||
} else {
|
||||
block_dims.x = WARP_SIZE;
|
||||
num_blocks.y = out_dims.x;
|
||||
num_blocks.z = out_dims.y;
|
||||
kernel =
|
||||
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
|
||||
}
|
||||
} else {
|
||||
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
|
||||
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
|
||||
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
|
||||
num_blocks.y = out_dims.x;
|
||||
num_blocks.z = out_dims.y;
|
||||
block_dims.x = BLOCK_DIM_X;
|
||||
kernel = cu::row_reduce_looped<
|
||||
InType,
|
||||
OutType,
|
||||
OP,
|
||||
NDIM,
|
||||
BLOCK_DIM_X,
|
||||
N_READS>;
|
||||
});
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(), out.data<OutType>(), out.size(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
@ -0,0 +1,84 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <cub/device/device_reduce.cuh>
|
||||
#include <cub/device/device_segmented_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename... Args>
|
||||
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(
|
||||
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
struct MultiplyOp {
|
||||
int factor;
|
||||
__device__ int operator()(int i) {
|
||||
return i * factor;
|
||||
}
|
||||
};
|
||||
|
||||
void segmented_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
auto in_iter = cu::make_cast_iterator<OutType>(
|
||||
thrust::device_pointer_cast(in.data<InType>()));
|
||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||
auto init = cu::ReduceInit<OP, InType>::value();
|
||||
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
cub_all_reduce(
|
||||
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
||||
} else if (plan.type == ContiguousReduce) {
|
||||
auto offsets = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
||||
cub_segmented_reduce(
|
||||
encoder,
|
||||
in_iter,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
offsets,
|
||||
offsets + 1,
|
||||
OP(),
|
||||
init,
|
||||
stream);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
343
mlx/backend/cuda/rms_norm.cu
Normal file
343
mlx/backend/cuda/rms_norm.cu
Normal file
@ -0,0 +1,343 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
inline __device__ float2 plus_f2(const float2& a, const float2& b) {
|
||||
return {a.x + b.x, a.y + b.y};
|
||||
}
|
||||
|
||||
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||
template <typename T, int BLOCK_DIM>
|
||||
struct BlockBroadcastReduce {
|
||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||
|
||||
cg::thread_block& block;
|
||||
TempStorage& temp;
|
||||
|
||||
template <typename Op>
|
||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
T x = cg::reduce(warp, input, op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
}
|
||||
|
||||
__device__ T Sum(const T& input) {
|
||||
return Reduce(input, cg::plus<T>{}, T{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm(
|
||||
const T* x,
|
||||
const T* w,
|
||||
T* out,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
normalizer += t * t;
|
||||
}
|
||||
}
|
||||
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = static_cast<float>(xn[i]) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm);
|
||||
}
|
||||
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm_vjp(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* g,
|
||||
T* gx,
|
||||
T* gw,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
||||
__shared__ union {
|
||||
typename BlockReduceF::TempStorage f;
|
||||
typename BlockReduceF2::TempStorage f2;
|
||||
} temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
g += grid.block_rank() * axis_size;
|
||||
gx += grid.block_rank() * axis_size;
|
||||
gw += grid.block_rank() * axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float2 factors = {};
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f2(factors, {wg * t, t * t});
|
||||
}
|
||||
}
|
||||
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
|
||||
float meangwx = factors.x / axis_size;
|
||||
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = xn[i];
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
|
||||
if constexpr (HAS_W) {
|
||||
wn[i] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool RMSNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||
void RMSNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("RMSNorm::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::rms_norm<DataType, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void RMSNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("RMSNormVJP::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
}
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return {x_copy, true};
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[2].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
auto [g, g_copied] = check_input(inputs[2]);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
|
||||
// Check whether we had a weight.
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs.
|
||||
bool g_in_gx = false;
|
||||
if (donate_x) {
|
||||
gx.copy_shared_buffer(x);
|
||||
} else if (donate_g) {
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
}
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
// gradient accumulators.
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
if (has_w) {
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(g);
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::rms_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if (has_w) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
160
mlx/backend/cuda/softmax.cu
Normal file
160
mlx/backend/cuda/softmax.cu
Normal file
@ -0,0 +1,160 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return __expf(x);
|
||||
}
|
||||
|
||||
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void softmax(const T* in, T* out, int axis_size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
in += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
cg::greater<AccT> max_op;
|
||||
cg::plus<AccT> plus_op;
|
||||
|
||||
// Thread reduce.
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min();
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
AccT vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM + block.thread_rank(),
|
||||
make_cast_iterator<AccT>(in),
|
||||
vals,
|
||||
axis_size,
|
||||
Limits<AccT>::finite_min());
|
||||
prevmax = maxval;
|
||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||
// Online normalizer calculation for softmax:
|
||||
// https://github.com/NVIDIA/online-softmax
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
|
||||
// First warp reduce.
|
||||
prevmax = maxval;
|
||||
maxval = cg::reduce(warp, maxval, max_op);
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||
|
||||
__shared__ AccT local_max[WARP_SIZE];
|
||||
__shared__ AccT local_normalizer[WARP_SIZE];
|
||||
|
||||
// Write to shared memory and do second warp reduce.
|
||||
prevmax = maxval;
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_max[warp.meta_group_rank()] = maxval;
|
||||
}
|
||||
block.sync();
|
||||
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_max[warp.thread_rank()]
|
||||
: Limits<AccT>::finite_min();
|
||||
maxval = cg::reduce(warp, maxval, max_op);
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_normalizer[warp.meta_group_rank()] = normalizer;
|
||||
}
|
||||
block.sync();
|
||||
normalizer = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_normalizer[warp.thread_rank()]
|
||||
: AccT{};
|
||||
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Write output.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(index, in, vals, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
|
||||
}
|
||||
cub::StoreDirectBlocked(index, out, vals, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Softmax::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& s = stream();
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
array in = set_output(inputs[0]);
|
||||
bool precise = in.dtype() != float32 && precise_;
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
|
||||
if (precise) {
|
||||
kernel = cu::softmax<DataType, float, BLOCK_DIM, N_READS>;
|
||||
}
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
177
mlx/backend/cuda/ternary.cu
Normal file
177
mlx/backend/cuda/ternary.cu
Normal file
@ -0,0 +1,177 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename T, typename IdxT>
|
||||
__global__ void
|
||||
ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[index], b[index], c[index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename T, typename IdxT, int NDIM>
|
||||
__global__ void ternary_g_nd(
|
||||
const bool* a,
|
||||
const T* b,
|
||||
const T* c,
|
||||
T* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx, c_idx] = elem_to_loc_nd<NDIM>(
|
||||
index,
|
||||
shape.data(),
|
||||
a_strides.data(),
|
||||
b_strides.data(),
|
||||
c_strides.data());
|
||||
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename T, typename IdxT>
|
||||
__global__ void ternary_g(
|
||||
const bool* a,
|
||||
const T* b,
|
||||
const T* c,
|
||||
T* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
const __grid_constant__ Strides c_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx, c_idx] = elem_to_loc_4d(
|
||||
index,
|
||||
shape.data(),
|
||||
a_strides.data(),
|
||||
b_strides.data(),
|
||||
c_strides.data(),
|
||||
ndim);
|
||||
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void ternary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const Stream& s) {
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
const auto& c = inputs[2];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
|
||||
using DType = cuda_type_t<CTYPE>;
|
||||
|
||||
auto topt = get_ternary_op_type(a, b, c);
|
||||
if (topt == TernaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
auto& c_strides = strides[2];
|
||||
bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
|
||||
c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides),
|
||||
const_param<NDIM>(c_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::ternary_g<Op, DType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
const_param(c_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& c = inputs[2];
|
||||
auto topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
ternary_op_gpu_inplace<Op>(inputs, out, s);
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("select::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
ternary_op_gpu<cu::Select>(inputs, out, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -2,10 +2,10 @@
|
||||
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernels/unary_ops.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
@ -23,4 +24,20 @@ void check_cuda_error(const char* name, cudaError_t err) {
|
||||
}
|
||||
}
|
||||
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||
if (dtype == float16) {
|
||||
return "__half";
|
||||
}
|
||||
if (dtype == bfloat16) {
|
||||
return "__nv_bfloat16";
|
||||
}
|
||||
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
||||
if (dtype == DTYPE) { \
|
||||
return #CPP_TYPE; \
|
||||
}
|
||||
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
|
||||
#undef SPECIALIZE_DtypeToString
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -12,6 +12,8 @@ namespace cu {
|
||||
class Device;
|
||||
}
|
||||
|
||||
struct Dtype;
|
||||
|
||||
// Cuda stream managed with RAII.
|
||||
class CudaStream {
|
||||
public:
|
||||
@ -35,4 +37,7 @@ void check_cuda_error(const char* name, cudaError_t err);
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
|
||||
// Convert Dtype to CUDA C++ types.
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
@ -170,6 +171,41 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
slice_gpu(in, out, start_indices_, strides_, stream());
|
||||
}
|
||||
|
||||
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
auto [data_offset, out_strides] =
|
||||
prepare_slice(out, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ upd.shape(),
|
||||
/* const Strides& i_strides = */ upd.strides(),
|
||||
/* const Strides& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ stream());
|
||||
}
|
||||
|
||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
|
||||
eval(inputs, out);
|
||||
|
@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||
return steel_matmul_regular(
|
||||
s,
|
||||
d,
|
||||
/* a = */ in_unfolded,
|
||||
/* b = */ wt_transpose,
|
||||
/* c = */ out,
|
||||
/* M = */ implicit_M,
|
||||
/* N = */ implicit_N,
|
||||
/* K = */ implicit_K,
|
||||
/* batch_size_out = */ groups,
|
||||
/* a_cols = */ implicit_K * groups,
|
||||
/* b_cols = */ implicit_K,
|
||||
/* out_cols = */ implicit_N * groups,
|
||||
/* a_transposed = */ false,
|
||||
/* b_transposed = */ true,
|
||||
/* batch_shape = */ {1},
|
||||
/* batch_strides = */ {0},
|
||||
/* A_batch_strides = */ size_t(implicit_K),
|
||||
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
|
||||
/* matrix_stride_out = */ size_t(implicit_N),
|
||||
/*copies = */ copies);
|
||||
/* const Stream& s = */ s,
|
||||
/* Device& d = */ d,
|
||||
/* const array& a = */ in_unfolded,
|
||||
/* const array& b = */ wt_transpose,
|
||||
/* array& c = */ out,
|
||||
/* int M = */ implicit_M,
|
||||
/* int N = */ implicit_N,
|
||||
/* int K = */ implicit_K,
|
||||
/* int batch_size_out = */ groups,
|
||||
/* int lda = */ implicit_K * groups,
|
||||
/* int ldb = */ implicit_K,
|
||||
/* int ldd = */ implicit_N * groups,
|
||||
/* bool transpose_a = */ false,
|
||||
/* bool transpose_b = */ true,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ {1},
|
||||
/* Strides batch_strides = */ {0},
|
||||
/* int64_t A_batch_strides = */ int64_t(implicit_K),
|
||||
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
|
||||
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
|
@ -297,6 +297,9 @@ Device::Device() {
|
||||
device_ = load_device();
|
||||
default_library_ = load_default_library(device_);
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
int ag_tens = arch_[arch_.size() - 3] - '0';
|
||||
int ag_ones = arch_[arch_.size() - 2] - '0';
|
||||
arch_gen_ = ag_tens * 10 + ag_ones;
|
||||
auto arch = arch_.back();
|
||||
switch (arch) {
|
||||
case 'p': // phone
|
||||
|
@ -177,6 +177,10 @@ class Device {
|
||||
return arch_;
|
||||
}
|
||||
|
||||
int get_architecture_gen() const {
|
||||
return arch_gen_;
|
||||
}
|
||||
|
||||
void new_queue(int index);
|
||||
|
||||
MTL::CommandQueue* get_queue(Stream stream);
|
||||
@ -268,6 +272,7 @@ class Device {
|
||||
library_kernels_;
|
||||
const MTL::ResidencySet* residency_set_{nullptr};
|
||||
std::string arch_;
|
||||
int arch_gen_;
|
||||
int max_ops_per_buffer_;
|
||||
int max_mb_per_buffer_;
|
||||
};
|
||||
|
@ -33,8 +33,8 @@ template <
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
|
||||
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -6,7 +6,34 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void steel_matmul_regular(
|
||||
template <bool CHECK_AB = true>
|
||||
void steel_matmul_regular_axpby(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape,
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out,
|
||||
int64_t C_batch_stride = 0,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f);
|
||||
|
||||
inline void steel_matmul_regular(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
@ -21,14 +48,61 @@ void steel_matmul_regular(
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape,
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out,
|
||||
std::vector<array>& copies);
|
||||
int64_t matrix_stride_out) {
|
||||
return steel_matmul_regular_axpby<false>(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ b,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* int ldd = */ ldd,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides batch_strides = */ batch_strides,
|
||||
/* int64_t A_batch_stride = */ A_batch_stride,
|
||||
/* int64_t B_batch_stride = */ B_batch_stride,
|
||||
/* int64_t matrix_stride_out = */ matrix_stride_out);
|
||||
}
|
||||
|
||||
void steel_matmul(
|
||||
template <bool CHECK_AB = true>
|
||||
void steel_matmul_axpby(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape = {},
|
||||
Strides A_batch_stride = {},
|
||||
Strides B_batch_stride = {},
|
||||
Strides C_batch_stride = {},
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f);
|
||||
|
||||
inline void steel_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
@ -45,6 +119,26 @@ void steel_matmul(
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape = {},
|
||||
Strides A_batch_stride = {},
|
||||
Strides B_batch_stride = {});
|
||||
Strides B_batch_stride = {}) {
|
||||
return steel_matmul_axpby<false>(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ b,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides A_batch_stride = */ A_batch_stride,
|
||||
/* Strides B_batch_stride = */ B_batch_stride);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -26,7 +26,7 @@ void RMSNorm::eval_gpu(
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
@ -227,7 +227,7 @@ void LayerNorm::eval_gpu(
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
|
@ -322,41 +322,6 @@ void DynamicSliceUpdate::eval_gpu(
|
||||
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
auto [data_offset, out_strides] =
|
||||
prepare_slice(out, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ upd.shape(),
|
||||
/* const Strides& i_strides = */ upd.strides(),
|
||||
/* const Strides& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ stream());
|
||||
}
|
||||
|
||||
void QRF::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
|
@ -225,6 +225,8 @@ struct MPIWrapper {
|
||||
return mpi_bfloat16_;
|
||||
case float64:
|
||||
return mpi_double_;
|
||||
default:
|
||||
throw std::runtime_error("Invalid type");
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user