mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	working qsdpa
This commit is contained in:
		| @@ -1,58 +1,94 @@ | ||||
| import argparse | ||||
| import math | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
| from mlx.utils import tree_map | ||||
| from time_utils import time_fn | ||||
|  | ||||
| L = 16384 | ||||
| L = 32768 | ||||
| H = 32 | ||||
| H_k = H // 4 | ||||
| D = 128 | ||||
| dtype = mx.float16 | ||||
| loops = 10 | ||||
| bits = 8 | ||||
|  | ||||
| loops = 20 | ||||
|  | ||||
|  | ||||
| def attention(q, k, v): | ||||
|     def _sdpa(q, k, v): | ||||
|     for _ in range(loops): | ||||
|         B, Hq, L, D = q.shape | ||||
|         _, Hk, S, _ = k.shape | ||||
|         q = q.reshape(B, Hk, Hq // Hk, L, D) | ||||
|         k = k[:, :, None, :, :] | ||||
|         v = v[:, :, None, :, :] | ||||
|         s = q @ k.transpose(0, 1, 2, 4, 3) | ||||
|         ke = k[:, :, None, :, :] | ||||
|         ve = v[:, :, None, :, :] | ||||
|         s = q @ ke.transpose(0, 1, 2, 4, 3) | ||||
|         p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) | ||||
|         o = p @ v | ||||
|         return o.reshape(B, Hq, L, D) | ||||
|  | ||||
|     for i in range(loops): | ||||
|         q = _sdpa(q, k, v) | ||||
|         q = p @ ve | ||||
|         q = q.reshape(B, Hq, L, D) | ||||
|     return q | ||||
|  | ||||
|  | ||||
| def sdpa(q, k, v): | ||||
|     for i in range(loops): | ||||
|         q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) | ||||
|     for _ in range(loops): | ||||
|         q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None) | ||||
|     return q | ||||
|  | ||||
|  | ||||
| def time_self_attention_primitives(): | ||||
|     mx.random.seed(3) | ||||
|     q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) | ||||
|     k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) | ||||
|     v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) | ||||
|     mx.eval(q, k, v) | ||||
| def quant_sdpa(q, k, v, bits=4): | ||||
|     for _ in range(loops): | ||||
|         q = mx.fast.quantized_scaled_dot_product_attention( | ||||
|             q, *k, *v, scale=1.0, mask=None, bits=bits | ||||
|         ) | ||||
|     return q | ||||
|  | ||||
|  | ||||
| def quant_attention(q, k, v, bits=4): | ||||
|     for _ in range(loops): | ||||
|         B, Hq, L, D = q.shape | ||||
|         Hk = k[0].shape[1] | ||||
|  | ||||
|         q = q.reshape((B, Hk, Hq // Hk, L, D)) | ||||
|         ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k) | ||||
|         ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v) | ||||
|  | ||||
|         scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits) | ||||
|         scores = mx.softmax(scores, axis=-1) | ||||
|  | ||||
|         q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits) | ||||
|         q = q.reshape((B, Hq, L, D)) | ||||
|     return q | ||||
|  | ||||
|  | ||||
| def time_self_attention_primitives(q, k, v): | ||||
|     time_fn(attention, q, k, v) | ||||
|  | ||||
|  | ||||
| def time_self_attention_sdpa(): | ||||
|     mx.random.seed(3) | ||||
|     q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) | ||||
|     k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) | ||||
|     v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) | ||||
|     mx.eval(q, k, v) | ||||
| def time_self_attention_sdpa(q, k, v): | ||||
|     time_fn(sdpa, q, k, v) | ||||
|  | ||||
|  | ||||
| def time_self_attention_quant_sdpa(q, k, v, bits=4): | ||||
|     time_fn(quant_sdpa, q, k, v, bits) | ||||
|  | ||||
|  | ||||
| def time_self_attention_quant_primitives(q, k, v, bits=4): | ||||
|     time_fn(quant_attention, q, k, v, bits) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     time_self_attention_sdpa() | ||||
|     time_self_attention_primitives() | ||||
|     mx.random.seed(3) | ||||
|     q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype) | ||||
|     k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) | ||||
|     v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) | ||||
|     mx.eval(q, k, v) | ||||
|  | ||||
|     k_quant = mx.quantize(k, bits=bits) | ||||
|     v_quant = mx.quantize(v, bits=bits) | ||||
|     mx.eval(k_quant, v_quant) | ||||
|  | ||||
|     k = mx.dequantize(*k_quant, bits=bits) | ||||
|     v = mx.dequantize(*v_quant, bits=bits) | ||||
|  | ||||
|     time_self_attention_sdpa(q, k, v) | ||||
|     time_self_attention_quant_sdpa(q, k_quant, v_quant, bits) | ||||
|     time_self_attention_primitives(q, k, v) | ||||
|     time_self_attention_quant_primitives(q, k_quant, v_quant, bits) | ||||
|   | ||||
| @@ -20,4 +20,33 @@ using namespace metal; | ||||
| instantiate_sdpa_vector_heads(float) | ||||
| instantiate_sdpa_vector_heads(bfloat16_t) | ||||
| instantiate_sdpa_vector_heads(float16_t) | ||||
|  | ||||
| // Quantized SDPA vector instantiations | ||||
| #define instantiate_quant_sdpa_vector(name, type, head_dim, group_size, bits) \ | ||||
|   instantiate_kernel(                                                   \ | ||||
|     #name "_" #type "_" #head_dim "_" #group_size "_" #bits, \ | ||||
|     name, type, head_dim, group_size, bits) | ||||
|  | ||||
| #define instantiate_quant_sdpa_vector_passes(type, heads, group_size, bits) \ | ||||
|   instantiate_quant_sdpa_vector(quant_sdpa_vector, type, heads, group_size, bits)         \ | ||||
|   instantiate_quant_sdpa_vector(quant_sdpa_vector_2pass_1, type, heads, group_size, bits) | ||||
|  | ||||
| #define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \ | ||||
|   instantiate_quant_sdpa_vector_passes(type, heads, group_size, 4)         \ | ||||
|   instantiate_quant_sdpa_vector_passes(type, heads, group_size, 8) | ||||
|  | ||||
| #define instantiate_quant_sdpa_vector_group_size(type, heads) \ | ||||
|   instantiate_quant_sdpa_vector_bits(type, heads, 32)         \ | ||||
|   instantiate_quant_sdpa_vector_bits(type, heads, 64)         \ | ||||
|   instantiate_quant_sdpa_vector_bits(type, heads, 128) | ||||
|  | ||||
| #define instantiate_quant_sdpa_vector_heads(type) \ | ||||
|   instantiate_quant_sdpa_vector_group_size(type, 64)         \ | ||||
|   instantiate_quant_sdpa_vector_group_size(type, 96)         \ | ||||
|   instantiate_quant_sdpa_vector_group_size(type, 128) | ||||
|  | ||||
| instantiate_quant_sdpa_vector_heads(float) | ||||
| instantiate_quant_sdpa_vector_heads(bfloat16_t) | ||||
| instantiate_quant_sdpa_vector_heads(float16_t) | ||||
|  | ||||
|     // clang-format on | ||||
|   | ||||
| @@ -113,6 +113,208 @@ template <typename T, int D> | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, int elem_per_thread, int bits> | ||||
| METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) { | ||||
|   U query_sum = 0; | ||||
|   if (bits == 4) { | ||||
|     for (int i = 0; i < elem_per_thread; i += 4) { | ||||
|       q[i] = scale * queries[i]; | ||||
|       q[i + 1] = scale * queries[i + 1]; | ||||
|       q[i + 2] = scale * queries[i + 2]; | ||||
|       q[i + 3] = scale * queries[i + 3]; | ||||
|       query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3]; | ||||
|       q[i + 1] /= 16.0f; | ||||
|       q[i + 2] /= 256.0f; | ||||
|       q[i + 3] /= 4096.0f; | ||||
|     } | ||||
|   } else if (bits == 8) { | ||||
|     for (int i = 0; i < elem_per_thread; i++) { | ||||
|       q[i] = scale * queries[i]; | ||||
|       query_sum += q[i]; | ||||
|     } | ||||
|   } | ||||
|   return query_sum; | ||||
| } | ||||
|  | ||||
| template <typename U, int elem_per_thread, int bits> | ||||
| METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) { | ||||
|   if (bits == 4) { | ||||
|     auto ks = (const device uint16_t*)keys; | ||||
|     for (int i = 0; i < elem_per_thread / 4; i++) { | ||||
|       k[4 * i] = ks[i] & 0x000f; | ||||
|       k[4 * i + 1] = ks[i] & 0x00f0; | ||||
|       k[4 * i + 2] = ks[i] & 0x0f00; | ||||
|       k[4 * i + 3] = ks[i] & 0xf000; | ||||
|     } | ||||
|   } else if (bits == 8) { | ||||
|     auto ks = (const device uint8_t*)keys; | ||||
|     for (int i = 0; i < elem_per_thread; i++) { | ||||
|       k[i] = ks[i]; | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename U, int elem_per_thread, int bits> | ||||
| METAL_FUNC void load_values( | ||||
|     const device uint32_t* values, | ||||
|     thread U* v, | ||||
|     U value_scale, | ||||
|     U value_bias) { | ||||
|   auto vs = (const device uint8_t*)values; | ||||
|   if (bits == 4) { | ||||
|     U s[2] = {value_scale, value_scale / 16.0f}; | ||||
|     for (int i = 0; i < elem_per_thread / 2; i++) { | ||||
|       v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias; | ||||
|       v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias; | ||||
|     } | ||||
|   } else if (bits == 8) { | ||||
|     for (int i = 0; i < elem_per_thread; i++) { | ||||
|       v[i] = value_scale * vs[i] + value_bias; | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T, int D, int group_size, int bits> | ||||
| [[kernel]] void quant_sdpa_vector( | ||||
|     const device T* queries [[buffer(0)]], | ||||
|     const device uint32_t* keys [[buffer(1)]], | ||||
|     const device T* key_scales [[buffer(2)]], | ||||
|     const device T* key_biases [[buffer(3)]], | ||||
|     const device uint32_t* values [[buffer(4)]], | ||||
|     const device T* value_scales [[buffer(5)]], | ||||
|     const device T* value_biases [[buffer(6)]], | ||||
|     device T* out [[buffer(7)]], | ||||
|     const constant int& gqa_factor, | ||||
|     const constant int& N, | ||||
|     const constant size_t& k_stride, | ||||
|     const constant size_t& group_stride, | ||||
|     const constant float& scale, | ||||
|     uint3 tid [[threadgroup_position_in_grid]], | ||||
|     uint simd_gid [[simdgroup_index_in_threadgroup]], | ||||
|     uint simd_lid [[thread_index_in_simdgroup]], | ||||
|     uint quad_gid [[quadgroup_index_in_threadgroup]], | ||||
|     uint quad_lid [[thread_index_in_quadgroup]]) { | ||||
|   constexpr int BN = 32; | ||||
|   constexpr int BD = 4; | ||||
|   constexpr int elem_per_thread = D / BD; | ||||
|   constexpr int pack_factor = 32 / bits; | ||||
|  | ||||
|   const int stride = BN * D; | ||||
|  | ||||
|   typedef float U; | ||||
|  | ||||
|   thread U q[elem_per_thread]; | ||||
|   thread U k[elem_per_thread]; | ||||
|   thread U v[elem_per_thread]; | ||||
|   thread U o[elem_per_thread]; | ||||
|  | ||||
|   threadgroup U outputs[BN * BD]; | ||||
|   threadgroup U max_scores[BN]; | ||||
|   threadgroup U sum_exp_scores[BN]; | ||||
|  | ||||
|   // Adjust positions | ||||
|   const int head_idx = tid.y; | ||||
|   const int kv_head_idx = head_idx / gqa_factor; | ||||
|   queries += head_idx * D + quad_lid * elem_per_thread; | ||||
|  | ||||
|   const int kv_idx = quad_gid * D + quad_lid * elem_per_thread; | ||||
|   const int packed_idx = kv_head_idx * k_stride + kv_idx / pack_factor; | ||||
|   const int group_idx = kv_head_idx * group_stride + kv_idx / group_size; | ||||
|   keys += packed_idx; | ||||
|   key_scales += group_idx; | ||||
|   key_biases += group_idx; | ||||
|   values += packed_idx; | ||||
|   value_scales += group_idx; | ||||
|   value_biases += group_idx; | ||||
|  | ||||
|   out += head_idx * D + simd_gid * elem_per_thread; | ||||
|  | ||||
|   // Read the query and 0 the output accumulator | ||||
|   U query_sum = load_queries<T, U, elem_per_thread, bits>( | ||||
|       queries, q, static_cast<U>(scale)); | ||||
|   for (int i = 0; i < elem_per_thread; i++) { | ||||
|     o[i] = 0; | ||||
|   } | ||||
|  | ||||
|   U max_score = -INFINITY; | ||||
|   U sum_exp_score = 0; | ||||
|  | ||||
|   // For each key | ||||
|   for (int i = quad_gid; i < N; i += BN) { | ||||
|     load_keys<U, elem_per_thread, bits>(keys, k); | ||||
|  | ||||
|     // Assume D % group_size == 0 so all the keys are in the same group | ||||
|     U key_scale = key_scales[0]; | ||||
|     U key_bias = key_biases[0]; | ||||
|  | ||||
|     // Compute the i-th score | ||||
|     U score = 0; | ||||
|     for (int i = 0; i < elem_per_thread; i++) { | ||||
|       score += q[i] * k[i]; | ||||
|     } | ||||
|     score = score * key_scale + query_sum * key_bias; | ||||
|     score = quad_sum(score); | ||||
|  | ||||
|     // Update the accumulators | ||||
|     U new_max = max(max_score, score); | ||||
|     U factor = fast::exp(max_score - new_max); | ||||
|     U exp_score = fast::exp(score - new_max); | ||||
|  | ||||
|     max_score = new_max; | ||||
|     sum_exp_score = sum_exp_score * factor + exp_score; | ||||
|  | ||||
|     U value_scale = value_scales[0]; | ||||
|     U value_bias = value_biases[0]; | ||||
|  | ||||
|     // Load the values | ||||
|     load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias); | ||||
|  | ||||
|     // Update the output accumulator | ||||
|     for (int i = 0; i < elem_per_thread; i++) { | ||||
|       o[i] = o[i] * factor + exp_score * v[i]; | ||||
|     } | ||||
|  | ||||
|     // Move the pointers to the next kv | ||||
|     keys += stride / pack_factor; | ||||
|     key_scales += stride / group_size; | ||||
|     key_biases += stride / group_size; | ||||
|     values += stride / pack_factor; | ||||
|     value_scales += stride / group_size; | ||||
|     value_biases += stride / group_size; | ||||
|   } | ||||
|   threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|   // Each thread has a partial part of the output so we need to combine them. | ||||
|  | ||||
|   // First let's communicate the max and sum_exp | ||||
|   // Each quadgroup communicates it's max score | ||||
|   if (quad_lid == 0) { | ||||
|     max_scores[quad_gid] = max_score; | ||||
|     sum_exp_scores[quad_gid] = sum_exp_score; | ||||
|   } | ||||
|   threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|   max_score = max_scores[simd_lid]; | ||||
|   U new_max = simd_max(max_score); | ||||
|   U factor = fast::exp(max_score - new_max); | ||||
|   sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); | ||||
|  | ||||
|   // Now we need to aggregate all the outputs | ||||
|   for (int i = 0; i < elem_per_thread; i++) { | ||||
|     // 128 threads with 32 values per thread | ||||
|     outputs[simd_gid * BN + simd_lid] = o[i]; | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|     o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score; | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|   } | ||||
|  | ||||
|   // And write the output | ||||
|   if (simd_lid == 0) { | ||||
|     for (int i = 0; i < elem_per_thread; i++) { | ||||
|       out[i] = static_cast<T>(o[i]); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T, int D> | ||||
| [[kernel]] void sdpa_vector_2pass_1( | ||||
|     const device T* queries [[buffer(0)]], | ||||
| @@ -290,3 +492,158 @@ template <typename T, int D> | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T, int D, int group_size, int bits> | ||||
| [[kernel]] void quant_sdpa_vector_2pass_1( | ||||
|     const device T* queries [[buffer(0)]], | ||||
|     const device uint32_t* keys [[buffer(1)]], | ||||
|     const device T* key_scales [[buffer(2)]], | ||||
|     const device T* key_biases [[buffer(3)]], | ||||
|     const device uint32_t* values [[buffer(4)]], | ||||
|     const device T* value_scales [[buffer(5)]], | ||||
|     const device T* value_biases [[buffer(6)]], | ||||
|     device float* out [[buffer(7)]], | ||||
|     device float* sums [[buffer(8)]], | ||||
|     device float* maxs [[buffer(9)]], | ||||
|     const constant int& gqa_factor, | ||||
|     const constant int& N, | ||||
|     const constant size_t& k_stride, | ||||
|     const constant size_t& v_stride, | ||||
|     const constant size_t& k_group_stride, | ||||
|     const constant size_t& v_group_stride, | ||||
|     const constant float& scale, | ||||
|     uint3 tid [[threadgroup_position_in_grid]], | ||||
|     uint simd_gid [[simdgroup_index_in_threadgroup]], | ||||
|     uint simd_lid [[thread_index_in_simdgroup]], | ||||
|     uint quad_gid [[quadgroup_index_in_threadgroup]], | ||||
|     uint quad_lid [[thread_index_in_quadgroup]]) { | ||||
|   constexpr int BN = 8; | ||||
|   constexpr int BD = 4; | ||||
|   constexpr int elem_per_thread = D / BD; | ||||
|   const int stride = BN * D; | ||||
|   constexpr int blocks = 32; | ||||
|   constexpr int pack_factor = 32 / bits; | ||||
|  | ||||
|   typedef float U; | ||||
|  | ||||
|   thread U q[elem_per_thread]; | ||||
|   thread U k[elem_per_thread]; | ||||
|   thread U v[elem_per_thread]; | ||||
|   thread U o[elem_per_thread]; | ||||
|  | ||||
|   threadgroup U outputs[BN * BD]; | ||||
|   threadgroup U max_scores[BN]; | ||||
|   threadgroup U sum_exp_scores[BN]; | ||||
|  | ||||
|   // Adjust positions | ||||
|   const int block_idx = tid.z; | ||||
|   const int head_idx = tid.y; | ||||
|   const int kv_head_idx = head_idx / gqa_factor; | ||||
|   queries += head_idx * D + quad_lid * elem_per_thread; | ||||
|  | ||||
|   const int kv_idx = | ||||
|       (block_idx * BN + quad_gid) * D + quad_lid * elem_per_thread; | ||||
|   const int packed_idx = kv_idx / pack_factor; | ||||
|   const int k_group_idx = kv_head_idx * k_group_stride + kv_idx / group_size; | ||||
|   const int v_group_idx = kv_head_idx * v_group_stride + kv_idx / group_size; | ||||
|  | ||||
|   keys += kv_head_idx * k_stride + packed_idx; | ||||
|   key_scales += k_group_idx; | ||||
|   key_biases += k_group_idx; | ||||
|   values += kv_head_idx * v_stride + packed_idx; | ||||
|   value_scales += v_group_idx; | ||||
|   value_biases += v_group_idx; | ||||
|  | ||||
|   out += head_idx * blocks * D + block_idx * D + quad_lid * elem_per_thread; | ||||
|   sums += head_idx * blocks + block_idx; | ||||
|   maxs += head_idx * blocks + block_idx; | ||||
|  | ||||
|   // Read the query and 0 the output accumulator | ||||
|   U query_sum = load_queries<T, U, elem_per_thread, bits>( | ||||
|       queries, q, static_cast<U>(scale)); | ||||
|   for (int i = 0; i < elem_per_thread; i++) { | ||||
|     o[i] = 0; | ||||
|   } | ||||
|  | ||||
|   U max_score = -1e9; | ||||
|   U sum_exp_score = 0; | ||||
|  | ||||
|   // For each key | ||||
|   for (int i = block_idx * BN + quad_gid; i < N; i += blocks * BN) { | ||||
|     // Read the key | ||||
|     load_keys<U, elem_per_thread, bits>(keys, k); | ||||
|  | ||||
|     // Assume D % group_size == 0 so all the keys are in the same group | ||||
|     U key_scale = key_scales[0]; | ||||
|     U key_bias = key_biases[0]; | ||||
|  | ||||
|     // Compute the i-th score | ||||
|     U score = 0; | ||||
|     for (int i = 0; i < elem_per_thread; i++) { | ||||
|       score += q[i] * k[i]; | ||||
|     } | ||||
|     score = score * key_scale + query_sum * key_bias; | ||||
|     score = quad_sum(score); | ||||
|  | ||||
|     // Update the accumulators | ||||
|     U new_max = max(max_score, score); | ||||
|     U factor = fast::exp(max_score - new_max); | ||||
|     U exp_score = fast::exp(score - new_max); | ||||
|  | ||||
|     max_score = new_max; | ||||
|     sum_exp_score = sum_exp_score * factor + exp_score; | ||||
|  | ||||
|     U value_scale = value_scales[0]; | ||||
|     U value_bias = value_biases[0]; | ||||
|     load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias); | ||||
|  | ||||
|     // Update the output accumulator | ||||
|     for (int i = 0; i < elem_per_thread; i++) { | ||||
|       o[i] = o[i] * factor + exp_score * v[i]; | ||||
|     } | ||||
|  | ||||
|     // Move the pointers to the next kv | ||||
|     keys += blocks * stride / pack_factor; | ||||
|     key_scales += blocks * stride / group_size; | ||||
|     key_biases += blocks * stride / group_size; | ||||
|     values += blocks * stride / pack_factor; | ||||
|     value_scales += blocks * stride / group_size; | ||||
|     value_biases += blocks * stride / group_size; | ||||
|   } | ||||
|  | ||||
|   // Each thread has a partial part of the output so we need to combine them. | ||||
|  | ||||
|   // First let's communicate the max and sum_exp | ||||
|   if (quad_lid == 0) { | ||||
|     max_scores[quad_gid] = max_score; | ||||
|     sum_exp_scores[quad_gid] = sum_exp_score; | ||||
|   } | ||||
|   threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|   max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9; | ||||
|   U new_max = simd_max(max_score); | ||||
|   U factor = fast::exp(max_score - new_max); | ||||
|   sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; | ||||
|   sum_exp_score = simd_sum(sum_exp_score * factor); | ||||
|  | ||||
|   // Write the sum and new max | ||||
|   if (simd_gid == 0) { | ||||
|     sums[0] = sum_exp_score; | ||||
|     maxs[0] = new_max; | ||||
|   } | ||||
|  | ||||
|   // Now we need to aggregate all the outputs | ||||
|   for (int i = 0; i < elem_per_thread; i++) { | ||||
|     outputs[quad_lid * BN + quad_gid] = | ||||
|         o[i] * fast::exp(max_scores[quad_gid] - new_max); | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     if (quad_gid == 0) { | ||||
|       U output = outputs[quad_lid * BN]; | ||||
|       for (int j = 1; j < BN; j++) { | ||||
|         output += outputs[quad_lid * BN + j]; | ||||
|       } | ||||
|       out[i] = static_cast<T>(output); | ||||
|     } | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|   } | ||||
| } | ||||
|   | ||||
| @@ -242,6 +242,171 @@ void sdpa_vector_2pass( | ||||
|   compute_encoder.dispatch_threadgroups(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| void quant_sdpa_vector( | ||||
|     const Stream& s, | ||||
|     metal::Device& d, | ||||
|     const array& q, | ||||
|     const array& k, | ||||
|     const array& k_scales, | ||||
|     const array& k_biases, | ||||
|     const array& v, | ||||
|     const array& v_scales, | ||||
|     const array& v_biases, | ||||
|     array& out, | ||||
|     float scale, | ||||
|     int group_size, | ||||
|     int bits) { | ||||
|   // Set the kernel name | ||||
|   std::string kname; | ||||
|   kname.reserve(96); | ||||
|   kname += "quant_sdpa_vector_"; | ||||
|   kname += get_type_string(q.dtype()); | ||||
|   kname += "_"; | ||||
|   kname += std::to_string(q.shape(-1)); | ||||
|   kname += "_"; | ||||
|   kname += std::to_string(group_size); | ||||
|   kname += "_"; | ||||
|   kname += std::to_string(bits); | ||||
|  | ||||
|   // Compute the necessary sizes | ||||
|   int gqa_factor = q.shape(1) / k.shape(1); | ||||
|   int N = k.shape(2); | ||||
|   int B = q.shape(0) * q.shape(1); | ||||
|   size_t stride = k.strides()[1]; | ||||
|   size_t group_stride = k_scales.strides()[1]; | ||||
|   MTL::Size group_dims(128, 1, 1); | ||||
|   MTL::Size grid_dims(1, B, 1); | ||||
|  | ||||
|   // Get the kernel | ||||
|   auto& compute_encoder = d.get_command_encoder(s.index); | ||||
|   auto kernel = d.get_kernel(kname); | ||||
|   compute_encoder.set_compute_pipeline_state(kernel); | ||||
|  | ||||
|   // Set its arguments | ||||
|   compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); | ||||
|   compute_encoder.set_input_array(k, 1); | ||||
|   compute_encoder.set_input_array(k_scales, 2); | ||||
|   compute_encoder.set_input_array(k_biases, 3); | ||||
|   compute_encoder.set_input_array(v, 4); | ||||
|   compute_encoder.set_input_array(v_scales, 5); | ||||
|   compute_encoder.set_input_array(v_biases, 6); | ||||
|   compute_encoder.set_output_array(out, 7); | ||||
|   compute_encoder.set_bytes(&gqa_factor, sizeof(int), 8); | ||||
|   compute_encoder.set_bytes(&N, sizeof(int), 9); | ||||
|   compute_encoder.set_bytes(&stride, sizeof(size_t), 10); | ||||
|   compute_encoder.set_bytes(&group_stride, sizeof(size_t), 11); | ||||
|   compute_encoder.set_bytes(&scale, sizeof(float), 12); | ||||
|  | ||||
|   // Launch | ||||
|   compute_encoder.dispatch_threadgroups(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| void quant_sdpa_vector_2pass( | ||||
|     const Stream& s, | ||||
|     metal::Device& d, | ||||
|     const array& q, | ||||
|     const array& k, | ||||
|     const array& k_scales, | ||||
|     const array& k_biases, | ||||
|     const array& v, | ||||
|     const array& v_scales, | ||||
|     const array& v_biases, | ||||
|     array& out, | ||||
|     float scale, | ||||
|     int group_size, | ||||
|     int bits) { | ||||
|   // Set the kernel name | ||||
|   std::string kname; | ||||
|   kname.reserve(96); | ||||
|   kname += "quant_sdpa_vector_2pass_1_"; | ||||
|   kname += get_type_string(q.dtype()); | ||||
|   kname += "_"; | ||||
|   kname += std::to_string(q.shape(-1)); | ||||
|   kname += "_"; | ||||
|   kname += std::to_string(group_size); | ||||
|   kname += "_"; | ||||
|   kname += std::to_string(bits); | ||||
|  | ||||
|   // Compute the necessary sizes | ||||
|   int gqa_factor = q.shape(1) / k.shape(1); | ||||
|   int N = k.shape(2); | ||||
|   int blocks = 32; | ||||
|   int B = q.shape(0) * q.shape(1); | ||||
|   size_t k_stride = k.strides()[1]; | ||||
|   size_t v_stride = v.strides()[1]; | ||||
|   size_t k_group_stride = k_scales.strides()[1]; | ||||
|   size_t v_group_stride = v_scales.strides()[1]; | ||||
|   MTL::Size group_dims(8 * 4, 1, 1); | ||||
|   MTL::Size grid_dims(1, B, blocks); | ||||
|  | ||||
|   // Allocate the intermediates | ||||
|   std::vector<int> intermediate_shape; | ||||
|   intermediate_shape.reserve(out.ndim() + 1); | ||||
|   intermediate_shape.insert( | ||||
|       intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1); | ||||
|   intermediate_shape.push_back(blocks); | ||||
|   intermediate_shape.push_back(out.shape().back()); | ||||
|   array intermediate(intermediate_shape, float32, nullptr, {}); | ||||
|   intermediate_shape.pop_back(); | ||||
|   array sums(intermediate_shape, float32, nullptr, {}); | ||||
|   array maxs(std::move(intermediate_shape), float32, nullptr, {}); | ||||
|   intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); | ||||
|   sums.set_data(allocator::malloc_or_wait(sums.nbytes())); | ||||
|   maxs.set_data(allocator::malloc_or_wait(maxs.nbytes())); | ||||
|   d.add_temporary(intermediate, s.index); | ||||
|   d.add_temporary(sums, s.index); | ||||
|   d.add_temporary(maxs, s.index); | ||||
|  | ||||
|   // Get the kernel | ||||
|   auto& compute_encoder = d.get_command_encoder(s.index); | ||||
|   auto kernel = d.get_kernel(kname); | ||||
|   compute_encoder.set_compute_pipeline_state(kernel); | ||||
|  | ||||
|   // Set its arguments | ||||
|   compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); | ||||
|   compute_encoder.set_input_array(k, 1); | ||||
|   compute_encoder.set_input_array(k_scales, 2); | ||||
|   compute_encoder.set_input_array(k_biases, 3); | ||||
|   compute_encoder.set_input_array(v, 4); | ||||
|   compute_encoder.set_input_array(v_scales, 5); | ||||
|   compute_encoder.set_input_array(v_biases, 6); | ||||
|   compute_encoder.set_output_array(intermediate, 7); | ||||
|   compute_encoder.set_output_array(sums, 8); | ||||
|   compute_encoder.set_output_array(maxs, 9); | ||||
|   compute_encoder.set_bytes(gqa_factor, 10); | ||||
|   compute_encoder.set_bytes(N, 11); | ||||
|   compute_encoder.set_bytes(k_stride, 12); | ||||
|   compute_encoder.set_bytes(v_stride, 13); | ||||
|   compute_encoder.set_bytes(k_group_stride, 14); | ||||
|   compute_encoder.set_bytes(v_group_stride, 15); | ||||
|   compute_encoder.set_bytes(scale, 16); | ||||
|  | ||||
|   // Launch | ||||
|   compute_encoder.dispatch_threadgroups(grid_dims, group_dims); | ||||
|  | ||||
|   // Final pass | ||||
|   kname.clear(); | ||||
|   kname += "sdpa_vector_2pass_2_"; | ||||
|   kname += get_type_string(q.dtype()); | ||||
|   kname += "_"; | ||||
|   kname += std::to_string(q.shape(-1)); | ||||
|  | ||||
|   // Get the kernel | ||||
|   kernel = d.get_kernel(kname); | ||||
|   compute_encoder.set_compute_pipeline_state(kernel); | ||||
|  | ||||
|   // Set its arguments | ||||
|   compute_encoder.set_input_array(intermediate, 0); | ||||
|   compute_encoder.set_input_array(sums, 1); | ||||
|   compute_encoder.set_input_array(maxs, 2); | ||||
|   compute_encoder.set_output_array(out, 3); | ||||
|  | ||||
|   // Launch | ||||
|   group_dims = MTL::Size(1024, 1, 1); | ||||
|   grid_dims = MTL::Size(1, B, 1); | ||||
|   compute_encoder.dispatch_threadgroups(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| void ScaledDotProductAttention::eval_gpu( | ||||
| @@ -254,7 +419,6 @@ void ScaledDotProductAttention::eval_gpu( | ||||
|  | ||||
|   auto& q_pre = inputs[0]; | ||||
|   auto& k_pre = inputs[1]; | ||||
|   auto& v_pre = inputs[2]; | ||||
|   auto& o = out; | ||||
|  | ||||
|   std::vector<array> copies; | ||||
| @@ -295,9 +459,7 @@ void ScaledDotProductAttention::eval_gpu( | ||||
|  | ||||
|   // We are in vector mode ie single query | ||||
|   if (q_pre.shape(2) == 1) { | ||||
|     const auto& q = copy_unless(is_contiguous, q_pre); | ||||
|     const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); | ||||
|     const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); | ||||
|     auto q = copy_unless(is_contiguous, q_pre); | ||||
|  | ||||
|     // Donate the query if possible | ||||
|     if (q.is_donatable()) { | ||||
| @@ -306,20 +468,55 @@ void ScaledDotProductAttention::eval_gpu( | ||||
|       o.set_data(allocator::malloc_or_wait(o.nbytes())); | ||||
|     } | ||||
|  | ||||
|     // We route to the 2 pass fused attention if | ||||
|     // - The device is large and the sequence length long | ||||
|     // - The sequence length is even longer and we have gqa | ||||
|     char devc = d.get_architecture().back(); | ||||
|     if ((devc == 'd' && k.shape(2) >= 1024) || | ||||
|         (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { | ||||
|       sdpa_vector_2pass(s, d, q, k, v, o, scale_); | ||||
|     if (quantized_) { | ||||
|       auto& k_scales_pre = inputs[2]; | ||||
|       auto& k_biases_pre = inputs[3]; | ||||
|       auto& v_pre = inputs[4]; | ||||
|       auto& v_scales_pre = inputs[5]; | ||||
|       auto& v_biases_pre = inputs[6]; | ||||
|  | ||||
|       auto k = copy_unless(is_contiguous_except_seq_len, k_pre); | ||||
|       auto k_scales = copy_unless(is_contiguous_except_seq_len, k_scales_pre); | ||||
|       auto k_biases = copy_unless(is_contiguous_except_seq_len, k_biases_pre); | ||||
|       auto v = copy_unless(is_contiguous_except_seq_len, v_pre); | ||||
|       auto v_scales = copy_unless(is_contiguous_except_seq_len, v_scales_pre); | ||||
|       auto v_biases = copy_unless(is_contiguous_except_seq_len, v_biases_pre); | ||||
|  | ||||
|       quant_sdpa_vector_2pass( | ||||
|           s, | ||||
|           d, | ||||
|           q, | ||||
|           k, | ||||
|           k_scales, | ||||
|           k_biases, | ||||
|           v, | ||||
|           v_scales, | ||||
|           v_biases, | ||||
|           o, | ||||
|           scale_, | ||||
|           group_size_, | ||||
|           bits_); | ||||
|     } else { | ||||
|       sdpa_vector(s, d, q, k, v, o, scale_); | ||||
|       auto& k_pre = inputs[1]; | ||||
|       auto& v_pre = inputs[2]; | ||||
|  | ||||
|       const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); | ||||
|       const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); | ||||
|  | ||||
|       char devc = d.get_architecture().back(); | ||||
|       if ((devc == 'd' && k.shape(2) >= 1024) || | ||||
|           (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { | ||||
|         sdpa_vector_2pass(s, d, q, k, v, o, scale_); | ||||
|       } else { | ||||
|         sdpa_vector(s, d, q, k, v, o, scale_); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // Full attention mode | ||||
|   else { | ||||
|     auto& v_pre = inputs[2]; | ||||
|  | ||||
|     const auto& q = copy_unless(is_matrix_contiguous, q_pre); | ||||
|     const auto& k = copy_unless(is_matrix_contiguous, k_pre); | ||||
|     const auto& v = copy_unless(is_matrix_contiguous, v_pre); | ||||
|   | ||||
							
								
								
									
										127
									
								
								mlx/fast.cpp
									
									
									
									
									
								
							
							
						
						
									
										127
									
								
								mlx/fast.cpp
									
									
									
									
									
								
							| @@ -664,7 +664,7 @@ array scaled_dot_product_attention( | ||||
|         std::move(out_shape), | ||||
|         final_type, | ||||
|         std::make_shared<ScaledDotProductAttention>( | ||||
|             stream, fallback, scale, false), | ||||
|             stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/false), | ||||
|         {q, k, v}); | ||||
|   } | ||||
|  | ||||
| @@ -678,7 +678,130 @@ array scaled_dot_product_attention( | ||||
| bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { | ||||
|   const ScaledDotProductAttention& a_other = | ||||
|       static_cast<const ScaledDotProductAttention&>(other); | ||||
|   return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_; | ||||
|   return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_ && | ||||
|       quantized_ == a_other.quantized_; | ||||
| } | ||||
|  | ||||
| array quantized_scaled_dot_product_attention( | ||||
|     const array& queries, | ||||
|     const array& keys, | ||||
|     const array& key_scales, | ||||
|     const array& key_biases, | ||||
|     const array& values, | ||||
|     const array& value_scales, | ||||
|     const array& value_biases, | ||||
|     const float scale, | ||||
|     const std::optional<array>& mask, | ||||
|     const int group_size, | ||||
|     const int bits, | ||||
|     StreamOrDevice s) { | ||||
|   int el_per_int = 32 / bits; | ||||
|   int out_dim = values.shape(-1) * el_per_int; | ||||
|  | ||||
|   auto n_q_heads = queries.shape(-3); | ||||
|   auto n_kv_heads = keys.shape(-3); | ||||
|  | ||||
|   auto out_shape = std::vector<int>( | ||||
|       {queries.shape(0), queries.shape(1), queries.shape(2), out_dim}); | ||||
|   auto stream = to_stream(s); | ||||
|   bool needs_mask = mask.has_value(); | ||||
|   auto fallback = | ||||
|       [scale, needs_mask, n_q_heads, n_kv_heads, group_size, bits, &s]( | ||||
|           const std::vector<array>& inputs) -> std::vector<array> { | ||||
|     int n_repeats = n_q_heads / n_kv_heads; | ||||
|  | ||||
|     auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); | ||||
|  | ||||
|     auto k = inputs[1]; | ||||
|     auto k_scales = inputs[2]; | ||||
|     auto k_biases = inputs[3]; | ||||
|     auto v = inputs[4]; | ||||
|     auto v_scales = inputs[5]; | ||||
|     auto v_biases = inputs[6]; | ||||
|  | ||||
|     int B = q.shape(0); | ||||
|     int L = q.shape(2); | ||||
|  | ||||
|     if (n_repeats > 1) { | ||||
|       q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s); | ||||
|       k = expand_dims(k, 2, s); | ||||
|       k_scales = expand_dims(k_scales, 2, s); | ||||
|       k_biases = expand_dims(k_biases, 2, s); | ||||
|       v = expand_dims(v, 2, s); | ||||
|       v_scales = expand_dims(v_scales, 2, s); | ||||
|       v_biases = expand_dims(v_biases, 2, s); | ||||
|     } | ||||
|  | ||||
|     array scores = quantized_matmul( | ||||
|         q, | ||||
|         k, | ||||
|         k_scales, | ||||
|         k_biases, | ||||
|         /*transpose=*/true, | ||||
|         /*group_size=*/group_size, | ||||
|         /*bits=*/bits, | ||||
|         s); | ||||
|     if (needs_mask) { | ||||
|       scores = add(scores, inputs[7], s); | ||||
|     } | ||||
|     scores = softmax(scores, std::vector<int>{-1}, true, s); | ||||
|     array out = quantized_matmul( | ||||
|         scores, | ||||
|         v, | ||||
|         v_scales, | ||||
|         v_biases, | ||||
|         /*transpose=*/false, | ||||
|         /*group_size=*/group_size, | ||||
|         /*bits=*/bits, | ||||
|         s); | ||||
|     if (n_repeats > 1) { | ||||
|       out = reshape(out, {B, n_q_heads, L, -1}, s); | ||||
|     } | ||||
|     return std::vector<array>{out}; | ||||
|   }; | ||||
|  | ||||
|   int L = queries.shape(2); | ||||
|   if (L > 1) { | ||||
|     if (needs_mask) { | ||||
|       return fallback( | ||||
|           {queries, | ||||
|            keys, | ||||
|            key_scales, | ||||
|            key_biases, | ||||
|            values, | ||||
|            value_scales, | ||||
|            value_biases, | ||||
|            mask.value()})[0]; | ||||
|     } else { | ||||
|       return fallback( | ||||
|           {queries, | ||||
|            keys, | ||||
|            key_scales, | ||||
|            key_biases, | ||||
|            values, | ||||
|            value_scales, | ||||
|            value_biases})[0]; | ||||
|     } | ||||
|   } else { | ||||
|     return array( | ||||
|         std::move(out_shape), | ||||
|         queries.dtype(), | ||||
|         std::make_shared<ScaledDotProductAttention>( | ||||
|             stream, | ||||
|             fallback, | ||||
|             scale, | ||||
|             /*needs_mask=*/false, | ||||
|             /*quantized=*/true, | ||||
|             group_size, | ||||
|             bits), | ||||
|         {queries, | ||||
|          keys, | ||||
|          key_scales, | ||||
|          key_biases, | ||||
|          values, | ||||
|          value_scales, | ||||
|          value_biases}); | ||||
|   } | ||||
| } | ||||
|  | ||||
| array pack_and_quantize( | ||||
|   | ||||
							
								
								
									
										15
									
								
								mlx/fast.h
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								mlx/fast.h
									
									
									
									
									
								
							| @@ -41,6 +41,21 @@ array scaled_dot_product_attention( | ||||
|     const std::optional<int> memory_efficient_threshold = std::nullopt, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| /** Computes: `O = softmax(Q @ K.T) @ V` where K and V are quantized. **/ | ||||
| array quantized_scaled_dot_product_attention( | ||||
|     const array& queries, | ||||
|     const array& keys, | ||||
|     const array& key_scales, | ||||
|     const array& key_biases, | ||||
|     const array& values, | ||||
|     const array& value_scales, | ||||
|     const array& value_biases, | ||||
|     const float scale, | ||||
|     const std::optional<array>& mask = std::nullopt, | ||||
|     const int group_size = 64, | ||||
|     const int bits = 4, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| std::tuple<array, array, array> affine_quantize( | ||||
|     const array& w, | ||||
|     int group_size = 64, | ||||
|   | ||||
| @@ -190,8 +190,16 @@ class ScaledDotProductAttention : public Custom { | ||||
|       Stream stream, | ||||
|       std::function<std::vector<array>(std::vector<array>)> fallback, | ||||
|       const float scale, | ||||
|       const bool needs_mask) | ||||
|       : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {} | ||||
|       const bool needs_mask, | ||||
|       const bool quantized, | ||||
|       const int group_size = 64, | ||||
|       const int bits = 4) | ||||
|       : Custom(stream, fallback), | ||||
|         scale_(scale), | ||||
|         needs_mask_(needs_mask), | ||||
|         quantized_(quantized), | ||||
|         group_size_(group_size), | ||||
|         bits_(bits) {} | ||||
|  | ||||
|   void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) | ||||
|       override { | ||||
| @@ -212,6 +220,9 @@ class ScaledDotProductAttention : public Custom { | ||||
|   std::function<std::vector<array>(std::vector<array>)> fallback_; | ||||
|   float scale_; | ||||
|   bool needs_mask_; | ||||
|   bool quantized_; | ||||
|   int group_size_; | ||||
|   int bits_; | ||||
| }; | ||||
|  | ||||
| class AffineQuantize : public Custom { | ||||
|   | ||||
| @@ -161,6 +161,45 @@ void init_fast(nb::module_& parent_module) { | ||||
|             array: The output array. | ||||
|       )pbdoc"); | ||||
|  | ||||
|   m.def( | ||||
|       "quantized_scaled_dot_product_attention", | ||||
|       &fast::quantized_scaled_dot_product_attention, | ||||
|       "q"_a, | ||||
|       "k"_a, | ||||
|       "k_scales"_a, | ||||
|       "k_biases"_a, | ||||
|       "v"_a, | ||||
|       "v_scales"_a, | ||||
|       "v_biases"_a, | ||||
|       nb::kw_only(), | ||||
|       "scale"_a, | ||||
|       "mask"_a = nb::none(), | ||||
|       "group_size"_a = 64, | ||||
|       "bits"_a = 4, | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float,  mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|         A fast implementation of multi-head attention where the keys and values are quantized. | ||||
|  | ||||
|         see :func:`scaled_dot_product_attention` for more details. | ||||
|  | ||||
|         Args: | ||||
|             q (array): Input query array. | ||||
|             k (array): Input keys array. | ||||
|             k_scales (array): Scales for the quantized keys array. | ||||
|             k_biases (array): Biases for the quantized keys array. | ||||
|             v (array): Input values array. | ||||
|             v_scales (array): Scales for the quantized values array. | ||||
|             v_biases (array): Biases for the quantized values array. | ||||
|             scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) | ||||
|             mask (array, optional): An additive mask to apply to the query-key scores. | ||||
|             group_size (int): The group size used in the KV quantization. | ||||
|             bits (int): The bits used in the KV quantization. | ||||
|         Returns: | ||||
|             array: The output array. | ||||
|       )pbdoc"); | ||||
|  | ||||
|   m.def( | ||||
|       "metal_kernel", | ||||
|       [](const std::string& name, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron