mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
faster rms norm (#2433)
This commit is contained in:
@@ -22,7 +22,7 @@ __global__ void copy_s(const In* in, Out* out, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = cast_to<Out>(in[0]);
|
||||
out_vec[i] = cast_to<Out>(in[0]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
@@ -43,7 +43,7 @@ __global__ void copy_v(const In* in, Out* out, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = cast_to<Out>(in_vec.val[i]);
|
||||
out_vec[i] = cast_to<Out>(in_vec[i]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
@@ -65,8 +65,7 @@ void copy_contiguous(
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
// TODO: Choose optimized value based on type size.
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int N_READS = 16 / sizeof(InType);
|
||||
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
|
||||
if (ctype == CopyType::Vector) {
|
||||
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
|
||||
|
||||
Reference in New Issue
Block a user