mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Double buffer keys for vector sdpa
This commit is contained in:
		| @@ -4,7 +4,7 @@ import math | ||||
| import mlx.core as mx | ||||
| from time_utils import time_fn | ||||
|  | ||||
| L = 16384 | ||||
| L = 1024 | ||||
| H = 32 | ||||
| H_k = H // 4 | ||||
| D = 128 | ||||
|   | ||||
| @@ -45,7 +45,7 @@ template <typename T, int D, int V = D> | ||||
|   typedef float U; | ||||
|  | ||||
|   thread U q[qk_per_thread]; | ||||
|   thread U k[qk_per_thread]; | ||||
|   thread U k[2][qk_per_thread]; | ||||
|   thread U o[v_per_thread]; | ||||
|  | ||||
|   threadgroup U outputs[BN * BD]; | ||||
| @@ -86,8 +86,21 @@ template <typename T, int D, int V = D> | ||||
|   U max_score = -INFINITY; | ||||
|   U sum_exp_score = 0; | ||||
|  | ||||
|   // Read the first key | ||||
|   short a = 0, b = 1; | ||||
|   for (int j = 0; j < qk_per_thread; j++) { | ||||
|     k[a][j] = keys[j]; | ||||
|   } | ||||
|   keys += inner_k_stride; | ||||
|  | ||||
|   // For each key | ||||
|   for (int i = simd_gid; i < N; i += BN) { | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|     // Read the next key | ||||
|     for (int j = 0; j < qk_per_thread; j++) { | ||||
|       k[b][j] = keys[j]; | ||||
|     } | ||||
|  | ||||
|     bool use_key = true; | ||||
|     if (do_causal) { | ||||
|       use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); | ||||
| @@ -95,15 +108,10 @@ template <typename T, int D, int V = D> | ||||
|       use_key = bmask[0]; | ||||
|     } | ||||
|     if (use_key) { | ||||
|       // Read the key | ||||
|       for (int j = 0; j < qk_per_thread; j++) { | ||||
|         k[j] = keys[j]; | ||||
|       } | ||||
|  | ||||
|       // Compute the i-th score | ||||
|       U score = 0; | ||||
|       for (int j = 0; j < qk_per_thread; j++) { | ||||
|         score += q[j] * k[j]; | ||||
|         score += q[j] * k[a][j]; | ||||
|       } | ||||
|       score = simd_sum(score); | ||||
|       if (float_mask) { | ||||
| @@ -133,6 +141,11 @@ template <typename T, int D, int V = D> | ||||
|     if (float_mask) { | ||||
|       fmask += BN * mask_kv_seq_stride; | ||||
|     } | ||||
|  | ||||
|     // Swap read and write locations | ||||
|     short tmp = a; | ||||
|     b = a; | ||||
|     a = tmp; | ||||
|   } | ||||
|  | ||||
|   // Each thread has a partial part of the output so we need to combine them. | ||||
| @@ -202,7 +215,7 @@ template <typename T, int D, int V = D> | ||||
|   typedef float U; | ||||
|  | ||||
|   thread U q[qk_per_thread]; | ||||
|   thread U k[qk_per_thread]; | ||||
|   thread U k[2][qk_per_thread]; | ||||
|   thread U o[v_per_thread]; | ||||
|  | ||||
|   threadgroup U outputs[BN * BD]; | ||||
| @@ -248,8 +261,21 @@ template <typename T, int D, int V = D> | ||||
|   U max_score = -1e9; | ||||
|   U sum_exp_score = 0; | ||||
|  | ||||
|   // Read the first key | ||||
|   short a = 0, b = 1; | ||||
|   for (int j = 0; j < qk_per_thread; j++) { | ||||
|     k[a][j] = keys[j]; | ||||
|   } | ||||
|   keys += inner_k_stride; | ||||
|  | ||||
|   // For each key | ||||
|   for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|     // Read the next key | ||||
|     for (int j = 0; j < qk_per_thread; j++) { | ||||
|       k[b][j] = keys[j]; | ||||
|     } | ||||
|  | ||||
|     bool use_key = true; | ||||
|     if (do_causal) { | ||||
|       use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); | ||||
| @@ -257,15 +283,10 @@ template <typename T, int D, int V = D> | ||||
|       use_key = bmask[0]; | ||||
|     } | ||||
|     if (use_key) { | ||||
|       // Read the key | ||||
|       for (int i = 0; i < qk_per_thread; i++) { | ||||
|         k[i] = keys[i]; | ||||
|       } | ||||
|  | ||||
|       // Compute the i-th score | ||||
|       U score = 0; | ||||
|       for (int i = 0; i < qk_per_thread; i++) { | ||||
|         score += q[i] * k[i]; | ||||
|         score += q[i] * k[a][i]; | ||||
|       } | ||||
|       score = simd_sum(score); | ||||
|       if (float_mask) { | ||||
| @@ -295,6 +316,11 @@ template <typename T, int D, int V = D> | ||||
|     if (float_mask) { | ||||
|       fmask += BN * blocks * mask_kv_seq_stride; | ||||
|     } | ||||
|  | ||||
|     // Swap read and write locations | ||||
|     short tmp = a; | ||||
|     b = a; | ||||
|     a = tmp; | ||||
|   } | ||||
|  | ||||
|   // Each thread has a partial part of the output so we need to combine them. | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos