diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 7fc4022de..a9fc79f19 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -22,26 +22,28 @@ inline __device__ float2 plus_f2(const float2& a, const float2& b) { } // Similar to cub::BlockReduce, but result is broadcasted to every thread. -template +template struct BlockBroadcastReduce { - static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE); - static_assert(BLOCK_DIM % WARP_SIZE == 0); - using TempStorage = T[BLOCK_DIM / WARP_SIZE]; + using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)]; cg::thread_block& block; TempStorage& temp; template __device__ T Reduce(const T& input, const Op& op, const T& init_value) { - auto warp = cg::tiled_partition(block); + auto warp = cg::tiled_partition(block); T x = cg::reduce(warp, input, op); - if (warp.thread_rank() == 0) { - temp[warp.meta_group_rank()] = x; + if constexpr (BLOCK_DIM > GROUP_DIM) { + if (warp.thread_rank() == 0) { + temp[warp.meta_group_rank()] = x; + } + block.sync(); + x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()] + : init_value; + return cg::reduce(warp, x, op); + } else { + return x; } - block.sync(); - x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()] - : init_value; - return cg::reduce(warp, x, op); } __device__ T Sum(const T& input) { @@ -49,6 +51,52 @@ struct BlockBroadcastReduce { } }; +template +__global__ void rms_norm_small( + const T* x, + const T* w, + T* out, + float eps, + uint32_t axis_size, + uint32_t n_rows, + int64_t w_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceT = BlockBroadcastReduce; + __shared__ typename BlockReduceT::TempStorage temp; + + auto row = + (grid.block_rank() * block.dim_threads().y) + block.thread_index().y; + if (row >= n_rows) { + return; + } + x += row * axis_size; + out += row * axis_size; + + // Normalizer. + float normalizer = 0; + auto index = block.thread_index().x; + auto xn = load_vector(x, index, axis_size, T(0)); +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + float t = static_cast(xn[i]); + normalizer += t * t; + } + + normalizer = BlockReduceT{block, temp}.Sum(normalizer); + normalizer = rsqrt(normalizer / axis_size + eps); + + // Outputs. + auto wn = load_vector(w, index, axis_size, w_stride, T(0)); +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + float y = static_cast(xn[i]) * normalizer; + xn[i] = wn[i] * static_cast(y); + } + store_vector(out, index, xn, axis_size); +} + template __global__ void rms_norm( const T* x, @@ -94,6 +142,74 @@ __global__ void rms_norm( } } +template < + typename T, + bool HAS_W, + int BLOCK_DIM, + int REDUCE_DIM, + int N_READS = 4> +__global__ void rms_norm_vjp_small( + const T* x, + const T* w, + const T* g, + T* gx, + T* gw, + float eps, + int32_t axis_size, + int32_t n_rows, + int64_t w_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceF2 = BlockBroadcastReduce; + __shared__ typename BlockReduceF2::TempStorage temp; + + auto row = + (grid.block_rank() * block.dim_threads().y) + block.thread_index().y; + if (row >= n_rows) { + return; + } + + x += row * axis_size; + g += row * axis_size; + gx += row * axis_size; + gw += row * axis_size; + + // Normalizer. + float2 factors = {}; + auto index = block.thread_index().x; + auto xn = load_vector(x, index, axis_size, T(0)); + auto gn = load_vector(g, index, axis_size, T(0)); + auto wn = load_vector(w, index, axis_size, w_stride, T(0)); + for (int i = 0; i < N_READS; i++) { + float t = static_cast(xn[i]); + float wi = wn[i]; + float gi = gn[i]; + float wg = wi * gi; + factors = plus_f2(factors, {wg * t, t * t}); + } + + factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {}); + float meangwx = factors.x / axis_size; + float normalizer = rsqrt(factors.y / axis_size + eps); + float normalizer3 = normalizer * normalizer * normalizer; + + // Outputs. + for (int i = 0; i < N_READS; i++) { + float xi = xn[i]; + float wi = wn[i]; + float gi = gn[i]; + xn[i] = static_cast(normalizer * wi * gi - xi * meangwx * normalizer3); + if constexpr (HAS_W) { + wn[i] = static_cast(gi * xi * normalizer); + } + } + store_vector(gx, index, xn, axis_size); + if constexpr (HAS_W) { + store_vector(gw, index, wn, axis_size); + } +} + template __global__ void rms_norm_vjp( const T* x, @@ -107,12 +223,8 @@ __global__ void rms_norm_vjp( auto grid = cg::this_grid(); auto block = cg::this_thread_block(); - using BlockReduceF = BlockBroadcastReduce; using BlockReduceF2 = BlockBroadcastReduce; - __shared__ union { - typename BlockReduceF::TempStorage f; - typename BlockReduceF2::TempStorage f2; - } temp; + __shared__ typename BlockReduceF2::TempStorage temp; x += grid.block_rank() * axis_size; g += grid.block_rank() * axis_size; @@ -134,7 +246,7 @@ __global__ void rms_norm_vjp( factors = plus_f2(factors, {wg * t, t * t}); } } - factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {}); + factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {}); float meangwx = factors.x / axis_size; float normalizer = rsqrt(factors.y / axis_size + eps); float normalizer3 = normalizer * normalizer * normalizer; @@ -169,6 +281,43 @@ bool RMSNorm::use_fallback(Stream s) { return s.device == Device::cpu; } +template +void dispatch_group_dim(int axis_size, F&& f) { + if (axis_size <= n_per_thread * 8) { + f(std::integral_constant{}, + std::integral_constant(), + std::integral_constant()); + } else if (axis_size <= n_per_thread * 16) { + f(std::integral_constant{}, + std::integral_constant(), + std::integral_constant()); + } else if (axis_size <= n_per_thread * 32) { + f(std::integral_constant{}, + std::integral_constant(), + std::integral_constant()); + } else if (axis_size <= n_per_thread * 32 * 2) { + f(std::integral_constant{}, + std::integral_constant(), + std::integral_constant()); + } else if (axis_size <= n_per_thread * 32 * 4) { + f(std::integral_constant{}, + std::integral_constant(), + std::integral_constant()); + } else if (axis_size <= n_per_thread * 32 * 8) { + f(std::integral_constant{}, + std::integral_constant(), + std::integral_constant()); + } else if (axis_size <= n_per_thread * 32 * 16) { + f(std::integral_constant{}, + std::integral_constant(), + std::integral_constant()); + } else { + f(std::integral_constant{}, + std::integral_constant(), + std::integral_constant()); + } +} + // TODO: There are duplicate code with backend/metal/normalization.cpp void RMSNorm::eval_gpu( const std::vector& inputs, @@ -216,12 +365,33 @@ void RMSNorm::eval_gpu( dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { using DataType = cuda_type_t; constexpr int N_READS = 16 / sizeof(DataType); - dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - auto kernel = cu::rms_norm; + if (axis_size <= N_READS * 1024) { + dispatch_group_dim( + axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) { + constexpr int block_dim = n_groups() * group_dim(); + auto kernel = + cu::rms_norm_small; + auto n_blocks = + (n_rows + groups_per_block() - 1) / groups_per_block(); + encoder.add_kernel_node( + kernel, + n_blocks, + {block_dim, groups_per_block()}, + 0, + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(out), + eps_, + axis_size, + n_rows, + w_stride); + }); + } else { + auto kernel = cu::rms_norm; encoder.add_kernel_node( kernel, n_rows, - block_dim(), + 1024, 0, gpu_ptr(x), gpu_ptr(w), @@ -229,7 +399,7 @@ void RMSNorm::eval_gpu( eps_, axis_size, w_stride); - }); + } }); } @@ -306,27 +476,51 @@ void RMSNormVJP::eval_gpu( dispatch_bool(has_w, [&](auto has_w_constant) { using DataType = cuda_type_t; constexpr int N_READS = 16 / sizeof(DataType); - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - auto kernel = cu::rms_norm_vjp< - DataType, - has_w_constant.value, - block_dim(), - N_READS>; - encoder.add_kernel_node( - kernel, - n_rows, - block_dim(), - 0, - gpu_ptr(x), - gpu_ptr(w), - gpu_ptr(g), - gpu_ptr(gx), - gpu_ptr(gw_temp), - eps_, - axis_size, - w_stride); - }); + if (axis_size <= N_READS * 1024) { + dispatch_group_dim( + axis_size, + [&](auto group_dim, auto n_groups, auto groups_per_block) { + constexpr int block_dim = group_dim() * n_groups(); + auto kernel = cu::rms_norm_vjp_small< + DataType, + has_w_constant.value, + block_dim, + group_dim(), + N_READS>; + auto n_blocks = + (n_rows + groups_per_block() - 1) / groups_per_block(); + encoder.add_kernel_node( + kernel, + n_blocks, + {block_dim, groups_per_block()}, + 0, + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(g), + gpu_ptr(gx), + gpu_ptr(gw_temp), + eps_, + axis_size, + n_rows, + w_stride); + }); + } else { + auto kernel = + cu::rms_norm_vjp; + encoder.add_kernel_node( + kernel, + n_rows, + 1024, + 0, + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(g), + gpu_ptr(gx), + gpu_ptr(gw_temp), + eps_, + axis_size, + w_stride); + } }); });