mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Faster rms norm for small dimension (#2838)
This commit is contained in:
@@ -22,26 +22,28 @@ inline __device__ float2 plus_f2(const float2& a, const float2& b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||||
template <typename T, int BLOCK_DIM>
|
template <typename T, int BLOCK_DIM, int GROUP_DIM = WARP_SIZE>
|
||||||
struct BlockBroadcastReduce {
|
struct BlockBroadcastReduce {
|
||||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)];
|
||||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
|
||||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
|
||||||
|
|
||||||
cg::thread_block& block;
|
cg::thread_block& block;
|
||||||
TempStorage& temp;
|
TempStorage& temp;
|
||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<GROUP_DIM>(block);
|
||||||
T x = cg::reduce(warp, input, op);
|
T x = cg::reduce(warp, input, op);
|
||||||
if (warp.thread_rank() == 0) {
|
if constexpr (BLOCK_DIM > GROUP_DIM) {
|
||||||
temp[warp.meta_group_rank()] = x;
|
if (warp.thread_rank() == 0) {
|
||||||
|
temp[warp.meta_group_rank()] = x;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||||
|
: init_value;
|
||||||
|
return cg::reduce(warp, x, op);
|
||||||
|
} else {
|
||||||
|
return x;
|
||||||
}
|
}
|
||||||
block.sync();
|
|
||||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
|
||||||
: init_value;
|
|
||||||
return cg::reduce(warp, x, op);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ T Sum(const T& input) {
|
__device__ T Sum(const T& input) {
|
||||||
@@ -49,6 +51,52 @@ struct BlockBroadcastReduce {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4>
|
||||||
|
__global__ void rms_norm_small(
|
||||||
|
const T* x,
|
||||||
|
const T* w,
|
||||||
|
T* out,
|
||||||
|
float eps,
|
||||||
|
uint32_t axis_size,
|
||||||
|
uint32_t n_rows,
|
||||||
|
int64_t w_stride) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM, REDUCE_DIM>;
|
||||||
|
__shared__ typename BlockReduceT::TempStorage temp;
|
||||||
|
|
||||||
|
auto row =
|
||||||
|
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
|
||||||
|
if (row >= n_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
x += row * axis_size;
|
||||||
|
out += row * axis_size;
|
||||||
|
|
||||||
|
// Normalizer.
|
||||||
|
float normalizer = 0;
|
||||||
|
auto index = block.thread_index().x;
|
||||||
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
float t = static_cast<float>(xn[i]);
|
||||||
|
normalizer += t * t;
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||||
|
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||||
|
|
||||||
|
// Outputs.
|
||||||
|
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
float y = static_cast<float>(xn[i]) * normalizer;
|
||||||
|
xn[i] = wn[i] * static_cast<T>(y);
|
||||||
|
}
|
||||||
|
store_vector<N_READS>(out, index, xn, axis_size);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||||
__global__ void rms_norm(
|
__global__ void rms_norm(
|
||||||
const T* x,
|
const T* x,
|
||||||
@@ -94,6 +142,74 @@ __global__ void rms_norm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
bool HAS_W,
|
||||||
|
int BLOCK_DIM,
|
||||||
|
int REDUCE_DIM,
|
||||||
|
int N_READS = 4>
|
||||||
|
__global__ void rms_norm_vjp_small(
|
||||||
|
const T* x,
|
||||||
|
const T* w,
|
||||||
|
const T* g,
|
||||||
|
T* gx,
|
||||||
|
T* gw,
|
||||||
|
float eps,
|
||||||
|
int32_t axis_size,
|
||||||
|
int32_t n_rows,
|
||||||
|
int64_t w_stride) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM, REDUCE_DIM>;
|
||||||
|
__shared__ typename BlockReduceF2::TempStorage temp;
|
||||||
|
|
||||||
|
auto row =
|
||||||
|
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
|
||||||
|
if (row >= n_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
x += row * axis_size;
|
||||||
|
g += row * axis_size;
|
||||||
|
gx += row * axis_size;
|
||||||
|
gw += row * axis_size;
|
||||||
|
|
||||||
|
// Normalizer.
|
||||||
|
float2 factors = {};
|
||||||
|
auto index = block.thread_index().x;
|
||||||
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
|
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||||
|
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float t = static_cast<float>(xn[i]);
|
||||||
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
float wg = wi * gi;
|
||||||
|
factors = plus_f2(factors, {wg * t, t * t});
|
||||||
|
}
|
||||||
|
|
||||||
|
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
|
||||||
|
float meangwx = factors.x / axis_size;
|
||||||
|
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||||
|
float normalizer3 = normalizer * normalizer * normalizer;
|
||||||
|
|
||||||
|
// Outputs.
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float xi = xn[i];
|
||||||
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
|
||||||
|
if constexpr (HAS_W) {
|
||||||
|
wn[i] = static_cast<T>(gi * xi * normalizer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
store_vector<N_READS>(gx, index, xn, axis_size);
|
||||||
|
if constexpr (HAS_W) {
|
||||||
|
store_vector<N_READS>(gw, index, wn, axis_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||||
__global__ void rms_norm_vjp(
|
__global__ void rms_norm_vjp(
|
||||||
const T* x,
|
const T* x,
|
||||||
@@ -107,12 +223,8 @@ __global__ void rms_norm_vjp(
|
|||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
|
||||||
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
||||||
__shared__ union {
|
__shared__ typename BlockReduceF2::TempStorage temp;
|
||||||
typename BlockReduceF::TempStorage f;
|
|
||||||
typename BlockReduceF2::TempStorage f2;
|
|
||||||
} temp;
|
|
||||||
|
|
||||||
x += grid.block_rank() * axis_size;
|
x += grid.block_rank() * axis_size;
|
||||||
g += grid.block_rank() * axis_size;
|
g += grid.block_rank() * axis_size;
|
||||||
@@ -134,7 +246,7 @@ __global__ void rms_norm_vjp(
|
|||||||
factors = plus_f2(factors, {wg * t, t * t});
|
factors = plus_f2(factors, {wg * t, t * t});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
|
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
|
||||||
float meangwx = factors.x / axis_size;
|
float meangwx = factors.x / axis_size;
|
||||||
float normalizer = rsqrt(factors.y / axis_size + eps);
|
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||||
float normalizer3 = normalizer * normalizer * normalizer;
|
float normalizer3 = normalizer * normalizer * normalizer;
|
||||||
@@ -169,6 +281,43 @@ bool RMSNorm::use_fallback(Stream s) {
|
|||||||
return s.device == Device::cpu;
|
return s.device == Device::cpu;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int n_per_thread, typename F>
|
||||||
|
void dispatch_group_dim(int axis_size, F&& f) {
|
||||||
|
if (axis_size <= n_per_thread * 8) {
|
||||||
|
f(std::integral_constant<int, 8>{},
|
||||||
|
std::integral_constant<int, 1>(),
|
||||||
|
std::integral_constant<int, 16>());
|
||||||
|
} else if (axis_size <= n_per_thread * 16) {
|
||||||
|
f(std::integral_constant<int, 16>{},
|
||||||
|
std::integral_constant<int, 1>(),
|
||||||
|
std::integral_constant<int, 8>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 1>(),
|
||||||
|
std::integral_constant<int, 4>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32 * 2) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 2>(),
|
||||||
|
std::integral_constant<int, 2>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32 * 4) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 4>(),
|
||||||
|
std::integral_constant<int, 1>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32 * 8) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 8>(),
|
||||||
|
std::integral_constant<int, 1>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32 * 16) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 16>(),
|
||||||
|
std::integral_constant<int, 1>());
|
||||||
|
} else {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 32>(),
|
||||||
|
std::integral_constant<int, 1>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||||
void RMSNorm::eval_gpu(
|
void RMSNorm::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
@@ -216,12 +365,33 @@ void RMSNorm::eval_gpu(
|
|||||||
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
constexpr int N_READS = 16 / sizeof(DataType);
|
constexpr int N_READS = 16 / sizeof(DataType);
|
||||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
if (axis_size <= N_READS * 1024) {
|
||||||
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
|
dispatch_group_dim<N_READS>(
|
||||||
|
axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) {
|
||||||
|
constexpr int block_dim = n_groups() * group_dim();
|
||||||
|
auto kernel =
|
||||||
|
cu::rms_norm_small<DataType, block_dim, group_dim(), N_READS>;
|
||||||
|
auto n_blocks =
|
||||||
|
(n_rows + groups_per_block() - 1) / groups_per_block();
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
n_blocks,
|
||||||
|
{block_dim, groups_per_block()},
|
||||||
|
0,
|
||||||
|
gpu_ptr<DataType>(x),
|
||||||
|
gpu_ptr<DataType>(w),
|
||||||
|
gpu_ptr<DataType>(out),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
n_rows,
|
||||||
|
w_stride);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
auto kernel = cu::rms_norm<DataType, 1024, N_READS>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
1024,
|
||||||
0,
|
0,
|
||||||
gpu_ptr<DataType>(x),
|
gpu_ptr<DataType>(x),
|
||||||
gpu_ptr<DataType>(w),
|
gpu_ptr<DataType>(w),
|
||||||
@@ -229,7 +399,7 @@ void RMSNorm::eval_gpu(
|
|||||||
eps_,
|
eps_,
|
||||||
axis_size,
|
axis_size,
|
||||||
w_stride);
|
w_stride);
|
||||||
});
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -306,27 +476,51 @@ void RMSNormVJP::eval_gpu(
|
|||||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
constexpr int N_READS = 16 / sizeof(DataType);
|
constexpr int N_READS = 16 / sizeof(DataType);
|
||||||
dispatch_block_dim(
|
if (axis_size <= N_READS * 1024) {
|
||||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
dispatch_group_dim<N_READS>(
|
||||||
auto kernel = cu::rms_norm_vjp<
|
axis_size,
|
||||||
DataType,
|
[&](auto group_dim, auto n_groups, auto groups_per_block) {
|
||||||
has_w_constant.value,
|
constexpr int block_dim = group_dim() * n_groups();
|
||||||
block_dim(),
|
auto kernel = cu::rms_norm_vjp_small<
|
||||||
N_READS>;
|
DataType,
|
||||||
encoder.add_kernel_node(
|
has_w_constant.value,
|
||||||
kernel,
|
block_dim,
|
||||||
n_rows,
|
group_dim(),
|
||||||
block_dim(),
|
N_READS>;
|
||||||
0,
|
auto n_blocks =
|
||||||
gpu_ptr<DataType>(x),
|
(n_rows + groups_per_block() - 1) / groups_per_block();
|
||||||
gpu_ptr<DataType>(w),
|
encoder.add_kernel_node(
|
||||||
gpu_ptr<DataType>(g),
|
kernel,
|
||||||
gpu_ptr<DataType>(gx),
|
n_blocks,
|
||||||
gpu_ptr<DataType>(gw_temp),
|
{block_dim, groups_per_block()},
|
||||||
eps_,
|
0,
|
||||||
axis_size,
|
gpu_ptr<DataType>(x),
|
||||||
w_stride);
|
gpu_ptr<DataType>(w),
|
||||||
});
|
gpu_ptr<DataType>(g),
|
||||||
|
gpu_ptr<DataType>(gx),
|
||||||
|
gpu_ptr<DataType>(gw_temp),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
n_rows,
|
||||||
|
w_stride);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
auto kernel =
|
||||||
|
cu::rms_norm_vjp<DataType, has_w_constant.value, 1024, N_READS>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
n_rows,
|
||||||
|
1024,
|
||||||
|
0,
|
||||||
|
gpu_ptr<DataType>(x),
|
||||||
|
gpu_ptr<DataType>(w),
|
||||||
|
gpu_ptr<DataType>(g),
|
||||||
|
gpu_ptr<DataType>(gx),
|
||||||
|
gpu_ptr<DataType>(gw_temp),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
w_stride);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user