faster rms norm (#2433)

This commit is contained in:
Awni Hannun
2025-07-29 13:12:00 -07:00
committed by GitHub
parent 970dbe8e25
commit ef631d63af
11 changed files with 210 additions and 112 deletions

View File

@@ -32,7 +32,7 @@ ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
AlignedVector<T, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]);
out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -166,8 +166,7 @@ void ternary_op_gpu_inplace(
} else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
constexpr int N_READS = 16 / sizeof(DType);
auto kernel = cu::ternary_v<Op, DType, IdxT, N_READS>;
auto [num_blocks, block_dims] = get_launch_args(
kernel,