mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix compilation with CUDA 11 (#2331)
This commit is contained in:
@@ -74,7 +74,7 @@ __global__ void rms_norm(
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
normalizer += t * t;
|
||||
@@ -130,7 +130,7 @@ __global__ void rms_norm_vjp(
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
|
||||
Reference in New Issue
Block a user