mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
faster rms norm (#2433)
This commit is contained in:
@@ -10,8 +10,6 @@
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -57,7 +55,7 @@ __global__ void rms_norm(
|
||||
const T* w,
|
||||
T* out,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
uint32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
@@ -72,8 +70,8 @@ __global__ void rms_norm(
|
||||
float normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
normalizer += t * t;
|
||||
@@ -85,15 +83,14 @@ __global__ void rms_norm(
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = static_cast<float>(xn[i]) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm);
|
||||
float y = static_cast<float>(xn[i]) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(y);
|
||||
}
|
||||
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
||||
store_vector<N_READS>(out, index, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,13 +122,10 @@ __global__ void rms_norm_vjp(
|
||||
// Normalizer.
|
||||
float2 factors = {};
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
float wi = wn[i];
|
||||
@@ -148,12 +142,9 @@ __global__ void rms_norm_vjp(
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = xn[i];
|
||||
float wi = wn[i];
|
||||
@@ -163,9 +154,9 @@ __global__ void rms_norm_vjp(
|
||||
wn[i] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
||||
store_vector<N_READS>(gx, index, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
||||
store_vector<N_READS>(gw, index, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,9 +214,9 @@ void RMSNorm::eval_gpu(
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
||||
constexpr uint32_t N_READS = 4;
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 16 / sizeof(DataType);
|
||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
@@ -312,11 +303,10 @@ void RMSNormVJP::eval_gpu(
|
||||
encoder.set_output_array(gw_temp);
|
||||
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
|
||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||
constexpr int N_READS = 4;
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 16 / sizeof(DataType);
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 4;
|
||||
auto kernel = cu::rms_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant.value,
|
||||
|
||||
Reference in New Issue
Block a user