diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index a19e752bd..67a682c90 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -2,6 +2,7 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/config.h" +#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/gpu/copy.h" @@ -217,11 +218,11 @@ __global__ void kernel_sdpav_2pass_1( U k[v_per_thread]; U o[v_per_thread]; - __shared__ U outputs[BD][BN + 1]; + __shared__ U outputs[BN][BD + 1]; __shared__ U max_scores[BN]; __shared__ U sum_exp_scores[BN]; - const U scale_log2 = params.scale; // * 1.44269504089f; + const U scale_log2 = params.scale * 1.44269504089f; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition<32>(block); @@ -230,7 +231,7 @@ __global__ void kernel_sdpav_2pass_1( const int warp_idx = warp.meta_group_rank(); // Adjust to thread block and thread - const int batch_idx = 0; // blockIdx.z / blocks; + const int batch_idx = blockIdx.z / blocks; const int block_idx = blockIdx.z % blocks; const int head_idx = blockIdx.x; const int kv_head_idx = head_idx / params.gqa_factor; @@ -302,8 +303,8 @@ __global__ void kernel_sdpav_2pass_1( // Update the accumulators U new_max = max(max_score, score); - U factor = expf(max_score - new_max); - U exp_score = expf(score - new_max); + U factor = exp2f(max_score - new_max); + U exp_score = exp2f(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; @@ -330,7 +331,7 @@ __global__ void kernel_sdpav_2pass_1( max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9; U new_max = cg::reduce(warp, max_score, cg::greater()); - U factor = expf(max_score - new_max); + U factor = exp2f(max_score - new_max); sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f; sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus()); @@ -341,31 +342,30 @@ __global__ void kernel_sdpav_2pass_1( } // Now we need to aggregate all the outputs + auto ff = exp2f(max_scores[warp_idx] - new_max); PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { - outputs[lane_idx][warp_idx] = o[i] * expf(max_scores[warp_idx] - new_max); + outputs[warp_idx][lane_idx] = o[i] * ff; block.sync(); if (warp_idx == 0) { - U ot = outputs[lane_idx][0]; - + U ot = outputs[0][lane_idx]; PRAGMA_LOOP_UNROLL for (int j = 1; j < BN; j++) { - ot += outputs[lane_idx][0]; + ot += outputs[j][lane_idx]; + warp.sync(); } - - // o[i] = ot; - partials[v_per_thread * lane_idx + i] = ot; + o[i] = ot; } block.sync(); } - // if(warp_idx == 0) { - // PRAGMA_LOOP_UNROLL - // for (int i = 0; i < v_per_thread; i++) { - // partials[v_per_thread * lane_idx + i] = o[i]; - // } - // } + if (warp_idx == 0) { + PRAGMA_LOOP_UNROLL + for (int i = 0; i < v_per_thread; i++) { + partials[v_per_thread * lane_idx + i] = o[i]; + } + } } template @@ -414,9 +414,9 @@ __global__ void kernel_sdpav_2pass_2( U max_score = maxs[lane_idx]; U new_max = cg::reduce(warp, max_score, cg::greater()); - U factor = expf(max_score - new_max); + U factor = exp2f(max_score - new_max); U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus()); - // sum_exp_score = __frcp_rn(sum_exp_score); + sum_exp_score = __frcp_rn(sum_exp_score); PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { @@ -429,7 +429,7 @@ __global__ void kernel_sdpav_2pass_2( outputs[lane_idx][warp_idx] = o[i]; block.sync(); U ot = outputs[warp_idx][lane_idx] * factor; - o[i] = cg::reduce(warp, ot, cg::plus()) / sum_exp_score; + o[i] = cg::reduce(warp, ot, cg::plus()) * sum_exp_score; block.sync(); } @@ -504,6 +504,7 @@ void sdpa_vector_1pass_fallback( kernel, grid_dim, block_dim, + 0, q.data(), k.data(), v.data(), @@ -585,6 +586,7 @@ void sdpa_vector_2pass_fallback( kernel, grid_dim, block_dim, + 0, q.data(), k.data(), v.data(), @@ -610,6 +612,7 @@ void sdpa_vector_2pass_fallback( kernel, grid_dim, block_dim, + 0, intermediate.data(), sums.data(), maxs.data(), @@ -632,7 +635,7 @@ void sdpa_vector_fallback( bool do_causal_ = false) { int kL = k.shape(2); - if (false && kL > 1024) { + if (kL > 1024) { return sdpa_vector_2pass_fallback( s, encoder, q, k, v, scale, o, do_causal_); } else { @@ -1034,7 +1037,23 @@ void ScaledDotProductAttention::eval_gpu( if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { o.copy_shared_buffer(q); } else { - o.set_data(allocator::malloc(o.nbytes())); + int64_t str_oD = 1; + int64_t str_oH = o.shape(3); + int64_t str_oL = o.shape(1) * str_oH; + int64_t str_oB = o.shape(2) * str_oL; + size_t data_size = o.shape(0) * str_oB; + + array::Flags flags{ + /* bool contiguous = */ 1, + /* bool row_contiguous = */ 0, + /* bool col_contiguous = */ 0, + }; + + o.set_data( + allocator::malloc(o.nbytes()), + data_size, + {str_oB, str_oH, str_oL, str_oD}, + flags); } return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);