diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 852cf43af..9a9fbcb37 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -364,7 +364,7 @@ void LayerNormVJP::eval_gpu( using DataType = cuda_type_t; auto kernel = cu::layer_norm_vjp< DataType, - has_w_constant(), + has_w_constant.value, block_dim(), N_READS>; kernel<<>>( diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 4578dbad0..deb4a2f91 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -319,10 +319,10 @@ void row_reduce_looped( T, U, OP, - reduce_ndim(), - threads_constant(), + reduce_ndim.value, + threads_constant.value, N_READS>; - block.x = threads_constant(); + block.x = threads_constant.value; }); }); diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 7f5f9630d..fc8f4f490 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -320,7 +320,7 @@ void RMSNormVJP::eval_gpu( constexpr int N_READS = 4; auto kernel = cu::rms_norm_vjp< DataType, - has_w_constant(), + has_w_constant.value, block_dim(), N_READS>; kernel<<>>( diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index a7d7b27ce..bb9618fc4 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -315,7 +315,8 @@ void RoPE::eval_gpu( dispatch_bool(forward_, [&](auto forward) { using DataType = cuda_type_t; if (single && !with_freqs) { - auto kernel = cu::rope_single; + auto kernel = + cu::rope_single; uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); kernel<<>>( @@ -327,8 +328,8 @@ void RoPE::eval_gpu( mat_size, dims); } else if (single) { - auto kernel = - cu::rope_single_freqs; + auto kernel = cu:: + rope_single_freqs; uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); kernel<<>>( @@ -341,7 +342,8 @@ void RoPE::eval_gpu( dims, inputs[2].strides(0)); } else if (with_freqs) { - auto kernel = cu::rope_freqs; + auto kernel = + cu::rope_freqs; uint3 dims = make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); dims.z = (dims.z + 3) / 4; @@ -359,7 +361,7 @@ void RoPE::eval_gpu( dims, inputs[2].strides(0)); } else { - auto kernel = cu::rope; + auto kernel = cu::rope; uint3 dims = make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); dims.z = (dims.z + 3) / 4;