mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
faster rms norm (#2433)
This commit is contained in:
@@ -31,8 +31,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
|
||||
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < n_per_thread; ++j) {
|
||||
sum += static_cast<float>(local_mat.val[j]) *
|
||||
static_cast<float>(local_vec.val[j]);
|
||||
sum +=
|
||||
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,8 +73,7 @@ __global__ void gemv_batched(
|
||||
}
|
||||
|
||||
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
|
||||
bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0;
|
||||
return is_multiple && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
||||
return K % 32 == 0 && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
|
||||
Reference in New Issue
Block a user