mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	working qsdpa
This commit is contained in:
		@@ -1,58 +1,94 @@
 | 
				
			|||||||
import argparse
 | 
					 | 
				
			||||||
import math
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import mlx.core as mx
 | 
					import mlx.core as mx
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from mlx.utils import tree_map
 | 
				
			||||||
from time_utils import time_fn
 | 
					from time_utils import time_fn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
L = 16384
 | 
					L = 32768
 | 
				
			||||||
H = 32
 | 
					H = 32
 | 
				
			||||||
H_k = H // 4
 | 
					H_k = H // 4
 | 
				
			||||||
D = 128
 | 
					D = 128
 | 
				
			||||||
dtype = mx.float16
 | 
					dtype = mx.float16
 | 
				
			||||||
loops = 10
 | 
					bits = 8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					loops = 20
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def attention(q, k, v):
 | 
					def attention(q, k, v):
 | 
				
			||||||
    def _sdpa(q, k, v):
 | 
					    for _ in range(loops):
 | 
				
			||||||
        B, Hq, L, D = q.shape
 | 
					        B, Hq, L, D = q.shape
 | 
				
			||||||
        _, Hk, S, _ = k.shape
 | 
					        _, Hk, S, _ = k.shape
 | 
				
			||||||
        q = q.reshape(B, Hk, Hq // Hk, L, D)
 | 
					        q = q.reshape(B, Hk, Hq // Hk, L, D)
 | 
				
			||||||
        k = k[:, :, None, :, :]
 | 
					        ke = k[:, :, None, :, :]
 | 
				
			||||||
        v = v[:, :, None, :, :]
 | 
					        ve = v[:, :, None, :, :]
 | 
				
			||||||
        s = q @ k.transpose(0, 1, 2, 4, 3)
 | 
					        s = q @ ke.transpose(0, 1, 2, 4, 3)
 | 
				
			||||||
        p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
 | 
					        p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
 | 
				
			||||||
        o = p @ v
 | 
					        q = p @ ve
 | 
				
			||||||
        return o.reshape(B, Hq, L, D)
 | 
					        q = q.reshape(B, Hq, L, D)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    for i in range(loops):
 | 
					 | 
				
			||||||
        q = _sdpa(q, k, v)
 | 
					 | 
				
			||||||
    return q
 | 
					    return q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def sdpa(q, k, v):
 | 
					def sdpa(q, k, v):
 | 
				
			||||||
    for i in range(loops):
 | 
					    for _ in range(loops):
 | 
				
			||||||
        q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
 | 
					        q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
 | 
				
			||||||
    return q
 | 
					    return q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def time_self_attention_primitives():
 | 
					def quant_sdpa(q, k, v, bits=4):
 | 
				
			||||||
    mx.random.seed(3)
 | 
					    for _ in range(loops):
 | 
				
			||||||
    q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
 | 
					        q = mx.fast.quantized_scaled_dot_product_attention(
 | 
				
			||||||
    k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
 | 
					            q, *k, *v, scale=1.0, mask=None, bits=bits
 | 
				
			||||||
    v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
 | 
					        )
 | 
				
			||||||
    mx.eval(q, k, v)
 | 
					    return q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def quant_attention(q, k, v, bits=4):
 | 
				
			||||||
 | 
					    for _ in range(loops):
 | 
				
			||||||
 | 
					        B, Hq, L, D = q.shape
 | 
				
			||||||
 | 
					        Hk = k[0].shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        q = q.reshape((B, Hk, Hq // Hk, L, D))
 | 
				
			||||||
 | 
					        ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
 | 
				
			||||||
 | 
					        ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits)
 | 
				
			||||||
 | 
					        scores = mx.softmax(scores, axis=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits)
 | 
				
			||||||
 | 
					        q = q.reshape((B, Hq, L, D))
 | 
				
			||||||
 | 
					    return q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def time_self_attention_primitives(q, k, v):
 | 
				
			||||||
    time_fn(attention, q, k, v)
 | 
					    time_fn(attention, q, k, v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def time_self_attention_sdpa():
 | 
					def time_self_attention_sdpa(q, k, v):
 | 
				
			||||||
    mx.random.seed(3)
 | 
					 | 
				
			||||||
    q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
 | 
					 | 
				
			||||||
    k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
 | 
					 | 
				
			||||||
    v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
 | 
					 | 
				
			||||||
    mx.eval(q, k, v)
 | 
					 | 
				
			||||||
    time_fn(sdpa, q, k, v)
 | 
					    time_fn(sdpa, q, k, v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def time_self_attention_quant_sdpa(q, k, v, bits=4):
 | 
				
			||||||
 | 
					    time_fn(quant_sdpa, q, k, v, bits)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def time_self_attention_quant_primitives(q, k, v, bits=4):
 | 
				
			||||||
 | 
					    time_fn(quant_attention, q, k, v, bits)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    time_self_attention_sdpa()
 | 
					    mx.random.seed(3)
 | 
				
			||||||
    time_self_attention_primitives()
 | 
					    q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype)
 | 
				
			||||||
 | 
					    k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
 | 
				
			||||||
 | 
					    v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
 | 
				
			||||||
 | 
					    mx.eval(q, k, v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    k_quant = mx.quantize(k, bits=bits)
 | 
				
			||||||
 | 
					    v_quant = mx.quantize(v, bits=bits)
 | 
				
			||||||
 | 
					    mx.eval(k_quant, v_quant)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    k = mx.dequantize(*k_quant, bits=bits)
 | 
				
			||||||
 | 
					    v = mx.dequantize(*v_quant, bits=bits)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    time_self_attention_sdpa(q, k, v)
 | 
				
			||||||
 | 
					    time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
 | 
				
			||||||
 | 
					    time_self_attention_primitives(q, k, v)
 | 
				
			||||||
 | 
					    time_self_attention_quant_primitives(q, k_quant, v_quant, bits)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -20,4 +20,33 @@ using namespace metal;
 | 
				
			|||||||
instantiate_sdpa_vector_heads(float)
 | 
					instantiate_sdpa_vector_heads(float)
 | 
				
			||||||
instantiate_sdpa_vector_heads(bfloat16_t)
 | 
					instantiate_sdpa_vector_heads(bfloat16_t)
 | 
				
			||||||
instantiate_sdpa_vector_heads(float16_t)
 | 
					instantiate_sdpa_vector_heads(float16_t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Quantized SDPA vector instantiations
 | 
				
			||||||
 | 
					#define instantiate_quant_sdpa_vector(name, type, head_dim, group_size, bits) \
 | 
				
			||||||
 | 
					  instantiate_kernel(                                                   \
 | 
				
			||||||
 | 
					    #name "_" #type "_" #head_dim "_" #group_size "_" #bits, \
 | 
				
			||||||
 | 
					    name, type, head_dim, group_size, bits)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define instantiate_quant_sdpa_vector_passes(type, heads, group_size, bits) \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector(quant_sdpa_vector, type, heads, group_size, bits)         \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector(quant_sdpa_vector_2pass_1, type, heads, group_size, bits)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector_passes(type, heads, group_size, 4)         \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector_passes(type, heads, group_size, 8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define instantiate_quant_sdpa_vector_group_size(type, heads) \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector_bits(type, heads, 32)         \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector_bits(type, heads, 64)         \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector_bits(type, heads, 128)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define instantiate_quant_sdpa_vector_heads(type) \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector_group_size(type, 64)         \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector_group_size(type, 96)         \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector_group_size(type, 128)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					instantiate_quant_sdpa_vector_heads(float)
 | 
				
			||||||
 | 
					instantiate_quant_sdpa_vector_heads(bfloat16_t)
 | 
				
			||||||
 | 
					instantiate_quant_sdpa_vector_heads(float16_t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // clang-format on
 | 
					    // clang-format on
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -113,6 +113,208 @@ template <typename T, int D>
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename T, typename U, int elem_per_thread, int bits>
 | 
				
			||||||
 | 
					METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) {
 | 
				
			||||||
 | 
					  U query_sum = 0;
 | 
				
			||||||
 | 
					  if (bits == 4) {
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread; i += 4) {
 | 
				
			||||||
 | 
					      q[i] = scale * queries[i];
 | 
				
			||||||
 | 
					      q[i + 1] = scale * queries[i + 1];
 | 
				
			||||||
 | 
					      q[i + 2] = scale * queries[i + 2];
 | 
				
			||||||
 | 
					      q[i + 3] = scale * queries[i + 3];
 | 
				
			||||||
 | 
					      query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3];
 | 
				
			||||||
 | 
					      q[i + 1] /= 16.0f;
 | 
				
			||||||
 | 
					      q[i + 2] /= 256.0f;
 | 
				
			||||||
 | 
					      q[i + 3] /= 4096.0f;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else if (bits == 8) {
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					      q[i] = scale * queries[i];
 | 
				
			||||||
 | 
					      query_sum += q[i];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return query_sum;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename U, int elem_per_thread, int bits>
 | 
				
			||||||
 | 
					METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) {
 | 
				
			||||||
 | 
					  if (bits == 4) {
 | 
				
			||||||
 | 
					    auto ks = (const device uint16_t*)keys;
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread / 4; i++) {
 | 
				
			||||||
 | 
					      k[4 * i] = ks[i] & 0x000f;
 | 
				
			||||||
 | 
					      k[4 * i + 1] = ks[i] & 0x00f0;
 | 
				
			||||||
 | 
					      k[4 * i + 2] = ks[i] & 0x0f00;
 | 
				
			||||||
 | 
					      k[4 * i + 3] = ks[i] & 0xf000;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else if (bits == 8) {
 | 
				
			||||||
 | 
					    auto ks = (const device uint8_t*)keys;
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					      k[i] = ks[i];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename U, int elem_per_thread, int bits>
 | 
				
			||||||
 | 
					METAL_FUNC void load_values(
 | 
				
			||||||
 | 
					    const device uint32_t* values,
 | 
				
			||||||
 | 
					    thread U* v,
 | 
				
			||||||
 | 
					    U value_scale,
 | 
				
			||||||
 | 
					    U value_bias) {
 | 
				
			||||||
 | 
					  auto vs = (const device uint8_t*)values;
 | 
				
			||||||
 | 
					  if (bits == 4) {
 | 
				
			||||||
 | 
					    U s[2] = {value_scale, value_scale / 16.0f};
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread / 2; i++) {
 | 
				
			||||||
 | 
					      v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias;
 | 
				
			||||||
 | 
					      v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else if (bits == 8) {
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					      v[i] = value_scale * vs[i] + value_bias;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename T, int D, int group_size, int bits>
 | 
				
			||||||
 | 
					[[kernel]] void quant_sdpa_vector(
 | 
				
			||||||
 | 
					    const device T* queries [[buffer(0)]],
 | 
				
			||||||
 | 
					    const device uint32_t* keys [[buffer(1)]],
 | 
				
			||||||
 | 
					    const device T* key_scales [[buffer(2)]],
 | 
				
			||||||
 | 
					    const device T* key_biases [[buffer(3)]],
 | 
				
			||||||
 | 
					    const device uint32_t* values [[buffer(4)]],
 | 
				
			||||||
 | 
					    const device T* value_scales [[buffer(5)]],
 | 
				
			||||||
 | 
					    const device T* value_biases [[buffer(6)]],
 | 
				
			||||||
 | 
					    device T* out [[buffer(7)]],
 | 
				
			||||||
 | 
					    const constant int& gqa_factor,
 | 
				
			||||||
 | 
					    const constant int& N,
 | 
				
			||||||
 | 
					    const constant size_t& k_stride,
 | 
				
			||||||
 | 
					    const constant size_t& group_stride,
 | 
				
			||||||
 | 
					    const constant float& scale,
 | 
				
			||||||
 | 
					    uint3 tid [[threadgroup_position_in_grid]],
 | 
				
			||||||
 | 
					    uint simd_gid [[simdgroup_index_in_threadgroup]],
 | 
				
			||||||
 | 
					    uint simd_lid [[thread_index_in_simdgroup]],
 | 
				
			||||||
 | 
					    uint quad_gid [[quadgroup_index_in_threadgroup]],
 | 
				
			||||||
 | 
					    uint quad_lid [[thread_index_in_quadgroup]]) {
 | 
				
			||||||
 | 
					  constexpr int BN = 32;
 | 
				
			||||||
 | 
					  constexpr int BD = 4;
 | 
				
			||||||
 | 
					  constexpr int elem_per_thread = D / BD;
 | 
				
			||||||
 | 
					  constexpr int pack_factor = 32 / bits;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const int stride = BN * D;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  typedef float U;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  thread U q[elem_per_thread];
 | 
				
			||||||
 | 
					  thread U k[elem_per_thread];
 | 
				
			||||||
 | 
					  thread U v[elem_per_thread];
 | 
				
			||||||
 | 
					  thread U o[elem_per_thread];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  threadgroup U outputs[BN * BD];
 | 
				
			||||||
 | 
					  threadgroup U max_scores[BN];
 | 
				
			||||||
 | 
					  threadgroup U sum_exp_scores[BN];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Adjust positions
 | 
				
			||||||
 | 
					  const int head_idx = tid.y;
 | 
				
			||||||
 | 
					  const int kv_head_idx = head_idx / gqa_factor;
 | 
				
			||||||
 | 
					  queries += head_idx * D + quad_lid * elem_per_thread;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const int kv_idx = quad_gid * D + quad_lid * elem_per_thread;
 | 
				
			||||||
 | 
					  const int packed_idx = kv_head_idx * k_stride + kv_idx / pack_factor;
 | 
				
			||||||
 | 
					  const int group_idx = kv_head_idx * group_stride + kv_idx / group_size;
 | 
				
			||||||
 | 
					  keys += packed_idx;
 | 
				
			||||||
 | 
					  key_scales += group_idx;
 | 
				
			||||||
 | 
					  key_biases += group_idx;
 | 
				
			||||||
 | 
					  values += packed_idx;
 | 
				
			||||||
 | 
					  value_scales += group_idx;
 | 
				
			||||||
 | 
					  value_biases += group_idx;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  out += head_idx * D + simd_gid * elem_per_thread;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Read the query and 0 the output accumulator
 | 
				
			||||||
 | 
					  U query_sum = load_queries<T, U, elem_per_thread, bits>(
 | 
				
			||||||
 | 
					      queries, q, static_cast<U>(scale));
 | 
				
			||||||
 | 
					  for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					    o[i] = 0;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  U max_score = -INFINITY;
 | 
				
			||||||
 | 
					  U sum_exp_score = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // For each key
 | 
				
			||||||
 | 
					  for (int i = quad_gid; i < N; i += BN) {
 | 
				
			||||||
 | 
					    load_keys<U, elem_per_thread, bits>(keys, k);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Assume D % group_size == 0 so all the keys are in the same group
 | 
				
			||||||
 | 
					    U key_scale = key_scales[0];
 | 
				
			||||||
 | 
					    U key_bias = key_biases[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Compute the i-th score
 | 
				
			||||||
 | 
					    U score = 0;
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					      score += q[i] * k[i];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    score = score * key_scale + query_sum * key_bias;
 | 
				
			||||||
 | 
					    score = quad_sum(score);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Update the accumulators
 | 
				
			||||||
 | 
					    U new_max = max(max_score, score);
 | 
				
			||||||
 | 
					    U factor = fast::exp(max_score - new_max);
 | 
				
			||||||
 | 
					    U exp_score = fast::exp(score - new_max);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    max_score = new_max;
 | 
				
			||||||
 | 
					    sum_exp_score = sum_exp_score * factor + exp_score;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    U value_scale = value_scales[0];
 | 
				
			||||||
 | 
					    U value_bias = value_biases[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Load the values
 | 
				
			||||||
 | 
					    load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Update the output accumulator
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					      o[i] = o[i] * factor + exp_score * v[i];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Move the pointers to the next kv
 | 
				
			||||||
 | 
					    keys += stride / pack_factor;
 | 
				
			||||||
 | 
					    key_scales += stride / group_size;
 | 
				
			||||||
 | 
					    key_biases += stride / group_size;
 | 
				
			||||||
 | 
					    values += stride / pack_factor;
 | 
				
			||||||
 | 
					    value_scales += stride / group_size;
 | 
				
			||||||
 | 
					    value_biases += stride / group_size;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Each thread has a partial part of the output so we need to combine them.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // First let's communicate the max and sum_exp
 | 
				
			||||||
 | 
					  // Each quadgroup communicates it's max score
 | 
				
			||||||
 | 
					  if (quad_lid == 0) {
 | 
				
			||||||
 | 
					    max_scores[quad_gid] = max_score;
 | 
				
			||||||
 | 
					    sum_exp_scores[quad_gid] = sum_exp_score;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
 | 
					  max_score = max_scores[simd_lid];
 | 
				
			||||||
 | 
					  U new_max = simd_max(max_score);
 | 
				
			||||||
 | 
					  U factor = fast::exp(max_score - new_max);
 | 
				
			||||||
 | 
					  sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Now we need to aggregate all the outputs
 | 
				
			||||||
 | 
					  for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					    // 128 threads with 32 values per thread
 | 
				
			||||||
 | 
					    outputs[simd_gid * BN + simd_lid] = o[i];
 | 
				
			||||||
 | 
					    threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
 | 
					    o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score;
 | 
				
			||||||
 | 
					    threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // And write the output
 | 
				
			||||||
 | 
					  if (simd_lid == 0) {
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					      out[i] = static_cast<T>(o[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename T, int D>
 | 
					template <typename T, int D>
 | 
				
			||||||
[[kernel]] void sdpa_vector_2pass_1(
 | 
					[[kernel]] void sdpa_vector_2pass_1(
 | 
				
			||||||
    const device T* queries [[buffer(0)]],
 | 
					    const device T* queries [[buffer(0)]],
 | 
				
			||||||
@@ -290,3 +492,158 @@ template <typename T, int D>
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename T, int D, int group_size, int bits>
 | 
				
			||||||
 | 
					[[kernel]] void quant_sdpa_vector_2pass_1(
 | 
				
			||||||
 | 
					    const device T* queries [[buffer(0)]],
 | 
				
			||||||
 | 
					    const device uint32_t* keys [[buffer(1)]],
 | 
				
			||||||
 | 
					    const device T* key_scales [[buffer(2)]],
 | 
				
			||||||
 | 
					    const device T* key_biases [[buffer(3)]],
 | 
				
			||||||
 | 
					    const device uint32_t* values [[buffer(4)]],
 | 
				
			||||||
 | 
					    const device T* value_scales [[buffer(5)]],
 | 
				
			||||||
 | 
					    const device T* value_biases [[buffer(6)]],
 | 
				
			||||||
 | 
					    device float* out [[buffer(7)]],
 | 
				
			||||||
 | 
					    device float* sums [[buffer(8)]],
 | 
				
			||||||
 | 
					    device float* maxs [[buffer(9)]],
 | 
				
			||||||
 | 
					    const constant int& gqa_factor,
 | 
				
			||||||
 | 
					    const constant int& N,
 | 
				
			||||||
 | 
					    const constant size_t& k_stride,
 | 
				
			||||||
 | 
					    const constant size_t& v_stride,
 | 
				
			||||||
 | 
					    const constant size_t& k_group_stride,
 | 
				
			||||||
 | 
					    const constant size_t& v_group_stride,
 | 
				
			||||||
 | 
					    const constant float& scale,
 | 
				
			||||||
 | 
					    uint3 tid [[threadgroup_position_in_grid]],
 | 
				
			||||||
 | 
					    uint simd_gid [[simdgroup_index_in_threadgroup]],
 | 
				
			||||||
 | 
					    uint simd_lid [[thread_index_in_simdgroup]],
 | 
				
			||||||
 | 
					    uint quad_gid [[quadgroup_index_in_threadgroup]],
 | 
				
			||||||
 | 
					    uint quad_lid [[thread_index_in_quadgroup]]) {
 | 
				
			||||||
 | 
					  constexpr int BN = 8;
 | 
				
			||||||
 | 
					  constexpr int BD = 4;
 | 
				
			||||||
 | 
					  constexpr int elem_per_thread = D / BD;
 | 
				
			||||||
 | 
					  const int stride = BN * D;
 | 
				
			||||||
 | 
					  constexpr int blocks = 32;
 | 
				
			||||||
 | 
					  constexpr int pack_factor = 32 / bits;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  typedef float U;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  thread U q[elem_per_thread];
 | 
				
			||||||
 | 
					  thread U k[elem_per_thread];
 | 
				
			||||||
 | 
					  thread U v[elem_per_thread];
 | 
				
			||||||
 | 
					  thread U o[elem_per_thread];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  threadgroup U outputs[BN * BD];
 | 
				
			||||||
 | 
					  threadgroup U max_scores[BN];
 | 
				
			||||||
 | 
					  threadgroup U sum_exp_scores[BN];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Adjust positions
 | 
				
			||||||
 | 
					  const int block_idx = tid.z;
 | 
				
			||||||
 | 
					  const int head_idx = tid.y;
 | 
				
			||||||
 | 
					  const int kv_head_idx = head_idx / gqa_factor;
 | 
				
			||||||
 | 
					  queries += head_idx * D + quad_lid * elem_per_thread;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const int kv_idx =
 | 
				
			||||||
 | 
					      (block_idx * BN + quad_gid) * D + quad_lid * elem_per_thread;
 | 
				
			||||||
 | 
					  const int packed_idx = kv_idx / pack_factor;
 | 
				
			||||||
 | 
					  const int k_group_idx = kv_head_idx * k_group_stride + kv_idx / group_size;
 | 
				
			||||||
 | 
					  const int v_group_idx = kv_head_idx * v_group_stride + kv_idx / group_size;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  keys += kv_head_idx * k_stride + packed_idx;
 | 
				
			||||||
 | 
					  key_scales += k_group_idx;
 | 
				
			||||||
 | 
					  key_biases += k_group_idx;
 | 
				
			||||||
 | 
					  values += kv_head_idx * v_stride + packed_idx;
 | 
				
			||||||
 | 
					  value_scales += v_group_idx;
 | 
				
			||||||
 | 
					  value_biases += v_group_idx;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  out += head_idx * blocks * D + block_idx * D + quad_lid * elem_per_thread;
 | 
				
			||||||
 | 
					  sums += head_idx * blocks + block_idx;
 | 
				
			||||||
 | 
					  maxs += head_idx * blocks + block_idx;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Read the query and 0 the output accumulator
 | 
				
			||||||
 | 
					  U query_sum = load_queries<T, U, elem_per_thread, bits>(
 | 
				
			||||||
 | 
					      queries, q, static_cast<U>(scale));
 | 
				
			||||||
 | 
					  for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					    o[i] = 0;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  U max_score = -1e9;
 | 
				
			||||||
 | 
					  U sum_exp_score = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // For each key
 | 
				
			||||||
 | 
					  for (int i = block_idx * BN + quad_gid; i < N; i += blocks * BN) {
 | 
				
			||||||
 | 
					    // Read the key
 | 
				
			||||||
 | 
					    load_keys<U, elem_per_thread, bits>(keys, k);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Assume D % group_size == 0 so all the keys are in the same group
 | 
				
			||||||
 | 
					    U key_scale = key_scales[0];
 | 
				
			||||||
 | 
					    U key_bias = key_biases[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Compute the i-th score
 | 
				
			||||||
 | 
					    U score = 0;
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					      score += q[i] * k[i];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    score = score * key_scale + query_sum * key_bias;
 | 
				
			||||||
 | 
					    score = quad_sum(score);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Update the accumulators
 | 
				
			||||||
 | 
					    U new_max = max(max_score, score);
 | 
				
			||||||
 | 
					    U factor = fast::exp(max_score - new_max);
 | 
				
			||||||
 | 
					    U exp_score = fast::exp(score - new_max);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    max_score = new_max;
 | 
				
			||||||
 | 
					    sum_exp_score = sum_exp_score * factor + exp_score;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    U value_scale = value_scales[0];
 | 
				
			||||||
 | 
					    U value_bias = value_biases[0];
 | 
				
			||||||
 | 
					    load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Update the output accumulator
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					      o[i] = o[i] * factor + exp_score * v[i];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Move the pointers to the next kv
 | 
				
			||||||
 | 
					    keys += blocks * stride / pack_factor;
 | 
				
			||||||
 | 
					    key_scales += blocks * stride / group_size;
 | 
				
			||||||
 | 
					    key_biases += blocks * stride / group_size;
 | 
				
			||||||
 | 
					    values += blocks * stride / pack_factor;
 | 
				
			||||||
 | 
					    value_scales += blocks * stride / group_size;
 | 
				
			||||||
 | 
					    value_biases += blocks * stride / group_size;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Each thread has a partial part of the output so we need to combine them.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // First let's communicate the max and sum_exp
 | 
				
			||||||
 | 
					  if (quad_lid == 0) {
 | 
				
			||||||
 | 
					    max_scores[quad_gid] = max_score;
 | 
				
			||||||
 | 
					    sum_exp_scores[quad_gid] = sum_exp_score;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
 | 
					  max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
 | 
				
			||||||
 | 
					  U new_max = simd_max(max_score);
 | 
				
			||||||
 | 
					  U factor = fast::exp(max_score - new_max);
 | 
				
			||||||
 | 
					  sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
 | 
				
			||||||
 | 
					  sum_exp_score = simd_sum(sum_exp_score * factor);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Write the sum and new max
 | 
				
			||||||
 | 
					  if (simd_gid == 0) {
 | 
				
			||||||
 | 
					    sums[0] = sum_exp_score;
 | 
				
			||||||
 | 
					    maxs[0] = new_max;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Now we need to aggregate all the outputs
 | 
				
			||||||
 | 
					  for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					    outputs[quad_lid * BN + quad_gid] =
 | 
				
			||||||
 | 
					        o[i] * fast::exp(max_scores[quad_gid] - new_max);
 | 
				
			||||||
 | 
					    threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (quad_gid == 0) {
 | 
				
			||||||
 | 
					      U output = outputs[quad_lid * BN];
 | 
				
			||||||
 | 
					      for (int j = 1; j < BN; j++) {
 | 
				
			||||||
 | 
					        output += outputs[quad_lid * BN + j];
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      out[i] = static_cast<T>(output);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -242,6 +242,171 @@ void sdpa_vector_2pass(
 | 
				
			|||||||
  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
 | 
					  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void quant_sdpa_vector(
 | 
				
			||||||
 | 
					    const Stream& s,
 | 
				
			||||||
 | 
					    metal::Device& d,
 | 
				
			||||||
 | 
					    const array& q,
 | 
				
			||||||
 | 
					    const array& k,
 | 
				
			||||||
 | 
					    const array& k_scales,
 | 
				
			||||||
 | 
					    const array& k_biases,
 | 
				
			||||||
 | 
					    const array& v,
 | 
				
			||||||
 | 
					    const array& v_scales,
 | 
				
			||||||
 | 
					    const array& v_biases,
 | 
				
			||||||
 | 
					    array& out,
 | 
				
			||||||
 | 
					    float scale,
 | 
				
			||||||
 | 
					    int group_size,
 | 
				
			||||||
 | 
					    int bits) {
 | 
				
			||||||
 | 
					  // Set the kernel name
 | 
				
			||||||
 | 
					  std::string kname;
 | 
				
			||||||
 | 
					  kname.reserve(96);
 | 
				
			||||||
 | 
					  kname += "quant_sdpa_vector_";
 | 
				
			||||||
 | 
					  kname += get_type_string(q.dtype());
 | 
				
			||||||
 | 
					  kname += "_";
 | 
				
			||||||
 | 
					  kname += std::to_string(q.shape(-1));
 | 
				
			||||||
 | 
					  kname += "_";
 | 
				
			||||||
 | 
					  kname += std::to_string(group_size);
 | 
				
			||||||
 | 
					  kname += "_";
 | 
				
			||||||
 | 
					  kname += std::to_string(bits);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Compute the necessary sizes
 | 
				
			||||||
 | 
					  int gqa_factor = q.shape(1) / k.shape(1);
 | 
				
			||||||
 | 
					  int N = k.shape(2);
 | 
				
			||||||
 | 
					  int B = q.shape(0) * q.shape(1);
 | 
				
			||||||
 | 
					  size_t stride = k.strides()[1];
 | 
				
			||||||
 | 
					  size_t group_stride = k_scales.strides()[1];
 | 
				
			||||||
 | 
					  MTL::Size group_dims(128, 1, 1);
 | 
				
			||||||
 | 
					  MTL::Size grid_dims(1, B, 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Get the kernel
 | 
				
			||||||
 | 
					  auto& compute_encoder = d.get_command_encoder(s.index);
 | 
				
			||||||
 | 
					  auto kernel = d.get_kernel(kname);
 | 
				
			||||||
 | 
					  compute_encoder.set_compute_pipeline_state(kernel);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Set its arguments
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(k, 1);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(k_scales, 2);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(k_biases, 3);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(v, 4);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(v_scales, 5);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(v_biases, 6);
 | 
				
			||||||
 | 
					  compute_encoder.set_output_array(out, 7);
 | 
				
			||||||
 | 
					  compute_encoder.set_bytes(&gqa_factor, sizeof(int), 8);
 | 
				
			||||||
 | 
					  compute_encoder.set_bytes(&N, sizeof(int), 9);
 | 
				
			||||||
 | 
					  compute_encoder.set_bytes(&stride, sizeof(size_t), 10);
 | 
				
			||||||
 | 
					  compute_encoder.set_bytes(&group_stride, sizeof(size_t), 11);
 | 
				
			||||||
 | 
					  compute_encoder.set_bytes(&scale, sizeof(float), 12);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Launch
 | 
				
			||||||
 | 
					  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void quant_sdpa_vector_2pass(
 | 
				
			||||||
 | 
					    const Stream& s,
 | 
				
			||||||
 | 
					    metal::Device& d,
 | 
				
			||||||
 | 
					    const array& q,
 | 
				
			||||||
 | 
					    const array& k,
 | 
				
			||||||
 | 
					    const array& k_scales,
 | 
				
			||||||
 | 
					    const array& k_biases,
 | 
				
			||||||
 | 
					    const array& v,
 | 
				
			||||||
 | 
					    const array& v_scales,
 | 
				
			||||||
 | 
					    const array& v_biases,
 | 
				
			||||||
 | 
					    array& out,
 | 
				
			||||||
 | 
					    float scale,
 | 
				
			||||||
 | 
					    int group_size,
 | 
				
			||||||
 | 
					    int bits) {
 | 
				
			||||||
 | 
					  // Set the kernel name
 | 
				
			||||||
 | 
					  std::string kname;
 | 
				
			||||||
 | 
					  kname.reserve(96);
 | 
				
			||||||
 | 
					  kname += "quant_sdpa_vector_2pass_1_";
 | 
				
			||||||
 | 
					  kname += get_type_string(q.dtype());
 | 
				
			||||||
 | 
					  kname += "_";
 | 
				
			||||||
 | 
					  kname += std::to_string(q.shape(-1));
 | 
				
			||||||
 | 
					  kname += "_";
 | 
				
			||||||
 | 
					  kname += std::to_string(group_size);
 | 
				
			||||||
 | 
					  kname += "_";
 | 
				
			||||||
 | 
					  kname += std::to_string(bits);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Compute the necessary sizes
 | 
				
			||||||
 | 
					  int gqa_factor = q.shape(1) / k.shape(1);
 | 
				
			||||||
 | 
					  int N = k.shape(2);
 | 
				
			||||||
 | 
					  int blocks = 32;
 | 
				
			||||||
 | 
					  int B = q.shape(0) * q.shape(1);
 | 
				
			||||||
 | 
					  size_t k_stride = k.strides()[1];
 | 
				
			||||||
 | 
					  size_t v_stride = v.strides()[1];
 | 
				
			||||||
 | 
					  size_t k_group_stride = k_scales.strides()[1];
 | 
				
			||||||
 | 
					  size_t v_group_stride = v_scales.strides()[1];
 | 
				
			||||||
 | 
					  MTL::Size group_dims(8 * 4, 1, 1);
 | 
				
			||||||
 | 
					  MTL::Size grid_dims(1, B, blocks);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Allocate the intermediates
 | 
				
			||||||
 | 
					  std::vector<int> intermediate_shape;
 | 
				
			||||||
 | 
					  intermediate_shape.reserve(out.ndim() + 1);
 | 
				
			||||||
 | 
					  intermediate_shape.insert(
 | 
				
			||||||
 | 
					      intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1);
 | 
				
			||||||
 | 
					  intermediate_shape.push_back(blocks);
 | 
				
			||||||
 | 
					  intermediate_shape.push_back(out.shape().back());
 | 
				
			||||||
 | 
					  array intermediate(intermediate_shape, float32, nullptr, {});
 | 
				
			||||||
 | 
					  intermediate_shape.pop_back();
 | 
				
			||||||
 | 
					  array sums(intermediate_shape, float32, nullptr, {});
 | 
				
			||||||
 | 
					  array maxs(std::move(intermediate_shape), float32, nullptr, {});
 | 
				
			||||||
 | 
					  intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
 | 
				
			||||||
 | 
					  sums.set_data(allocator::malloc_or_wait(sums.nbytes()));
 | 
				
			||||||
 | 
					  maxs.set_data(allocator::malloc_or_wait(maxs.nbytes()));
 | 
				
			||||||
 | 
					  d.add_temporary(intermediate, s.index);
 | 
				
			||||||
 | 
					  d.add_temporary(sums, s.index);
 | 
				
			||||||
 | 
					  d.add_temporary(maxs, s.index);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Get the kernel
 | 
				
			||||||
 | 
					  auto& compute_encoder = d.get_command_encoder(s.index);
 | 
				
			||||||
 | 
					  auto kernel = d.get_kernel(kname);
 | 
				
			||||||
 | 
					  compute_encoder.set_compute_pipeline_state(kernel);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Set its arguments
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(k, 1);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(k_scales, 2);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(k_biases, 3);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(v, 4);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(v_scales, 5);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(v_biases, 6);
 | 
				
			||||||
 | 
					  compute_encoder.set_output_array(intermediate, 7);
 | 
				
			||||||
 | 
					  compute_encoder.set_output_array(sums, 8);
 | 
				
			||||||
 | 
					  compute_encoder.set_output_array(maxs, 9);
 | 
				
			||||||
 | 
					  compute_encoder.set_bytes(gqa_factor, 10);
 | 
				
			||||||
 | 
					  compute_encoder.set_bytes(N, 11);
 | 
				
			||||||
 | 
					  compute_encoder.set_bytes(k_stride, 12);
 | 
				
			||||||
 | 
					  compute_encoder.set_bytes(v_stride, 13);
 | 
				
			||||||
 | 
					  compute_encoder.set_bytes(k_group_stride, 14);
 | 
				
			||||||
 | 
					  compute_encoder.set_bytes(v_group_stride, 15);
 | 
				
			||||||
 | 
					  compute_encoder.set_bytes(scale, 16);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Launch
 | 
				
			||||||
 | 
					  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Final pass
 | 
				
			||||||
 | 
					  kname.clear();
 | 
				
			||||||
 | 
					  kname += "sdpa_vector_2pass_2_";
 | 
				
			||||||
 | 
					  kname += get_type_string(q.dtype());
 | 
				
			||||||
 | 
					  kname += "_";
 | 
				
			||||||
 | 
					  kname += std::to_string(q.shape(-1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Get the kernel
 | 
				
			||||||
 | 
					  kernel = d.get_kernel(kname);
 | 
				
			||||||
 | 
					  compute_encoder.set_compute_pipeline_state(kernel);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Set its arguments
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(intermediate, 0);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(sums, 1);
 | 
				
			||||||
 | 
					  compute_encoder.set_input_array(maxs, 2);
 | 
				
			||||||
 | 
					  compute_encoder.set_output_array(out, 3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Launch
 | 
				
			||||||
 | 
					  group_dims = MTL::Size(1024, 1, 1);
 | 
				
			||||||
 | 
					  grid_dims = MTL::Size(1, B, 1);
 | 
				
			||||||
 | 
					  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace
 | 
					} // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void ScaledDotProductAttention::eval_gpu(
 | 
					void ScaledDotProductAttention::eval_gpu(
 | 
				
			||||||
@@ -254,7 +419,6 @@ void ScaledDotProductAttention::eval_gpu(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  auto& q_pre = inputs[0];
 | 
					  auto& q_pre = inputs[0];
 | 
				
			||||||
  auto& k_pre = inputs[1];
 | 
					  auto& k_pre = inputs[1];
 | 
				
			||||||
  auto& v_pre = inputs[2];
 | 
					 | 
				
			||||||
  auto& o = out;
 | 
					  auto& o = out;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::vector<array> copies;
 | 
					  std::vector<array> copies;
 | 
				
			||||||
@@ -295,9 +459,7 @@ void ScaledDotProductAttention::eval_gpu(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  // We are in vector mode ie single query
 | 
					  // We are in vector mode ie single query
 | 
				
			||||||
  if (q_pre.shape(2) == 1) {
 | 
					  if (q_pre.shape(2) == 1) {
 | 
				
			||||||
    const auto& q = copy_unless(is_contiguous, q_pre);
 | 
					    auto q = copy_unless(is_contiguous, q_pre);
 | 
				
			||||||
    const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
 | 
					 | 
				
			||||||
    const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Donate the query if possible
 | 
					    // Donate the query if possible
 | 
				
			||||||
    if (q.is_donatable()) {
 | 
					    if (q.is_donatable()) {
 | 
				
			||||||
@@ -306,20 +468,55 @@ void ScaledDotProductAttention::eval_gpu(
 | 
				
			|||||||
      o.set_data(allocator::malloc_or_wait(o.nbytes()));
 | 
					      o.set_data(allocator::malloc_or_wait(o.nbytes()));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // We route to the 2 pass fused attention if
 | 
					    if (quantized_) {
 | 
				
			||||||
    // - The device is large and the sequence length long
 | 
					      auto& k_scales_pre = inputs[2];
 | 
				
			||||||
    // - The sequence length is even longer and we have gqa
 | 
					      auto& k_biases_pre = inputs[3];
 | 
				
			||||||
    char devc = d.get_architecture().back();
 | 
					      auto& v_pre = inputs[4];
 | 
				
			||||||
    if ((devc == 'd' && k.shape(2) >= 1024) ||
 | 
					      auto& v_scales_pre = inputs[5];
 | 
				
			||||||
        (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
 | 
					      auto& v_biases_pre = inputs[6];
 | 
				
			||||||
      sdpa_vector_2pass(s, d, q, k, v, o, scale_);
 | 
					
 | 
				
			||||||
 | 
					      auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
 | 
				
			||||||
 | 
					      auto k_scales = copy_unless(is_contiguous_except_seq_len, k_scales_pre);
 | 
				
			||||||
 | 
					      auto k_biases = copy_unless(is_contiguous_except_seq_len, k_biases_pre);
 | 
				
			||||||
 | 
					      auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
 | 
				
			||||||
 | 
					      auto v_scales = copy_unless(is_contiguous_except_seq_len, v_scales_pre);
 | 
				
			||||||
 | 
					      auto v_biases = copy_unless(is_contiguous_except_seq_len, v_biases_pre);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      quant_sdpa_vector_2pass(
 | 
				
			||||||
 | 
					          s,
 | 
				
			||||||
 | 
					          d,
 | 
				
			||||||
 | 
					          q,
 | 
				
			||||||
 | 
					          k,
 | 
				
			||||||
 | 
					          k_scales,
 | 
				
			||||||
 | 
					          k_biases,
 | 
				
			||||||
 | 
					          v,
 | 
				
			||||||
 | 
					          v_scales,
 | 
				
			||||||
 | 
					          v_biases,
 | 
				
			||||||
 | 
					          o,
 | 
				
			||||||
 | 
					          scale_,
 | 
				
			||||||
 | 
					          group_size_,
 | 
				
			||||||
 | 
					          bits_);
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      sdpa_vector(s, d, q, k, v, o, scale_);
 | 
					      auto& k_pre = inputs[1];
 | 
				
			||||||
 | 
					      auto& v_pre = inputs[2];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
 | 
				
			||||||
 | 
					      const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      char devc = d.get_architecture().back();
 | 
				
			||||||
 | 
					      if ((devc == 'd' && k.shape(2) >= 1024) ||
 | 
				
			||||||
 | 
					          (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
 | 
				
			||||||
 | 
					        sdpa_vector_2pass(s, d, q, k, v, o, scale_);
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        sdpa_vector(s, d, q, k, v, o, scale_);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Full attention mode
 | 
					  // Full attention mode
 | 
				
			||||||
  else {
 | 
					  else {
 | 
				
			||||||
 | 
					    auto& v_pre = inputs[2];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const auto& q = copy_unless(is_matrix_contiguous, q_pre);
 | 
					    const auto& q = copy_unless(is_matrix_contiguous, q_pre);
 | 
				
			||||||
    const auto& k = copy_unless(is_matrix_contiguous, k_pre);
 | 
					    const auto& k = copy_unless(is_matrix_contiguous, k_pre);
 | 
				
			||||||
    const auto& v = copy_unless(is_matrix_contiguous, v_pre);
 | 
					    const auto& v = copy_unless(is_matrix_contiguous, v_pre);
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										127
									
								
								mlx/fast.cpp
									
									
									
									
									
								
							
							
						
						
									
										127
									
								
								mlx/fast.cpp
									
									
									
									
									
								
							@@ -664,7 +664,7 @@ array scaled_dot_product_attention(
 | 
				
			|||||||
        std::move(out_shape),
 | 
					        std::move(out_shape),
 | 
				
			||||||
        final_type,
 | 
					        final_type,
 | 
				
			||||||
        std::make_shared<ScaledDotProductAttention>(
 | 
					        std::make_shared<ScaledDotProductAttention>(
 | 
				
			||||||
            stream, fallback, scale, false),
 | 
					            stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/false),
 | 
				
			||||||
        {q, k, v});
 | 
					        {q, k, v});
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -678,7 +678,130 @@ array scaled_dot_product_attention(
 | 
				
			|||||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
 | 
					bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
 | 
				
			||||||
  const ScaledDotProductAttention& a_other =
 | 
					  const ScaledDotProductAttention& a_other =
 | 
				
			||||||
      static_cast<const ScaledDotProductAttention&>(other);
 | 
					      static_cast<const ScaledDotProductAttention&>(other);
 | 
				
			||||||
  return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
 | 
					  return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_ &&
 | 
				
			||||||
 | 
					      quantized_ == a_other.quantized_;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					array quantized_scaled_dot_product_attention(
 | 
				
			||||||
 | 
					    const array& queries,
 | 
				
			||||||
 | 
					    const array& keys,
 | 
				
			||||||
 | 
					    const array& key_scales,
 | 
				
			||||||
 | 
					    const array& key_biases,
 | 
				
			||||||
 | 
					    const array& values,
 | 
				
			||||||
 | 
					    const array& value_scales,
 | 
				
			||||||
 | 
					    const array& value_biases,
 | 
				
			||||||
 | 
					    const float scale,
 | 
				
			||||||
 | 
					    const std::optional<array>& mask,
 | 
				
			||||||
 | 
					    const int group_size,
 | 
				
			||||||
 | 
					    const int bits,
 | 
				
			||||||
 | 
					    StreamOrDevice s) {
 | 
				
			||||||
 | 
					  int el_per_int = 32 / bits;
 | 
				
			||||||
 | 
					  int out_dim = values.shape(-1) * el_per_int;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto n_q_heads = queries.shape(-3);
 | 
				
			||||||
 | 
					  auto n_kv_heads = keys.shape(-3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto out_shape = std::vector<int>(
 | 
				
			||||||
 | 
					      {queries.shape(0), queries.shape(1), queries.shape(2), out_dim});
 | 
				
			||||||
 | 
					  auto stream = to_stream(s);
 | 
				
			||||||
 | 
					  bool needs_mask = mask.has_value();
 | 
				
			||||||
 | 
					  auto fallback =
 | 
				
			||||||
 | 
					      [scale, needs_mask, n_q_heads, n_kv_heads, group_size, bits, &s](
 | 
				
			||||||
 | 
					          const std::vector<array>& inputs) -> std::vector<array> {
 | 
				
			||||||
 | 
					    int n_repeats = n_q_heads / n_kv_heads;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto k = inputs[1];
 | 
				
			||||||
 | 
					    auto k_scales = inputs[2];
 | 
				
			||||||
 | 
					    auto k_biases = inputs[3];
 | 
				
			||||||
 | 
					    auto v = inputs[4];
 | 
				
			||||||
 | 
					    auto v_scales = inputs[5];
 | 
				
			||||||
 | 
					    auto v_biases = inputs[6];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int B = q.shape(0);
 | 
				
			||||||
 | 
					    int L = q.shape(2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (n_repeats > 1) {
 | 
				
			||||||
 | 
					      q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s);
 | 
				
			||||||
 | 
					      k = expand_dims(k, 2, s);
 | 
				
			||||||
 | 
					      k_scales = expand_dims(k_scales, 2, s);
 | 
				
			||||||
 | 
					      k_biases = expand_dims(k_biases, 2, s);
 | 
				
			||||||
 | 
					      v = expand_dims(v, 2, s);
 | 
				
			||||||
 | 
					      v_scales = expand_dims(v_scales, 2, s);
 | 
				
			||||||
 | 
					      v_biases = expand_dims(v_biases, 2, s);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    array scores = quantized_matmul(
 | 
				
			||||||
 | 
					        q,
 | 
				
			||||||
 | 
					        k,
 | 
				
			||||||
 | 
					        k_scales,
 | 
				
			||||||
 | 
					        k_biases,
 | 
				
			||||||
 | 
					        /*transpose=*/true,
 | 
				
			||||||
 | 
					        /*group_size=*/group_size,
 | 
				
			||||||
 | 
					        /*bits=*/bits,
 | 
				
			||||||
 | 
					        s);
 | 
				
			||||||
 | 
					    if (needs_mask) {
 | 
				
			||||||
 | 
					      scores = add(scores, inputs[7], s);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    scores = softmax(scores, std::vector<int>{-1}, true, s);
 | 
				
			||||||
 | 
					    array out = quantized_matmul(
 | 
				
			||||||
 | 
					        scores,
 | 
				
			||||||
 | 
					        v,
 | 
				
			||||||
 | 
					        v_scales,
 | 
				
			||||||
 | 
					        v_biases,
 | 
				
			||||||
 | 
					        /*transpose=*/false,
 | 
				
			||||||
 | 
					        /*group_size=*/group_size,
 | 
				
			||||||
 | 
					        /*bits=*/bits,
 | 
				
			||||||
 | 
					        s);
 | 
				
			||||||
 | 
					    if (n_repeats > 1) {
 | 
				
			||||||
 | 
					      out = reshape(out, {B, n_q_heads, L, -1}, s);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return std::vector<array>{out};
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int L = queries.shape(2);
 | 
				
			||||||
 | 
					  if (L > 1) {
 | 
				
			||||||
 | 
					    if (needs_mask) {
 | 
				
			||||||
 | 
					      return fallback(
 | 
				
			||||||
 | 
					          {queries,
 | 
				
			||||||
 | 
					           keys,
 | 
				
			||||||
 | 
					           key_scales,
 | 
				
			||||||
 | 
					           key_biases,
 | 
				
			||||||
 | 
					           values,
 | 
				
			||||||
 | 
					           value_scales,
 | 
				
			||||||
 | 
					           value_biases,
 | 
				
			||||||
 | 
					           mask.value()})[0];
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      return fallback(
 | 
				
			||||||
 | 
					          {queries,
 | 
				
			||||||
 | 
					           keys,
 | 
				
			||||||
 | 
					           key_scales,
 | 
				
			||||||
 | 
					           key_biases,
 | 
				
			||||||
 | 
					           values,
 | 
				
			||||||
 | 
					           value_scales,
 | 
				
			||||||
 | 
					           value_biases})[0];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    return array(
 | 
				
			||||||
 | 
					        std::move(out_shape),
 | 
				
			||||||
 | 
					        queries.dtype(),
 | 
				
			||||||
 | 
					        std::make_shared<ScaledDotProductAttention>(
 | 
				
			||||||
 | 
					            stream,
 | 
				
			||||||
 | 
					            fallback,
 | 
				
			||||||
 | 
					            scale,
 | 
				
			||||||
 | 
					            /*needs_mask=*/false,
 | 
				
			||||||
 | 
					            /*quantized=*/true,
 | 
				
			||||||
 | 
					            group_size,
 | 
				
			||||||
 | 
					            bits),
 | 
				
			||||||
 | 
					        {queries,
 | 
				
			||||||
 | 
					         keys,
 | 
				
			||||||
 | 
					         key_scales,
 | 
				
			||||||
 | 
					         key_biases,
 | 
				
			||||||
 | 
					         values,
 | 
				
			||||||
 | 
					         value_scales,
 | 
				
			||||||
 | 
					         value_biases});
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
array pack_and_quantize(
 | 
					array pack_and_quantize(
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										15
									
								
								mlx/fast.h
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								mlx/fast.h
									
									
									
									
									
								
							@@ -41,6 +41,21 @@ array scaled_dot_product_attention(
 | 
				
			|||||||
    const std::optional<int> memory_efficient_threshold = std::nullopt,
 | 
					    const std::optional<int> memory_efficient_threshold = std::nullopt,
 | 
				
			||||||
    StreamOrDevice s = {});
 | 
					    StreamOrDevice s = {});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Computes: `O = softmax(Q @ K.T) @ V` where K and V are quantized. **/
 | 
				
			||||||
 | 
					array quantized_scaled_dot_product_attention(
 | 
				
			||||||
 | 
					    const array& queries,
 | 
				
			||||||
 | 
					    const array& keys,
 | 
				
			||||||
 | 
					    const array& key_scales,
 | 
				
			||||||
 | 
					    const array& key_biases,
 | 
				
			||||||
 | 
					    const array& values,
 | 
				
			||||||
 | 
					    const array& value_scales,
 | 
				
			||||||
 | 
					    const array& value_biases,
 | 
				
			||||||
 | 
					    const float scale,
 | 
				
			||||||
 | 
					    const std::optional<array>& mask = std::nullopt,
 | 
				
			||||||
 | 
					    const int group_size = 64,
 | 
				
			||||||
 | 
					    const int bits = 4,
 | 
				
			||||||
 | 
					    StreamOrDevice s = {});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::tuple<array, array, array> affine_quantize(
 | 
					std::tuple<array, array, array> affine_quantize(
 | 
				
			||||||
    const array& w,
 | 
					    const array& w,
 | 
				
			||||||
    int group_size = 64,
 | 
					    int group_size = 64,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -190,8 +190,16 @@ class ScaledDotProductAttention : public Custom {
 | 
				
			|||||||
      Stream stream,
 | 
					      Stream stream,
 | 
				
			||||||
      std::function<std::vector<array>(std::vector<array>)> fallback,
 | 
					      std::function<std::vector<array>(std::vector<array>)> fallback,
 | 
				
			||||||
      const float scale,
 | 
					      const float scale,
 | 
				
			||||||
      const bool needs_mask)
 | 
					      const bool needs_mask,
 | 
				
			||||||
      : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
 | 
					      const bool quantized,
 | 
				
			||||||
 | 
					      const int group_size = 64,
 | 
				
			||||||
 | 
					      const int bits = 4)
 | 
				
			||||||
 | 
					      : Custom(stream, fallback),
 | 
				
			||||||
 | 
					        scale_(scale),
 | 
				
			||||||
 | 
					        needs_mask_(needs_mask),
 | 
				
			||||||
 | 
					        quantized_(quantized),
 | 
				
			||||||
 | 
					        group_size_(group_size),
 | 
				
			||||||
 | 
					        bits_(bits) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
 | 
					  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
 | 
				
			||||||
      override {
 | 
					      override {
 | 
				
			||||||
@@ -212,6 +220,9 @@ class ScaledDotProductAttention : public Custom {
 | 
				
			|||||||
  std::function<std::vector<array>(std::vector<array>)> fallback_;
 | 
					  std::function<std::vector<array>(std::vector<array>)> fallback_;
 | 
				
			||||||
  float scale_;
 | 
					  float scale_;
 | 
				
			||||||
  bool needs_mask_;
 | 
					  bool needs_mask_;
 | 
				
			||||||
 | 
					  bool quantized_;
 | 
				
			||||||
 | 
					  int group_size_;
 | 
				
			||||||
 | 
					  int bits_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AffineQuantize : public Custom {
 | 
					class AffineQuantize : public Custom {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -161,6 +161,45 @@ void init_fast(nb::module_& parent_module) {
 | 
				
			|||||||
            array: The output array.
 | 
					            array: The output array.
 | 
				
			||||||
      )pbdoc");
 | 
					      )pbdoc");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  m.def(
 | 
				
			||||||
 | 
					      "quantized_scaled_dot_product_attention",
 | 
				
			||||||
 | 
					      &fast::quantized_scaled_dot_product_attention,
 | 
				
			||||||
 | 
					      "q"_a,
 | 
				
			||||||
 | 
					      "k"_a,
 | 
				
			||||||
 | 
					      "k_scales"_a,
 | 
				
			||||||
 | 
					      "k_biases"_a,
 | 
				
			||||||
 | 
					      "v"_a,
 | 
				
			||||||
 | 
					      "v_scales"_a,
 | 
				
			||||||
 | 
					      "v_biases"_a,
 | 
				
			||||||
 | 
					      nb::kw_only(),
 | 
				
			||||||
 | 
					      "scale"_a,
 | 
				
			||||||
 | 
					      "mask"_a = nb::none(),
 | 
				
			||||||
 | 
					      "group_size"_a = 64,
 | 
				
			||||||
 | 
					      "bits"_a = 4,
 | 
				
			||||||
 | 
					      "stream"_a = nb::none(),
 | 
				
			||||||
 | 
					      nb::sig(
 | 
				
			||||||
 | 
					          "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float,  mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
 | 
				
			||||||
 | 
					      R"pbdoc(
 | 
				
			||||||
 | 
					        A fast implementation of multi-head attention where the keys and values are quantized.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        see :func:`scaled_dot_product_attention` for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            q (array): Input query array.
 | 
				
			||||||
 | 
					            k (array): Input keys array.
 | 
				
			||||||
 | 
					            k_scales (array): Scales for the quantized keys array.
 | 
				
			||||||
 | 
					            k_biases (array): Biases for the quantized keys array.
 | 
				
			||||||
 | 
					            v (array): Input values array.
 | 
				
			||||||
 | 
					            v_scales (array): Scales for the quantized values array.
 | 
				
			||||||
 | 
					            v_biases (array): Biases for the quantized values array.
 | 
				
			||||||
 | 
					            scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
 | 
				
			||||||
 | 
					            mask (array, optional): An additive mask to apply to the query-key scores.
 | 
				
			||||||
 | 
					            group_size (int): The group size used in the KV quantization.
 | 
				
			||||||
 | 
					            bits (int): The bits used in the KV quantization.
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            array: The output array.
 | 
				
			||||||
 | 
					      )pbdoc");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  m.def(
 | 
					  m.def(
 | 
				
			||||||
      "metal_kernel",
 | 
					      "metal_kernel",
 | 
				
			||||||
      [](const std::string& name,
 | 
					      [](const std::string& name,
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user