mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 08:38:12 +08:00 
			
		
		
		
	Complete 2 pass sdpav
This commit is contained in:
		| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| #include "mlx/backend/cuda/device.h" | ||||
| #include "mlx/backend/cuda/device/config.h" | ||||
| #include "mlx/backend/cuda/device/utils.cuh" | ||||
| #include "mlx/backend/cuda/kernel_utils.cuh" | ||||
| #include "mlx/backend/cuda/lru_cache.h" | ||||
| #include "mlx/backend/gpu/copy.h" | ||||
| @@ -217,11 +218,11 @@ __global__ void kernel_sdpav_2pass_1( | ||||
|   U k[v_per_thread]; | ||||
|   U o[v_per_thread]; | ||||
|  | ||||
|   __shared__ U outputs[BD][BN + 1]; | ||||
|   __shared__ U outputs[BN][BD + 1]; | ||||
|   __shared__ U max_scores[BN]; | ||||
|   __shared__ U sum_exp_scores[BN]; | ||||
|  | ||||
|   const U scale_log2 = params.scale; // * 1.44269504089f; | ||||
|   const U scale_log2 = params.scale * 1.44269504089f; | ||||
|  | ||||
|   auto block = cg::this_thread_block(); | ||||
|   auto warp = cg::tiled_partition<32>(block); | ||||
| @@ -230,7 +231,7 @@ __global__ void kernel_sdpav_2pass_1( | ||||
|   const int warp_idx = warp.meta_group_rank(); | ||||
|  | ||||
|   // Adjust to thread block and thread | ||||
|   const int batch_idx = 0; // blockIdx.z / blocks; | ||||
|   const int batch_idx = blockIdx.z / blocks; | ||||
|   const int block_idx = blockIdx.z % blocks; | ||||
|   const int head_idx = blockIdx.x; | ||||
|   const int kv_head_idx = head_idx / params.gqa_factor; | ||||
| @@ -302,8 +303,8 @@ __global__ void kernel_sdpav_2pass_1( | ||||
|  | ||||
|       // Update the accumulators | ||||
|       U new_max = max(max_score, score); | ||||
|       U factor = expf(max_score - new_max); | ||||
|       U exp_score = expf(score - new_max); | ||||
|       U factor = exp2f(max_score - new_max); | ||||
|       U exp_score = exp2f(score - new_max); | ||||
|  | ||||
|       max_score = new_max; | ||||
|       sum_exp_score = sum_exp_score * factor + exp_score; | ||||
| @@ -330,7 +331,7 @@ __global__ void kernel_sdpav_2pass_1( | ||||
|  | ||||
|   max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9; | ||||
|   U new_max = cg::reduce(warp, max_score, cg::greater<U>()); | ||||
|   U factor = expf(max_score - new_max); | ||||
|   U factor = exp2f(max_score - new_max); | ||||
|   sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f; | ||||
|   sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>()); | ||||
|  | ||||
| @@ -341,31 +342,30 @@ __global__ void kernel_sdpav_2pass_1( | ||||
|   } | ||||
|  | ||||
|   // Now we need to aggregate all the outputs | ||||
|   auto ff = exp2f(max_scores[warp_idx] - new_max); | ||||
|   PRAGMA_LOOP_UNROLL | ||||
|   for (int i = 0; i < v_per_thread; i++) { | ||||
|     outputs[lane_idx][warp_idx] = o[i] * expf(max_scores[warp_idx] - new_max); | ||||
|     outputs[warp_idx][lane_idx] = o[i] * ff; | ||||
|     block.sync(); | ||||
|  | ||||
|     if (warp_idx == 0) { | ||||
|       U ot = outputs[lane_idx][0]; | ||||
|  | ||||
|       U ot = outputs[0][lane_idx]; | ||||
|       PRAGMA_LOOP_UNROLL | ||||
|       for (int j = 1; j < BN; j++) { | ||||
|         ot += outputs[lane_idx][0]; | ||||
|         ot += outputs[j][lane_idx]; | ||||
|         warp.sync(); | ||||
|       } | ||||
|  | ||||
|       // o[i] = ot; | ||||
|       partials[v_per_thread * lane_idx + i] = ot; | ||||
|       o[i] = ot; | ||||
|     } | ||||
|     block.sync(); | ||||
|   } | ||||
|  | ||||
|   // if(warp_idx == 0) { | ||||
|   //   PRAGMA_LOOP_UNROLL | ||||
|   //   for (int i = 0; i < v_per_thread; i++) { | ||||
|   //     partials[v_per_thread * lane_idx + i] = o[i]; | ||||
|   //   } | ||||
|   // } | ||||
|   if (warp_idx == 0) { | ||||
|     PRAGMA_LOOP_UNROLL | ||||
|     for (int i = 0; i < v_per_thread; i++) { | ||||
|       partials[v_per_thread * lane_idx + i] = o[i]; | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T, bool do_causal, int D> | ||||
| @@ -414,9 +414,9 @@ __global__ void kernel_sdpav_2pass_2( | ||||
|  | ||||
|   U max_score = maxs[lane_idx]; | ||||
|   U new_max = cg::reduce(warp, max_score, cg::greater<U>()); | ||||
|   U factor = expf(max_score - new_max); | ||||
|   U factor = exp2f(max_score - new_max); | ||||
|   U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>()); | ||||
|   // sum_exp_score = __frcp_rn(sum_exp_score); | ||||
|   sum_exp_score = __frcp_rn(sum_exp_score); | ||||
|  | ||||
|   PRAGMA_LOOP_UNROLL | ||||
|   for (int i = 0; i < v_per_thread; i++) { | ||||
| @@ -429,7 +429,7 @@ __global__ void kernel_sdpav_2pass_2( | ||||
|     outputs[lane_idx][warp_idx] = o[i]; | ||||
|     block.sync(); | ||||
|     U ot = outputs[warp_idx][lane_idx] * factor; | ||||
|     o[i] = cg::reduce(warp, ot, cg::plus<U>()) / sum_exp_score; | ||||
|     o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score; | ||||
|     block.sync(); | ||||
|   } | ||||
|  | ||||
| @@ -504,6 +504,7 @@ void sdpa_vector_1pass_fallback( | ||||
|             kernel, | ||||
|             grid_dim, | ||||
|             block_dim, | ||||
|             0, | ||||
|             q.data<DataType>(), | ||||
|             k.data<DataType>(), | ||||
|             v.data<DataType>(), | ||||
| @@ -585,6 +586,7 @@ void sdpa_vector_2pass_fallback( | ||||
|               kernel, | ||||
|               grid_dim, | ||||
|               block_dim, | ||||
|               0, | ||||
|               q.data<DataType>(), | ||||
|               k.data<DataType>(), | ||||
|               v.data<DataType>(), | ||||
| @@ -610,6 +612,7 @@ void sdpa_vector_2pass_fallback( | ||||
|               kernel, | ||||
|               grid_dim, | ||||
|               block_dim, | ||||
|               0, | ||||
|               intermediate.data<float>(), | ||||
|               sums.data<float>(), | ||||
|               maxs.data<float>(), | ||||
| @@ -632,7 +635,7 @@ void sdpa_vector_fallback( | ||||
|     bool do_causal_ = false) { | ||||
|   int kL = k.shape(2); | ||||
|  | ||||
|   if (false && kL > 1024) { | ||||
|   if (kL > 1024) { | ||||
|     return sdpa_vector_2pass_fallback( | ||||
|         s, encoder, q, k, v, scale, o, do_causal_); | ||||
|   } else { | ||||
| @@ -1034,7 +1037,23 @@ void ScaledDotProductAttention::eval_gpu( | ||||
|     if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { | ||||
|       o.copy_shared_buffer(q); | ||||
|     } else { | ||||
|       o.set_data(allocator::malloc(o.nbytes())); | ||||
|       int64_t str_oD = 1; | ||||
|       int64_t str_oH = o.shape(3); | ||||
|       int64_t str_oL = o.shape(1) * str_oH; | ||||
|       int64_t str_oB = o.shape(2) * str_oL; | ||||
|       size_t data_size = o.shape(0) * str_oB; | ||||
|  | ||||
|       array::Flags flags{ | ||||
|           /* bool contiguous = */ 1, | ||||
|           /* bool row_contiguous = */ 0, | ||||
|           /* bool col_contiguous = */ 0, | ||||
|       }; | ||||
|  | ||||
|       o.set_data( | ||||
|           allocator::malloc(o.nbytes()), | ||||
|           data_size, | ||||
|           {str_oB, str_oH, str_oL, str_oD}, | ||||
|           flags); | ||||
|     } | ||||
|  | ||||
|     return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jagrit Digani
					Jagrit Digani