mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	revert sdpa
This commit is contained in:
		@@ -16,11 +16,9 @@ template <typename T, int D>
 | 
			
		||||
    const constant float& scale,
 | 
			
		||||
    uint3 tid [[threadgroup_position_in_grid]],
 | 
			
		||||
    uint simd_gid [[simdgroup_index_in_threadgroup]],
 | 
			
		||||
    uint simd_lid [[thread_index_in_simdgroup]],
 | 
			
		||||
    uint quad_gid [[quadgroup_index_in_threadgroup]],
 | 
			
		||||
    uint quad_lid [[thread_index_in_quadgroup]]) {
 | 
			
		||||
    uint simd_lid [[thread_index_in_simdgroup]]) {
 | 
			
		||||
  constexpr int BN = 32;
 | 
			
		||||
  constexpr int BD = 4;
 | 
			
		||||
  constexpr int BD = 32;
 | 
			
		||||
  constexpr int elem_per_thread = D / BD;
 | 
			
		||||
 | 
			
		||||
  const int stride = BN * D;
 | 
			
		||||
@@ -38,9 +36,9 @@ template <typename T, int D>
 | 
			
		||||
  // Adjust positions
 | 
			
		||||
  const int head_idx = tid.y;
 | 
			
		||||
  const int kv_head_idx = head_idx / gqa_factor;
 | 
			
		||||
  queries += head_idx * D + quad_lid * elem_per_thread;
 | 
			
		||||
  keys += kv_head_idx * k_stride + quad_gid * D + quad_lid * elem_per_thread;
 | 
			
		||||
  values += kv_head_idx * k_stride + quad_gid * D + quad_lid * elem_per_thread;
 | 
			
		||||
  queries += head_idx * D + simd_lid * elem_per_thread;
 | 
			
		||||
  keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
 | 
			
		||||
  values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
 | 
			
		||||
  out += head_idx * D + simd_gid * elem_per_thread;
 | 
			
		||||
 | 
			
		||||
  // Read the query and 0 the output accumulator
 | 
			
		||||
@@ -55,7 +53,7 @@ template <typename T, int D>
 | 
			
		||||
  U sum_exp_score = 0;
 | 
			
		||||
 | 
			
		||||
  // For each key
 | 
			
		||||
  for (int i = quad_gid; i < N; i += BN) {
 | 
			
		||||
  for (int i = simd_gid; i < N; i += BN) {
 | 
			
		||||
    // Read the key
 | 
			
		||||
    for (int i = 0; i < elem_per_thread; i++) {
 | 
			
		||||
      k[i] = keys[i];
 | 
			
		||||
@@ -66,7 +64,7 @@ template <typename T, int D>
 | 
			
		||||
    for (int i = 0; i < elem_per_thread; i++) {
 | 
			
		||||
      score += q[i] * k[i];
 | 
			
		||||
    }
 | 
			
		||||
    score = quad_sum(score);
 | 
			
		||||
    score = simd_sum(score);
 | 
			
		||||
 | 
			
		||||
    // Update the accumulators
 | 
			
		||||
    U new_max = max(max_score, score);
 | 
			
		||||
@@ -90,10 +88,9 @@ template <typename T, int D>
 | 
			
		||||
  // Each thread has a partial part of the output so we need to combine them.
 | 
			
		||||
 | 
			
		||||
  // First let's communicate the max and sum_exp
 | 
			
		||||
  // Each quadgroup communicates it's max score
 | 
			
		||||
  if (quad_lid == 0) {
 | 
			
		||||
    max_scores[quad_gid] = max_score;
 | 
			
		||||
    sum_exp_scores[quad_gid] = sum_exp_score;
 | 
			
		||||
  if (simd_lid == 0) {
 | 
			
		||||
    max_scores[simd_gid] = max_score;
 | 
			
		||||
    sum_exp_scores[simd_gid] = sum_exp_score;
 | 
			
		||||
  }
 | 
			
		||||
  threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
  max_score = max_scores[simd_lid];
 | 
			
		||||
@@ -103,10 +100,9 @@ template <typename T, int D>
 | 
			
		||||
 | 
			
		||||
  // Now we need to aggregate all the outputs
 | 
			
		||||
  for (int i = 0; i < elem_per_thread; i++) {
 | 
			
		||||
    // 128 threads with 32 values per thread
 | 
			
		||||
    outputs[simd_gid * BN + simd_lid] = o[i];
 | 
			
		||||
    outputs[simd_lid * BD + simd_gid] = o[i];
 | 
			
		||||
    threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
    o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score;
 | 
			
		||||
    o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
 | 
			
		||||
    threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -165,7 +165,7 @@ void sdpa_vector(
 | 
			
		||||
  int N = k.shape(2);
 | 
			
		||||
  int B = q.shape(0) * q.shape(1);
 | 
			
		||||
  size_t stride = k.strides()[1];
 | 
			
		||||
  MTL::Size group_dims(128, 1, 1);
 | 
			
		||||
  MTL::Size group_dims(1024, 1, 1);
 | 
			
		||||
  MTL::Size grid_dims(1, B, 1);
 | 
			
		||||
 | 
			
		||||
  // Get the kernel
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user