mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	[WIP] 2 pass sdpav
This commit is contained in:
		| @@ -4,6 +4,8 @@ | |||||||
| #include "mlx/fast_primitives.h" | #include "mlx/fast_primitives.h" | ||||||
| #include "mlx/primitives.h" | #include "mlx/primitives.h" | ||||||
|  |  | ||||||
|  | namespace mlx::core { | ||||||
|  |  | ||||||
| #define NO_GPU_MULTI(func)                                             \ | #define NO_GPU_MULTI(func)                                             \ | ||||||
|   void func::eval_gpu(                                                 \ |   void func::eval_gpu(                                                 \ | ||||||
|       const std::vector<array>& inputs, std::vector<array>& outputs) { \ |       const std::vector<array>& inputs, std::vector<array>& outputs) { \ | ||||||
|   | |||||||
| @@ -15,14 +15,632 @@ | |||||||
| #include <fmt/format.h> | #include <fmt/format.h> | ||||||
| #include <nvtx3/nvtx3.hpp> | #include <nvtx3/nvtx3.hpp> | ||||||
|  |  | ||||||
|  | #include <cooperative_groups.h> | ||||||
|  | #include <cooperative_groups/reduce.h> | ||||||
|  |  | ||||||
| namespace fe = cudnn_frontend; | namespace fe = cudnn_frontend; | ||||||
|  |  | ||||||
| namespace mlx::core { | namespace mlx::core { | ||||||
|  |  | ||||||
| namespace cu {} // namespace cu | namespace cu { | ||||||
|  |  | ||||||
|  | namespace cg = cooperative_groups; | ||||||
|  |  | ||||||
|  | #define PRAGMA_LOOP_UNROLL #pragma unroll | ||||||
|  |  | ||||||
|  | struct AttnParams { | ||||||
|  |   int B; | ||||||
|  |   int H; | ||||||
|  |   int D; | ||||||
|  |  | ||||||
|  |   int qL; | ||||||
|  |   int kL; | ||||||
|  |  | ||||||
|  |   int gqa_factor; | ||||||
|  |   float scale; | ||||||
|  |  | ||||||
|  |   int64_t Q_strides[3]; | ||||||
|  |   int64_t K_strides[3]; | ||||||
|  |   int64_t V_strides[3]; | ||||||
|  |   int64_t O_strides[3]; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <typename T, bool do_causal, int D> | ||||||
|  | __global__ void kernel_sdpav_1pass( | ||||||
|  |     const T* Q, | ||||||
|  |     const T* K, | ||||||
|  |     const T* V, | ||||||
|  |     T* O, | ||||||
|  |     __grid_constant__ const AttnParams params) { | ||||||
|  |   constexpr int BN = 32; | ||||||
|  |   constexpr int BD = 32; | ||||||
|  |  | ||||||
|  |   constexpr int v_per_thread = D / BD; | ||||||
|  |  | ||||||
|  |   const int inner_k_stride = BN * int(params.K_strides[2]); | ||||||
|  |   const int inner_v_stride = BN * int(params.V_strides[2]); | ||||||
|  |  | ||||||
|  |   typedef float U; | ||||||
|  |  | ||||||
|  |   U q[v_per_thread]; | ||||||
|  |   U k[v_per_thread]; | ||||||
|  |   U o[v_per_thread]; | ||||||
|  |  | ||||||
|  |   __shared__ U outputs[BN][BD + 1]; | ||||||
|  |   __shared__ U max_scores[BN]; | ||||||
|  |   __shared__ U sum_exp_scores[BN]; | ||||||
|  |  | ||||||
|  |   const U scale_log2 = params.scale * 1.44269504089f; | ||||||
|  |  | ||||||
|  |   auto block = cg::this_thread_block(); | ||||||
|  |   auto warp = cg::tiled_partition<32>(block); | ||||||
|  |  | ||||||
|  |   const int lane_idx = warp.thread_rank(); | ||||||
|  |   const int warp_idx = warp.meta_group_rank(); | ||||||
|  |  | ||||||
|  |   // Adjust to thread block and thread | ||||||
|  |   const int batch_idx = blockIdx.z; | ||||||
|  |   const int head_idx = blockIdx.x; | ||||||
|  |   const int kv_head_idx = head_idx / params.gqa_factor; | ||||||
|  |  | ||||||
|  |   const int q_seq_idx = blockIdx.y; | ||||||
|  |   const int kv_seq_idx = warp_idx; | ||||||
|  |  | ||||||
|  |   Q += batch_idx * params.Q_strides[0] + // Batch | ||||||
|  |       head_idx * params.Q_strides[1] + // Head | ||||||
|  |       q_seq_idx * params.Q_strides[2]; // Sequence | ||||||
|  |  | ||||||
|  |   K += batch_idx * params.K_strides[0] + // Batch | ||||||
|  |       kv_head_idx * params.K_strides[1] + // Head | ||||||
|  |       kv_seq_idx * params.K_strides[2]; // Sequence | ||||||
|  |  | ||||||
|  |   V += batch_idx * params.V_strides[0] + // Batch | ||||||
|  |       kv_head_idx * params.V_strides[1] + // Head | ||||||
|  |       kv_seq_idx * params.V_strides[2]; // Sequence | ||||||
|  |  | ||||||
|  |   O += batch_idx * params.O_strides[0] + // Batch | ||||||
|  |       head_idx * params.O_strides[1] + // Head | ||||||
|  |       q_seq_idx * params.O_strides[2]; // Sequence | ||||||
|  |  | ||||||
|  |   // Read the query and 0 the output accumulator | ||||||
|  |   PRAGMA_LOOP_UNROLL | ||||||
|  |   for (int i = 0; i < v_per_thread; i++) { | ||||||
|  |     q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   PRAGMA_LOOP_UNROLL | ||||||
|  |   for (int i = 0; i < v_per_thread; i++) { | ||||||
|  |     o[i] = 0.f; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   U max_score = -INFINITY; | ||||||
|  |   U sum_exp_score = 0.f; | ||||||
|  |  | ||||||
|  |   // For each key | ||||||
|  |   for (int i = kv_seq_idx; i < params.kL; i += BN) { | ||||||
|  |     bool use_key = true; | ||||||
|  |     if constexpr (do_causal) { | ||||||
|  |       use_key = i <= (params.kL - params.qL + q_seq_idx); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (use_key) { | ||||||
|  |       // Read the key | ||||||
|  |       PRAGMA_LOOP_UNROLL | ||||||
|  |       for (int j = 0; j < v_per_thread; j++) { | ||||||
|  |         k[j] = K[v_per_thread * lane_idx + j]; | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       // Compute the i-th score | ||||||
|  |       U score = 0.f; | ||||||
|  |       PRAGMA_LOOP_UNROLL | ||||||
|  |       for (int j = 0; j < v_per_thread; j++) { | ||||||
|  |         score += q[j] * k[j]; | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       // Warp sum | ||||||
|  |       score = cg::reduce(warp, score, cg::plus<U>()); | ||||||
|  |  | ||||||
|  |       // Update the accumulators | ||||||
|  |       U new_max = max(max_score, score); | ||||||
|  |       U factor = exp2f(max_score - new_max); | ||||||
|  |       U exp_score = exp2f(score - new_max); | ||||||
|  |  | ||||||
|  |       max_score = new_max; | ||||||
|  |       sum_exp_score = sum_exp_score * factor + exp_score; | ||||||
|  |  | ||||||
|  |       // Update the output accumulator | ||||||
|  |       PRAGMA_LOOP_UNROLL | ||||||
|  |       for (int j = 0; j < v_per_thread; j++) { | ||||||
|  |         o[j] = o[j] * factor + | ||||||
|  |             exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Move the pointers to the next kv | ||||||
|  |     K += inner_k_stride; | ||||||
|  |     V += inner_v_stride; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (lane_idx == 0) { | ||||||
|  |     max_scores[warp_idx] = max_score; | ||||||
|  |     sum_exp_scores[warp_idx] = sum_exp_score; | ||||||
|  |   } | ||||||
|  |   block.sync(); | ||||||
|  |  | ||||||
|  |   max_score = max_scores[lane_idx]; | ||||||
|  |   U new_max = cg::reduce(warp, max_score, cg::greater<U>()); | ||||||
|  |   U factor = exp2f(max_score - new_max); | ||||||
|  |   sum_exp_score = | ||||||
|  |       cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>()); | ||||||
|  |   sum_exp_score = __frcp_rn(sum_exp_score); | ||||||
|  |  | ||||||
|  |   // Now we need to aggregate all the outputs | ||||||
|  |   PRAGMA_LOOP_UNROLL | ||||||
|  |   for (int i = 0; i < v_per_thread; i++) { | ||||||
|  |     outputs[lane_idx][warp_idx] = o[i]; | ||||||
|  |     block.sync(); | ||||||
|  |     U ot = outputs[warp_idx][lane_idx] * factor; | ||||||
|  |     o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score; | ||||||
|  |     block.sync(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // And write the output | ||||||
|  |   if (lane_idx == 0) { | ||||||
|  |     PRAGMA_LOOP_UNROLL | ||||||
|  |     for (int i = 0; i < v_per_thread; i++) { | ||||||
|  |       O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T, bool do_causal, int D> | ||||||
|  | __global__ void kernel_sdpav_2pass_1( | ||||||
|  |     const T* Q, | ||||||
|  |     const T* K, | ||||||
|  |     const T* V, | ||||||
|  |     float* partials, | ||||||
|  |     float* sums, | ||||||
|  |     float* maxs, | ||||||
|  |     __grid_constant__ const AttnParams params) { | ||||||
|  |   constexpr int BN = 8; | ||||||
|  |   constexpr int BD = 32; | ||||||
|  |   constexpr int blocks = 32; | ||||||
|  |  | ||||||
|  |   constexpr int v_per_thread = D / BD; | ||||||
|  |  | ||||||
|  |   const int inner_k_stride = blocks * BN * int(params.K_strides[2]); | ||||||
|  |   const int inner_v_stride = blocks * BN * int(params.V_strides[2]); | ||||||
|  |  | ||||||
|  |   typedef float U; | ||||||
|  |  | ||||||
|  |   U q[v_per_thread]; | ||||||
|  |   U k[v_per_thread]; | ||||||
|  |   U o[v_per_thread]; | ||||||
|  |  | ||||||
|  |   __shared__ U outputs[BD][BN + 1]; | ||||||
|  |   __shared__ U max_scores[BN]; | ||||||
|  |   __shared__ U sum_exp_scores[BN]; | ||||||
|  |  | ||||||
|  |   const U scale_log2 = params.scale; // * 1.44269504089f; | ||||||
|  |  | ||||||
|  |   auto block = cg::this_thread_block(); | ||||||
|  |   auto warp = cg::tiled_partition<32>(block); | ||||||
|  |  | ||||||
|  |   const int lane_idx = warp.thread_rank(); | ||||||
|  |   const int warp_idx = warp.meta_group_rank(); | ||||||
|  |  | ||||||
|  |   // Adjust to thread block and thread | ||||||
|  |   const int batch_idx = 0; // blockIdx.z / blocks; | ||||||
|  |   const int block_idx = blockIdx.z % blocks; | ||||||
|  |   const int head_idx = blockIdx.x; | ||||||
|  |   const int kv_head_idx = head_idx / params.gqa_factor; | ||||||
|  |  | ||||||
|  |   const int q_seq_idx = blockIdx.y; | ||||||
|  |   const int kv_seq_idx = block_idx * BN + warp_idx; | ||||||
|  |  | ||||||
|  |   Q += batch_idx * params.Q_strides[0] + // Batch | ||||||
|  |       head_idx * params.Q_strides[1] + // Head | ||||||
|  |       q_seq_idx * params.Q_strides[2]; // Sequence | ||||||
|  |  | ||||||
|  |   K += batch_idx * params.K_strides[0] + // Batch | ||||||
|  |       kv_head_idx * params.K_strides[1] + // Head | ||||||
|  |       kv_seq_idx * params.K_strides[2]; // Sequence | ||||||
|  |  | ||||||
|  |   V += batch_idx * params.V_strides[0] + // Batch | ||||||
|  |       kv_head_idx * params.V_strides[1] + // Head | ||||||
|  |       kv_seq_idx * params.V_strides[2]; // Sequence | ||||||
|  |  | ||||||
|  |   const int p_stride_s = blocks; | ||||||
|  |   const int p_stride_h = params.qL * p_stride_s; | ||||||
|  |   const int p_stride_b = params.H * p_stride_h; | ||||||
|  |   const int p_offset = batch_idx * p_stride_b + // Batch | ||||||
|  |       head_idx * p_stride_h + // Head | ||||||
|  |       q_seq_idx * p_stride_s + // Sequence | ||||||
|  |       block_idx; // Block | ||||||
|  |  | ||||||
|  |   partials += p_offset * D; | ||||||
|  |   sums += p_offset; | ||||||
|  |   maxs += p_offset; | ||||||
|  |  | ||||||
|  |   // Read the query and 0 the output accumulator | ||||||
|  |   PRAGMA_LOOP_UNROLL | ||||||
|  |   for (int i = 0; i < v_per_thread; i++) { | ||||||
|  |     q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   PRAGMA_LOOP_UNROLL | ||||||
|  |   for (int i = 0; i < v_per_thread; i++) { | ||||||
|  |     o[i] = 0.f; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   U max_score = -1e9; | ||||||
|  |   U sum_exp_score = 0.f; | ||||||
|  |  | ||||||
|  |   // For each key | ||||||
|  |   for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) { | ||||||
|  |     bool use_key = true; | ||||||
|  |     if constexpr (do_causal) { | ||||||
|  |       use_key = i <= (params.kL - params.qL + q_seq_idx); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (use_key) { | ||||||
|  |       // Read the key | ||||||
|  |       PRAGMA_LOOP_UNROLL | ||||||
|  |       for (int j = 0; j < v_per_thread; j++) { | ||||||
|  |         k[j] = K[v_per_thread * lane_idx + j]; | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       // Compute the i-th score | ||||||
|  |       U score = 0.f; | ||||||
|  |       PRAGMA_LOOP_UNROLL | ||||||
|  |       for (int j = 0; j < v_per_thread; j++) { | ||||||
|  |         score += q[j] * k[j]; | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       // Warp sum | ||||||
|  |       score = cg::reduce(warp, score, cg::plus<U>()); | ||||||
|  |  | ||||||
|  |       // Update the accumulators | ||||||
|  |       U new_max = max(max_score, score); | ||||||
|  |       U factor = expf(max_score - new_max); | ||||||
|  |       U exp_score = expf(score - new_max); | ||||||
|  |  | ||||||
|  |       max_score = new_max; | ||||||
|  |       sum_exp_score = sum_exp_score * factor + exp_score; | ||||||
|  |  | ||||||
|  |       // Update the output accumulator | ||||||
|  |       PRAGMA_LOOP_UNROLL | ||||||
|  |       for (int j = 0; j < v_per_thread; j++) { | ||||||
|  |         o[j] = o[j] * factor + | ||||||
|  |             exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Move the pointers to the next kv | ||||||
|  |     K += inner_k_stride; | ||||||
|  |     V += inner_v_stride; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (lane_idx == 0) { | ||||||
|  |     max_scores[warp_idx] = max_score; | ||||||
|  |     sum_exp_scores[warp_idx] = sum_exp_score; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   block.sync(); | ||||||
|  |  | ||||||
|  |   max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9; | ||||||
|  |   U new_max = cg::reduce(warp, max_score, cg::greater<U>()); | ||||||
|  |   U factor = expf(max_score - new_max); | ||||||
|  |   sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f; | ||||||
|  |   sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>()); | ||||||
|  |  | ||||||
|  |   // Write the sum and new max | ||||||
|  |   if (warp_idx == 0) { | ||||||
|  |     sums[0] = sum_exp_score; | ||||||
|  |     maxs[0] = new_max; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Now we need to aggregate all the outputs | ||||||
|  |   PRAGMA_LOOP_UNROLL | ||||||
|  |   for (int i = 0; i < v_per_thread; i++) { | ||||||
|  |     outputs[lane_idx][warp_idx] = o[i] * expf(max_scores[warp_idx] - new_max); | ||||||
|  |     block.sync(); | ||||||
|  |  | ||||||
|  |     if (warp_idx == 0) { | ||||||
|  |       U ot = outputs[lane_idx][0]; | ||||||
|  |  | ||||||
|  |       PRAGMA_LOOP_UNROLL | ||||||
|  |       for (int j = 1; j < BN; j++) { | ||||||
|  |         ot += outputs[lane_idx][0]; | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       // o[i] = ot; | ||||||
|  |       partials[v_per_thread * lane_idx + i] = ot; | ||||||
|  |     } | ||||||
|  |     block.sync(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // if(warp_idx == 0) { | ||||||
|  |   //   PRAGMA_LOOP_UNROLL | ||||||
|  |   //   for (int i = 0; i < v_per_thread; i++) { | ||||||
|  |   //     partials[v_per_thread * lane_idx + i] = o[i]; | ||||||
|  |   //   } | ||||||
|  |   // } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T, bool do_causal, int D> | ||||||
|  | __global__ void kernel_sdpav_2pass_2( | ||||||
|  |     const float* partials, | ||||||
|  |     const float* sums, | ||||||
|  |     const float* maxs, | ||||||
|  |     T* O, | ||||||
|  |     __grid_constant__ const AttnParams params) { | ||||||
|  |   constexpr int BN = 32; | ||||||
|  |   constexpr int BD = 32; | ||||||
|  |   constexpr int blocks = 32; | ||||||
|  |  | ||||||
|  |   constexpr int v_per_thread = D / BD; | ||||||
|  |  | ||||||
|  |   typedef float U; | ||||||
|  |  | ||||||
|  |   U o[v_per_thread]; | ||||||
|  |   __shared__ U outputs[BN][BD + 1]; | ||||||
|  |  | ||||||
|  |   auto block = cg::this_thread_block(); | ||||||
|  |   auto warp = cg::tiled_partition<32>(block); | ||||||
|  |  | ||||||
|  |   const int lane_idx = warp.thread_rank(); | ||||||
|  |   const int warp_idx = warp.meta_group_rank(); | ||||||
|  |  | ||||||
|  |   // Adjust to thread block and thread | ||||||
|  |   const int batch_idx = blockIdx.z; | ||||||
|  |   const int head_idx = blockIdx.x; | ||||||
|  |   const int q_seq_idx = blockIdx.y; | ||||||
|  |  | ||||||
|  |   const int p_stride_s = blocks; | ||||||
|  |   const int p_stride_h = params.qL * p_stride_s; | ||||||
|  |   const int p_stride_b = params.H * p_stride_h; | ||||||
|  |   const int p_offset = batch_idx * p_stride_b + // Batch | ||||||
|  |       head_idx * p_stride_h + // Head | ||||||
|  |       q_seq_idx * p_stride_s; // Sequence | ||||||
|  |  | ||||||
|  |   partials += p_offset * D + warp_idx * D; | ||||||
|  |   sums += p_offset; | ||||||
|  |   maxs += p_offset; | ||||||
|  |  | ||||||
|  |   O += batch_idx * params.O_strides[0] + // Batch | ||||||
|  |       head_idx * params.O_strides[1] + // Head | ||||||
|  |       q_seq_idx * params.O_strides[2]; // Sequence | ||||||
|  |  | ||||||
|  |   U max_score = maxs[lane_idx]; | ||||||
|  |   U new_max = cg::reduce(warp, max_score, cg::greater<U>()); | ||||||
|  |   U factor = expf(max_score - new_max); | ||||||
|  |   U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>()); | ||||||
|  |   // sum_exp_score = __frcp_rn(sum_exp_score); | ||||||
|  |  | ||||||
|  |   PRAGMA_LOOP_UNROLL | ||||||
|  |   for (int i = 0; i < v_per_thread; i++) { | ||||||
|  |     o[i] = partials[v_per_thread * lane_idx + i]; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Now we need to aggregate all the outputs | ||||||
|  |   PRAGMA_LOOP_UNROLL | ||||||
|  |   for (int i = 0; i < v_per_thread; i++) { | ||||||
|  |     outputs[lane_idx][warp_idx] = o[i]; | ||||||
|  |     block.sync(); | ||||||
|  |     U ot = outputs[warp_idx][lane_idx] * factor; | ||||||
|  |     o[i] = cg::reduce(warp, ot, cg::plus<U>()) / sum_exp_score; | ||||||
|  |     block.sync(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // And write the output | ||||||
|  |   if (lane_idx == 0) { | ||||||
|  |     PRAGMA_LOOP_UNROLL | ||||||
|  |     for (int i = 0; i < v_per_thread; i++) { | ||||||
|  |       O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace cu | ||||||
|  |  | ||||||
| namespace { | namespace { | ||||||
|  |  | ||||||
|  | template <typename F> | ||||||
|  | void dispatch_headdim(int n, F&& f) { | ||||||
|  |   switch (n) { | ||||||
|  |     case 64: | ||||||
|  |       f(std::integral_constant<int, 64>{}); | ||||||
|  |       break; | ||||||
|  |     case 96: | ||||||
|  |       f(std::integral_constant<int, 96>{}); | ||||||
|  |       break; | ||||||
|  |     case 128: | ||||||
|  |       f(std::integral_constant<int, 128>{}); | ||||||
|  |       break; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void sdpa_vector_1pass_fallback( | ||||||
|  |     const Stream& s, | ||||||
|  |     cu::CommandEncoder& encoder, | ||||||
|  |     const array& q, | ||||||
|  |     const array& k, | ||||||
|  |     const array& v, | ||||||
|  |     const float scale, | ||||||
|  |     array& o, | ||||||
|  |     bool do_causal_ = false) { | ||||||
|  |   encoder.set_input_array(q); | ||||||
|  |   encoder.set_input_array(k); | ||||||
|  |   encoder.set_input_array(v); | ||||||
|  |   encoder.set_output_array(o); | ||||||
|  |  | ||||||
|  |   cu::AttnParams params{ | ||||||
|  |       /* int B = */ q.shape(0), | ||||||
|  |       /* int H = */ q.shape(1), | ||||||
|  |       /* int D = */ q.shape(3), | ||||||
|  |  | ||||||
|  |       /* int qL = */ q.shape(2), | ||||||
|  |       /* int kL = */ k.shape(2), | ||||||
|  |  | ||||||
|  |       /* int gqa_factor = */ q.shape(1) / k.shape(1), | ||||||
|  |       /* float scale = */ scale, | ||||||
|  |  | ||||||
|  |       /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, | ||||||
|  |       /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, | ||||||
|  |       /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, | ||||||
|  |       /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; | ||||||
|  |  | ||||||
|  |   dim3 grid_dim(params.H, params.qL, params.B); | ||||||
|  |   dim3 block_dim(1024, 1, 1); | ||||||
|  |  | ||||||
|  |   dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) { | ||||||
|  |     dispatch_bool(do_causal_, [&](auto do_causal) { | ||||||
|  |       dispatch_headdim(params.D, [&](auto headdim) { | ||||||
|  |         using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; | ||||||
|  |  | ||||||
|  |         auto kernel = cu::kernel_sdpav_1pass<DataType, do_causal(), headdim()>; | ||||||
|  |         encoder.add_kernel_node( | ||||||
|  |             kernel, | ||||||
|  |             grid_dim, | ||||||
|  |             block_dim, | ||||||
|  |             q.data<DataType>(), | ||||||
|  |             k.data<DataType>(), | ||||||
|  |             v.data<DataType>(), | ||||||
|  |             o.data<DataType>(), | ||||||
|  |             params); | ||||||
|  |       }); | ||||||
|  |     }); | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void sdpa_vector_2pass_fallback( | ||||||
|  |     const Stream& s, | ||||||
|  |     cu::CommandEncoder& encoder, | ||||||
|  |     const array& q, | ||||||
|  |     const array& k, | ||||||
|  |     const array& v, | ||||||
|  |     const float scale, | ||||||
|  |     array& o, | ||||||
|  |     bool do_causal_ = false) { | ||||||
|  |   cu::AttnParams params{ | ||||||
|  |       /* int B = */ q.shape(0), | ||||||
|  |       /* int H = */ q.shape(1), | ||||||
|  |       /* int D = */ q.shape(3), | ||||||
|  |  | ||||||
|  |       /* int qL = */ q.shape(2), | ||||||
|  |       /* int kL = */ k.shape(2), | ||||||
|  |  | ||||||
|  |       /* int gqa_factor = */ q.shape(1) / k.shape(1), | ||||||
|  |       /* float scale = */ scale, | ||||||
|  |  | ||||||
|  |       /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, | ||||||
|  |       /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, | ||||||
|  |       /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, | ||||||
|  |       /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; | ||||||
|  |  | ||||||
|  |   // Allocate the intermediates | ||||||
|  |   int blocks = 32; | ||||||
|  |  | ||||||
|  |   Shape intermediate_shape; | ||||||
|  |   intermediate_shape.reserve(o.ndim() + 1); | ||||||
|  |   intermediate_shape.insert( | ||||||
|  |       intermediate_shape.end(), o.shape().begin(), o.shape().end() - 1); | ||||||
|  |   intermediate_shape.push_back(blocks); | ||||||
|  |   intermediate_shape.push_back(o.shape().back()); | ||||||
|  |  | ||||||
|  |   array intermediate(intermediate_shape, float32, nullptr, {}); | ||||||
|  |   intermediate_shape.pop_back(); | ||||||
|  |   array sums(intermediate_shape, float32, nullptr, {}); | ||||||
|  |   array maxs(std::move(intermediate_shape), float32, nullptr, {}); | ||||||
|  |  | ||||||
|  |   intermediate.set_data(allocator::malloc(intermediate.nbytes())); | ||||||
|  |   sums.set_data(allocator::malloc(sums.nbytes())); | ||||||
|  |   maxs.set_data(allocator::malloc(maxs.nbytes())); | ||||||
|  |  | ||||||
|  |   encoder.add_temporary(intermediate); | ||||||
|  |   encoder.add_temporary(sums); | ||||||
|  |   encoder.add_temporary(maxs); | ||||||
|  |  | ||||||
|  |   dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) { | ||||||
|  |     dispatch_bool(do_causal_, [&](auto do_causal) { | ||||||
|  |       dispatch_headdim(params.D, [&](auto headdim) { | ||||||
|  |         using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; | ||||||
|  |  | ||||||
|  |         { | ||||||
|  |           auto kernel = | ||||||
|  |               cu::kernel_sdpav_2pass_1<DataType, do_causal(), headdim()>; | ||||||
|  |  | ||||||
|  |           encoder.set_input_array(q); | ||||||
|  |           encoder.set_input_array(k); | ||||||
|  |           encoder.set_input_array(v); | ||||||
|  |           encoder.set_output_array(intermediate); | ||||||
|  |           encoder.set_output_array(sums); | ||||||
|  |           encoder.set_output_array(maxs); | ||||||
|  |  | ||||||
|  |           dim3 grid_dim(params.H, params.qL, params.B * 32); | ||||||
|  |           dim3 block_dim(8 * 32, 1, 1); | ||||||
|  |  | ||||||
|  |           encoder.add_kernel_node( | ||||||
|  |               kernel, | ||||||
|  |               grid_dim, | ||||||
|  |               block_dim, | ||||||
|  |               q.data<DataType>(), | ||||||
|  |               k.data<DataType>(), | ||||||
|  |               v.data<DataType>(), | ||||||
|  |               intermediate.data<float>(), | ||||||
|  |               sums.data<float>(), | ||||||
|  |               maxs.data<float>(), | ||||||
|  |               params); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         { | ||||||
|  |           auto kernel = | ||||||
|  |               cu::kernel_sdpav_2pass_2<DataType, do_causal(), headdim()>; | ||||||
|  |  | ||||||
|  |           encoder.set_input_array(intermediate); | ||||||
|  |           encoder.set_input_array(sums); | ||||||
|  |           encoder.set_input_array(maxs); | ||||||
|  |           encoder.set_output_array(o); | ||||||
|  |  | ||||||
|  |           dim3 grid_dim(params.H, params.qL, params.B); | ||||||
|  |           dim3 block_dim(1024, 1, 1); | ||||||
|  |  | ||||||
|  |           encoder.add_kernel_node( | ||||||
|  |               kernel, | ||||||
|  |               grid_dim, | ||||||
|  |               block_dim, | ||||||
|  |               intermediate.data<float>(), | ||||||
|  |               sums.data<float>(), | ||||||
|  |               maxs.data<float>(), | ||||||
|  |               o.data<DataType>(), | ||||||
|  |               params); | ||||||
|  |         } | ||||||
|  |       }); | ||||||
|  |     }); | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void sdpa_vector_fallback( | ||||||
|  |     const Stream& s, | ||||||
|  |     cu::CommandEncoder& encoder, | ||||||
|  |     const array& q, | ||||||
|  |     const array& k, | ||||||
|  |     const array& v, | ||||||
|  |     const float scale, | ||||||
|  |     array& o, | ||||||
|  |     bool do_causal_ = false) { | ||||||
|  |   int kL = k.shape(2); | ||||||
|  |  | ||||||
|  |   if (false && kL > 1024) { | ||||||
|  |     return sdpa_vector_2pass_fallback( | ||||||
|  |         s, encoder, q, k, v, scale, o, do_causal_); | ||||||
|  |   } else { | ||||||
|  |     return sdpa_vector_1pass_fallback( | ||||||
|  |         s, encoder, q, k, v, scale, o, do_causal_); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| struct SDPACacheKey { | struct SDPACacheKey { | ||||||
|   int device_id; |   int device_id; | ||||||
|   fe::DataType_t cudnn_type; |   fe::DataType_t cudnn_type; | ||||||
| @@ -67,8 +685,6 @@ std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph( | |||||||
|     return it->second; |     return it->second; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   nvtx3::scoped_range r("get_sdpa_forward_graph"); |  | ||||||
|  |  | ||||||
|   // Set up new graph |   // Set up new graph | ||||||
|   auto graph = std::make_shared<fe::graph::Graph>(); |   auto graph = std::make_shared<fe::graph::Graph>(); | ||||||
|  |  | ||||||
| @@ -143,8 +759,6 @@ std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph( | |||||||
|  |  | ||||||
|   // cuDNN only supports native CUDA graphs for sdpa in 9.6 or above. |   // cuDNN only supports native CUDA graphs for sdpa in 9.6 or above. | ||||||
|   if (cudnnGetVersion() < 90600) { |   if (cudnnGetVersion() < 90600) { | ||||||
|     nvtx3::scoped_range r("get_sdpa_forward_graph::graph_building"); |  | ||||||
|  |  | ||||||
|     auto build_status = graph->build(handle, {fe::HeurMode_t::A}); |     auto build_status = graph->build(handle, {fe::HeurMode_t::A}); | ||||||
|     if (!build_status.is_good()) { |     if (!build_status.is_good()) { | ||||||
|       throw std::runtime_error( |       throw std::runtime_error( | ||||||
| @@ -331,11 +945,6 @@ bool ScaledDotProductAttention::use_fallback( | |||||||
|     return true; |     return true; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   auto& cu_device = cu::device(s.device); |  | ||||||
|   if (cu_device.compute_capability_major() < 8) { |  | ||||||
|     return true; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   const int value_head_dim = v.shape(-1); |   const int value_head_dim = v.shape(-1); | ||||||
|   const int query_head_dim = q.shape(-1); |   const int query_head_dim = q.shape(-1); | ||||||
|   const int query_sequence_length = q.shape(2); |   const int query_sequence_length = q.shape(2); | ||||||
| @@ -344,11 +953,7 @@ bool ScaledDotProductAttention::use_fallback( | |||||||
|   const bool sdpa_supported_head_dim = query_head_dim == value_head_dim && |   const bool sdpa_supported_head_dim = query_head_dim == value_head_dim && | ||||||
|       (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128); |       (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128); | ||||||
|  |  | ||||||
|   const bool supported_dtype = q.dtype() == float16 || q.dtype() == bfloat16; |   return has_arr_mask || !sdpa_supported_head_dim; | ||||||
|  |  | ||||||
|   const bool supported_config = supported_dtype && sdpa_supported_head_dim; |  | ||||||
|  |  | ||||||
|   return has_arr_mask || !supported_config; |  | ||||||
| } | } | ||||||
|  |  | ||||||
| void ScaledDotProductAttention::eval_gpu( | void ScaledDotProductAttention::eval_gpu( | ||||||
| @@ -432,7 +1037,8 @@ void ScaledDotProductAttention::eval_gpu( | |||||||
|       o.set_data(allocator::malloc(o.nbytes())); |       o.set_data(allocator::malloc(o.nbytes())); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_); |     return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_); | ||||||
|  |     // return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // Full attention mode |   // Full attention mode | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jagrit Digani
					Jagrit Digani