diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 854c6c116..3e7b859a6 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -30,6 +30,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 5aa287603..c71795fad 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -244,8 +244,7 @@ void LayerNorm::eval_gpu( } }; - array o = set_output(inputs[0]); - const array& x = o.data_shared_ptr() ? o : out; + const array x = set_output(inputs[0]); const array& w = inputs[1]; const array& b = inputs[2]; diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 5fef9e150..95ea44f94 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -99,8 +99,6 @@ NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast { -NO_GPU_USE_FALLBACK(RMSNorm) -NO_GPU_MULTI(RMSNormVJP) NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu new file mode 100644 index 000000000..3c521b90d --- /dev/null +++ b/mlx/backend/cuda/rms_norm.cu @@ -0,0 +1,343 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/strided_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include +#include +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +inline __device__ float2 plus_f2(const float2& a, const float2& b) { + return {a.x + b.x, a.y + b.y}; +} + +// Similar to cub::BlockReduce, but result is broadcasted to every thread. +template +struct BlockBroadcastReduce { + static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE); + static_assert(BLOCK_DIM % WARP_SIZE == 0); + using TempStorage = T[BLOCK_DIM / WARP_SIZE]; + + cg::thread_block& block; + TempStorage& temp; + + template + __device__ T Reduce(const T& input, const Op& op, const T& init_value) { + auto warp = cg::tiled_partition(block); + T x = cg::reduce(warp, input, op); + if (warp.thread_rank() == 0) { + temp[warp.meta_group_rank()] = x; + } + block.sync(); + x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()] + : init_value; + return cg::reduce(warp, x, op); + } + + __device__ T Sum(const T& input) { + return Reduce(input, cg::plus{}, T{}); + } +}; + +template +__global__ void rms_norm( + const T* x, + const T* w, + T* out, + float eps, + int32_t axis_size, + int64_t w_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceT = BlockBroadcastReduce; + __shared__ typename BlockReduceT::TempStorage temp; + + x += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + // Normalizer. + float normalizer = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size, 0); + for (int i = 0; i < N_READS; ++i) { + float t = static_cast(xn[i]); + normalizer += t * t; + } + } + normalizer = BlockReduceT{block, temp}.Sum(normalizer); + normalizer = rsqrt(normalizer / axis_size + eps); + + // Outputs. + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; ++i) { + float norm = static_cast(xn[i]) * normalizer; + xn[i] = wn[i] * static_cast(norm); + } + cub::StoreDirectBlocked(index, out, xn, axis_size); + } +} + +template +__global__ void rms_norm_vjp( + const T* x, + const T* w, + const T* g, + T* gx, + T* gw, + float eps, + int32_t axis_size, + int64_t w_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceF = BlockBroadcastReduce; + using BlockReduceF2 = BlockBroadcastReduce; + __shared__ union { + typename BlockReduceF::TempStorage f; + typename BlockReduceF2::TempStorage f2; + } temp; + + x += grid.block_rank() * axis_size; + g += grid.block_rank() * axis_size; + gx += grid.block_rank() * axis_size; + gw += grid.block_rank() * axis_size; + + // Normalizer. + float2 factors = {}; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + T xn[N_READS]; + T wn[N_READS] = {}; + T gn[N_READS] = {}; + auto index = r * BLOCK_DIM + block.thread_rank(); + cub::LoadDirectBlocked(index, x, xn, axis_size, 0); + cub::LoadDirectBlocked(index, g, gn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float t = static_cast(xn[i]); + float wi = wn[i]; + float gi = gn[i]; + float wg = wi * gi; + factors = plus_f2(factors, {wg * t, t * t}); + } + } + factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {}); + float meangwx = factors.x / axis_size; + float normalizer = rsqrt(factors.y / axis_size + eps); + float normalizer3 = normalizer * normalizer * normalizer; + + // Outputs. + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + T gn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size); + cub::LoadDirectBlocked(index, g, gn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float xi = xn[i]; + float wi = wn[i]; + float gi = gn[i]; + xn[i] = static_cast(normalizer * wi * gi - xi * meangwx * normalizer3); + if constexpr (HAS_W) { + wn[i] = static_cast(gi * xi * normalizer); + } + } + cub::StoreDirectBlocked(index, gx, xn, axis_size); + if constexpr (HAS_W) { + cub::StoreDirectBlocked(index, gw, wn, axis_size); + } + } +} + +} // namespace cu + +namespace fast { + +bool RMSNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +// TODO: There are duplicate code with backend/metal/normalization.cpp +void RMSNorm::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("RMSNorm::eval_gpu"); + auto& s = stream(); + auto& out = outputs[0]; + + // Make sure that the last dimension is contiguous. + auto set_output = [&s, &out](const array& x) { + bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; + if (no_copy && x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 2]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; + } + }; + + const array x = set_output(inputs[0]); + const array& w = inputs[1]; + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, { + using DataType = cuda_type_t; + constexpr uint32_t N_READS = 4; + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = cu::rms_norm; + kernel<<>>( + x.data(), + w.data(), + out.data(), + eps_, + axis_size, + w_stride); + }); + }); + }); +} + +void RMSNormVJP::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("RMSNormVJP::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + // Ensure row contiguity. We could relax this step by checking that the array + // is contiguous (no broadcasts or holes) and that the input strides are the + // same as the cotangent strides but for now this is simpler. + auto check_input = [&s](const array& x) -> std::pair { + if (x.flags().row_contiguous) { + return {x, false}; + } + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + return {x_copy, true}; + }; + bool donate_x = inputs[0].is_donatable(); + bool donate_g = inputs[2].is_donatable(); + auto [x, copied] = check_input(inputs[0]); + donate_x |= copied; + const array& w = inputs[1]; + auto [g, g_copied] = check_input(inputs[2]); + donate_g |= g_copied; + array& gx = outputs[0]; + array& gw = outputs[1]; + + // Check whether we had a weight. + bool has_w = w.ndim() != 0; + + // Allocate space for the outputs. + bool g_in_gx = false; + if (donate_x) { + gx.copy_shared_buffer(x); + } else if (donate_g) { + gx.copy_shared_buffer(g); + g_in_gx = true; + } else { + gx.set_data(allocator::malloc(gx.nbytes())); + } + if (g_copied && !g_in_gx) { + encoder.add_temporary(g); + } + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + // Allocate a temporary to store the gradients for w and allocate the output + // gradient accumulators. + array gw_temp = + (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; + if (has_w) { + if (!g_in_gx && donate_g) { + gw_temp.copy_shared_buffer(g); + } else { + gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); + encoder.add_temporary(gw_temp); + } + } + gw.set_data(allocator::malloc(gw.nbytes())); + + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(g); + encoder.set_output_array(gx); + encoder.set_output_array(gw_temp); + encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + MLX_SWITCH_BOOL(has_w, HAS_W, { + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = cu::rms_norm_vjp; + kernel<<>>( + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); + }); + }); + }); + + if (has_w) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan); + } +} + +} // namespace fast + +} // namespace mlx::core