mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
faster rms norm (#2433)
This commit is contained in:
@@ -10,8 +10,6 @@
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -74,9 +72,11 @@ __global__ void layer_norm(
|
||||
float sum = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
sum += static_cast<float>(xn[i]);
|
||||
}
|
||||
}
|
||||
sum = BlockReduceT{block, temp}.Sum(sum);
|
||||
|
||||
@@ -87,11 +87,18 @@ __global__ void layer_norm(
|
||||
float normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
normalizer += t * t;
|
||||
if ((index + 1) * N_READS <= axis_size) {
|
||||
auto xn = load_vector<N_READS>(x, index);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
normalizer += t * t;
|
||||
}
|
||||
} else {
|
||||
for (int i = index * N_READS; i < axis_size; ++i) {
|
||||
float t = static_cast<float>(x[i]) - mean;
|
||||
normalizer += t * t;
|
||||
}
|
||||
}
|
||||
}
|
||||
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||
@@ -100,17 +107,15 @@ __global__ void layer_norm(
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T bn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(b, b_stride), bn, axis_size);
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
auto bn = load_vector<N_READS>(b, index, axis_size, b_stride, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
||||
}
|
||||
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
||||
store_vector<N_READS>(out, index, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,9 +148,11 @@ __global__ void layer_norm_vjp(
|
||||
float sum = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
sum += static_cast<float>(xn[i]);
|
||||
}
|
||||
}
|
||||
sum = BlockReduceF{block, temp.f}.Sum(sum);
|
||||
|
||||
@@ -155,19 +162,28 @@ __global__ void layer_norm_vjp(
|
||||
// Normalizer.
|
||||
float3 factors = {};
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
|
||||
if ((index + 1) * N_READS <= axis_size) {
|
||||
auto xn = load_vector<N_READS>(x, index);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||
}
|
||||
} else {
|
||||
for (int i = index * N_READS; i < axis_size; ++i) {
|
||||
float t = static_cast<float>(x[i]) - mean;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||
}
|
||||
}
|
||||
}
|
||||
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
|
||||
@@ -179,12 +195,10 @@ __global__ void layer_norm_vjp(
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
float wi = wn[i];
|
||||
@@ -194,9 +208,9 @@ __global__ void layer_norm_vjp(
|
||||
wn[i] = gi * xi;
|
||||
}
|
||||
}
|
||||
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
||||
store_vector<N_READS>(gx, index, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
||||
store_vector<N_READS>(gw, index, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -257,9 +271,9 @@ void LayerNorm::eval_gpu(
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
|
||||
constexpr uint32_t N_READS = 4;
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 16 / sizeof(DataType);
|
||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
@@ -364,10 +378,10 @@ void LayerNormVJP::eval_gpu(
|
||||
encoder.set_output_array(gw_temp);
|
||||
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
|
||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||
constexpr int N_READS = 4;
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 16 / sizeof(DataType);
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::layer_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant.value,
|
||||
|
||||
Reference in New Issue
Block a user