[CUDA] RMSNorm and VJP (#2280)

* rms norm start

* nit
This commit is contained in:
Awni Hannun 2025-06-12 17:09:49 -07:00 committed by GitHub
parent a4fc671d3e
commit 918761a25a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 345 additions and 4 deletions

View File

@ -30,6 +30,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu

View File

@ -244,8 +244,7 @@ void LayerNorm::eval_gpu(
}
};
array o = set_output(inputs[0]);
const array& x = o.data_shared_ptr() ? o : out;
const array x = set_output(inputs[0]);
const array& w = inputs[1];
const array& b = inputs[2];

View File

@ -99,8 +99,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_USE_FALLBACK(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)

View File

@ -0,0 +1,343 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
inline __device__ float2 plus_f2(const float2& a, const float2& b) {
return {a.x + b.x, a.y + b.y};
}
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
template <typename T, int BLOCK_DIM>
struct BlockBroadcastReduce {
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
static_assert(BLOCK_DIM % WARP_SIZE == 0);
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
cg::thread_block& block;
TempStorage& temp;
template <typename Op>
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
auto warp = cg::tiled_partition<WARP_SIZE>(block);
T x = cg::reduce(warp, input, op);
if (warp.thread_rank() == 0) {
temp[warp.meta_group_rank()] = x;
}
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
}
__device__ T Sum(const T& input) {
return Reduce(input, cg::plus<T>{}, T{});
}
};
template <typename T, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm(
const T* x,
const T* w,
T* out,
float eps,
int32_t axis_size,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
__shared__ typename BlockReduceT::TempStorage temp;
x += grid.block_rank() * axis_size;
out += grid.block_rank() * axis_size;
// Normalizer.
float normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]);
normalizer += t * t;
}
}
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
normalizer = rsqrt(normalizer / axis_size + eps);
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
T wn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; ++i) {
float norm = static_cast<float>(xn[i]) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm);
}
cub::StoreDirectBlocked(index, out, xn, axis_size);
}
}
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm_vjp(
const T* x,
const T* w,
const T* g,
T* gx,
T* gw,
float eps,
int32_t axis_size,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
__shared__ union {
typename BlockReduceF::TempStorage f;
typename BlockReduceF2::TempStorage f2;
} temp;
x += grid.block_rank() * axis_size;
g += grid.block_rank() * axis_size;
gx += grid.block_rank() * axis_size;
gw += grid.block_rank() * axis_size;
// Normalizer.
float2 factors = {};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T xn[N_READS];
T wn[N_READS] = {};
T gn[N_READS] = {};
auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]);
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f2(factors, {wg * t, t * t});
}
}
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
float meangwx = factors.x / axis_size;
float normalizer = rsqrt(factors.y / axis_size + eps);
float normalizer3 = normalizer * normalizer * normalizer;
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
T wn[N_READS];
T gn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float xi = xn[i];
float wi = wn[i];
float gi = gn[i];
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
if constexpr (HAS_W) {
wn[i] = static_cast<T>(gi * xi * normalizer);
}
}
cub::StoreDirectBlocked(index, gx, xn, axis_size);
if constexpr (HAS_W) {
cub::StoreDirectBlocked(index, gw, wn, axis_size);
}
}
}
} // namespace cu
namespace fast {
bool RMSNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
// TODO: There are duplicate code with backend/metal/normalization.cpp
void RMSNorm::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("RMSNorm::eval_gpu");
auto& s = stream();
auto& out = outputs[0];
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
const array x = set_output(inputs[0]);
const array& w = inputs[1];
int32_t axis_size = x.shape().back();
int32_t n_rows = x.data_size() / axis_size;
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::rms_norm<DataType, BLOCK_DIM, N_READS>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
x.data<DataType>(),
w.data<DataType>(),
out.data<DataType>(),
eps_,
axis_size,
w_stride);
});
});
});
}
void RMSNormVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("RMSNormVJP::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
if (x.flags().row_contiguous) {
return {x, false};
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[2].is_donatable();
auto [x, copied] = check_input(inputs[0]);
donate_x |= copied;
const array& w = inputs[1];
auto [g, g_copied] = check_input(inputs[2]);
donate_g |= g_copied;
array& gx = outputs[0];
array& gw = outputs[1];
// Check whether we had a weight.
bool has_w = w.ndim() != 0;
// Allocate space for the outputs.
bool g_in_gx = false;
if (donate_x) {
gx.copy_shared_buffer(x);
} else if (donate_g) {
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc(gx.nbytes()));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
}
int32_t axis_size = x.shape().back();
int32_t n_rows = x.data_size() / axis_size;
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
// Allocate a temporary to store the gradients for w and allocate the output
// gradient accumulators.
array gw_temp =
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
if (has_w) {
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
encoder.add_temporary(gw_temp);
}
}
gw.set_data(allocator::malloc(gw.nbytes()));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(g);
encoder.set_output_array(gx);
encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
constexpr int N_READS = 4;
MLX_SWITCH_BOOL(has_w, HAS_W, {
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::rms_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
eps_,
axis_size,
w_stride);
});
});
});
});
if (has_w) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
}
}
} // namespace fast
} // namespace mlx::core