mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Double buffer keys for vector sdpa
This commit is contained in:
		| @@ -4,7 +4,7 @@ import math | |||||||
| import mlx.core as mx | import mlx.core as mx | ||||||
| from time_utils import time_fn | from time_utils import time_fn | ||||||
|  |  | ||||||
| L = 16384 | L = 1024 | ||||||
| H = 32 | H = 32 | ||||||
| H_k = H // 4 | H_k = H // 4 | ||||||
| D = 128 | D = 128 | ||||||
|   | |||||||
| @@ -45,7 +45,7 @@ template <typename T, int D, int V = D> | |||||||
|   typedef float U; |   typedef float U; | ||||||
|  |  | ||||||
|   thread U q[qk_per_thread]; |   thread U q[qk_per_thread]; | ||||||
|   thread U k[qk_per_thread]; |   thread U k[2][qk_per_thread]; | ||||||
|   thread U o[v_per_thread]; |   thread U o[v_per_thread]; | ||||||
|  |  | ||||||
|   threadgroup U outputs[BN * BD]; |   threadgroup U outputs[BN * BD]; | ||||||
| @@ -86,8 +86,21 @@ template <typename T, int D, int V = D> | |||||||
|   U max_score = -INFINITY; |   U max_score = -INFINITY; | ||||||
|   U sum_exp_score = 0; |   U sum_exp_score = 0; | ||||||
|  |  | ||||||
|  |   // Read the first key | ||||||
|  |   short a = 0, b = 1; | ||||||
|  |   for (int j = 0; j < qk_per_thread; j++) { | ||||||
|  |     k[a][j] = keys[j]; | ||||||
|  |   } | ||||||
|  |   keys += inner_k_stride; | ||||||
|  |  | ||||||
|   // For each key |   // For each key | ||||||
|   for (int i = simd_gid; i < N; i += BN) { |   for (int i = simd_gid; i < N; i += BN) { | ||||||
|  |     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |     // Read the next key | ||||||
|  |     for (int j = 0; j < qk_per_thread; j++) { | ||||||
|  |       k[b][j] = keys[j]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     bool use_key = true; |     bool use_key = true; | ||||||
|     if (do_causal) { |     if (do_causal) { | ||||||
|       use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); |       use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); | ||||||
| @@ -95,15 +108,10 @@ template <typename T, int D, int V = D> | |||||||
|       use_key = bmask[0]; |       use_key = bmask[0]; | ||||||
|     } |     } | ||||||
|     if (use_key) { |     if (use_key) { | ||||||
|       // Read the key |  | ||||||
|       for (int j = 0; j < qk_per_thread; j++) { |  | ||||||
|         k[j] = keys[j]; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       // Compute the i-th score |       // Compute the i-th score | ||||||
|       U score = 0; |       U score = 0; | ||||||
|       for (int j = 0; j < qk_per_thread; j++) { |       for (int j = 0; j < qk_per_thread; j++) { | ||||||
|         score += q[j] * k[j]; |         score += q[j] * k[a][j]; | ||||||
|       } |       } | ||||||
|       score = simd_sum(score); |       score = simd_sum(score); | ||||||
|       if (float_mask) { |       if (float_mask) { | ||||||
| @@ -133,6 +141,11 @@ template <typename T, int D, int V = D> | |||||||
|     if (float_mask) { |     if (float_mask) { | ||||||
|       fmask += BN * mask_kv_seq_stride; |       fmask += BN * mask_kv_seq_stride; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // Swap read and write locations | ||||||
|  |     short tmp = a; | ||||||
|  |     b = a; | ||||||
|  |     a = tmp; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // Each thread has a partial part of the output so we need to combine them. |   // Each thread has a partial part of the output so we need to combine them. | ||||||
| @@ -202,7 +215,7 @@ template <typename T, int D, int V = D> | |||||||
|   typedef float U; |   typedef float U; | ||||||
|  |  | ||||||
|   thread U q[qk_per_thread]; |   thread U q[qk_per_thread]; | ||||||
|   thread U k[qk_per_thread]; |   thread U k[2][qk_per_thread]; | ||||||
|   thread U o[v_per_thread]; |   thread U o[v_per_thread]; | ||||||
|  |  | ||||||
|   threadgroup U outputs[BN * BD]; |   threadgroup U outputs[BN * BD]; | ||||||
| @@ -248,8 +261,21 @@ template <typename T, int D, int V = D> | |||||||
|   U max_score = -1e9; |   U max_score = -1e9; | ||||||
|   U sum_exp_score = 0; |   U sum_exp_score = 0; | ||||||
|  |  | ||||||
|  |   // Read the first key | ||||||
|  |   short a = 0, b = 1; | ||||||
|  |   for (int j = 0; j < qk_per_thread; j++) { | ||||||
|  |     k[a][j] = keys[j]; | ||||||
|  |   } | ||||||
|  |   keys += inner_k_stride; | ||||||
|  |  | ||||||
|   // For each key |   // For each key | ||||||
|   for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { |   for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { | ||||||
|  |     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |     // Read the next key | ||||||
|  |     for (int j = 0; j < qk_per_thread; j++) { | ||||||
|  |       k[b][j] = keys[j]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     bool use_key = true; |     bool use_key = true; | ||||||
|     if (do_causal) { |     if (do_causal) { | ||||||
|       use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); |       use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); | ||||||
| @@ -257,15 +283,10 @@ template <typename T, int D, int V = D> | |||||||
|       use_key = bmask[0]; |       use_key = bmask[0]; | ||||||
|     } |     } | ||||||
|     if (use_key) { |     if (use_key) { | ||||||
|       // Read the key |  | ||||||
|       for (int i = 0; i < qk_per_thread; i++) { |  | ||||||
|         k[i] = keys[i]; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       // Compute the i-th score |       // Compute the i-th score | ||||||
|       U score = 0; |       U score = 0; | ||||||
|       for (int i = 0; i < qk_per_thread; i++) { |       for (int i = 0; i < qk_per_thread; i++) { | ||||||
|         score += q[i] * k[i]; |         score += q[i] * k[a][i]; | ||||||
|       } |       } | ||||||
|       score = simd_sum(score); |       score = simd_sum(score); | ||||||
|       if (float_mask) { |       if (float_mask) { | ||||||
| @@ -295,6 +316,11 @@ template <typename T, int D, int V = D> | |||||||
|     if (float_mask) { |     if (float_mask) { | ||||||
|       fmask += BN * blocks * mask_kv_seq_stride; |       fmask += BN * blocks * mask_kv_seq_stride; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // Swap read and write locations | ||||||
|  |     short tmp = a; | ||||||
|  |     b = a; | ||||||
|  |     a = tmp; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // Each thread has a partial part of the output so we need to combine them. |   // Each thread has a partial part of the output so we need to combine them. | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos