mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	8 bit working
This commit is contained in:
		@@ -25,18 +25,18 @@ def attention(q, k, v):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def sdpa(q, k, v):
 | 
					def sdpa(q, k, v):
 | 
				
			||||||
    k = mx.quantize(k)
 | 
					    k = mx.quantize(k, bits=8)
 | 
				
			||||||
    v = mx.quantize(v)
 | 
					    v = mx.quantize(v, bits=8)
 | 
				
			||||||
    k = mx.dequantize(*k)
 | 
					    k = mx.dequantize(*k, bits=8)
 | 
				
			||||||
    v = mx.dequantize(*v)
 | 
					    v = mx.dequantize(*v, bits=8)
 | 
				
			||||||
    return mx.fast.scaled_dot_product_attention(q, k, v, scale=0.08, mask=None)
 | 
					    return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def quant_sdpa(q, k, v):
 | 
					def quant_sdpa(q, k, v):
 | 
				
			||||||
    k = mx.quantize(k)
 | 
					    k = mx.quantize(k, bits=8)
 | 
				
			||||||
    v = mx.quantize(v)
 | 
					    v = mx.quantize(v, bits=8)
 | 
				
			||||||
    return mx.fast.quantized_scaled_dot_product_attention(
 | 
					    return mx.fast.quantized_scaled_dot_product_attention(
 | 
				
			||||||
        q, *k, *v, scale=0.08, mask=None
 | 
					        q, *k, *v, scale=1.0, mask=None, bits=8
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -939,13 +939,25 @@ instantiate_sdpa_vector_heads(bfloat16_t)
 | 
				
			|||||||
instantiate_sdpa_vector_heads(float16_t)
 | 
					instantiate_sdpa_vector_heads(float16_t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Quantized SDPA vector instantiations
 | 
					// Quantized SDPA vector instantiations
 | 
				
			||||||
#define instantiate_quant_sdpa_vector(type, head_dim)                              \
 | 
					#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \
 | 
				
			||||||
  instantiate_kernel("quant_sdpa_vector_" #type "_" #head_dim, quant_sdpa_vector, type, head_dim, 64, 4)
 | 
					  instantiate_kernel(                                                   \
 | 
				
			||||||
 | 
					    "quant_sdpa_vector_" #type "_" #head_dim "_" #group_size "_" #bits, \
 | 
				
			||||||
 | 
					    quant_sdpa_vector, type, head_dim, group_size, bits)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector(type, heads, group_size, 4)         \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector(type, heads, group_size, 8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define instantiate_quant_sdpa_vector_group_size(type, heads) \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector_bits(type, heads, 32)         \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector_bits(type, heads, 64)         \
 | 
				
			||||||
 | 
					  instantiate_quant_sdpa_vector_bits(type, heads, 128)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#define instantiate_quant_sdpa_vector_heads(type) \
 | 
					#define instantiate_quant_sdpa_vector_heads(type) \
 | 
				
			||||||
  instantiate_quant_sdpa_vector(type, 64)         \
 | 
					  instantiate_quant_sdpa_vector_group_size(type, 64)         \
 | 
				
			||||||
  instantiate_quant_sdpa_vector(type, 96)         \
 | 
					  instantiate_quant_sdpa_vector_group_size(type, 96)         \
 | 
				
			||||||
  instantiate_quant_sdpa_vector(type, 128)
 | 
					  instantiate_quant_sdpa_vector_group_size(type, 128)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
instantiate_quant_sdpa_vector_heads(float)
 | 
					instantiate_quant_sdpa_vector_heads(float)
 | 
				
			||||||
instantiate_quant_sdpa_vector_heads(bfloat16_t)
 | 
					instantiate_quant_sdpa_vector_heads(bfloat16_t)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -118,6 +118,67 @@ template <typename T, int D>
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename T, typename U, int elem_per_thread, int bits>
 | 
				
			||||||
 | 
					METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) {
 | 
				
			||||||
 | 
					  U query_sum = 0;
 | 
				
			||||||
 | 
					  if (bits == 4) {
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread; i += 4) {
 | 
				
			||||||
 | 
					      q[i] = scale * queries[i];
 | 
				
			||||||
 | 
					      q[i + 1] = scale * queries[i + 1];
 | 
				
			||||||
 | 
					      q[i + 2] = scale * queries[i + 2];
 | 
				
			||||||
 | 
					      q[i + 3] = scale * queries[i + 3];
 | 
				
			||||||
 | 
					      query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3];
 | 
				
			||||||
 | 
					      q[i + 1] /= 16.0f;
 | 
				
			||||||
 | 
					      q[i + 2] /= 256.0f;
 | 
				
			||||||
 | 
					      q[i + 3] /= 4096.0f;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else if (bits == 8) {
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					      q[i] = scale * queries[i];
 | 
				
			||||||
 | 
					      query_sum += q[i];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return query_sum;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename U, int elem_per_thread, int bits>
 | 
				
			||||||
 | 
					METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) {
 | 
				
			||||||
 | 
					  if (bits == 4) {
 | 
				
			||||||
 | 
					    auto ks = (const device uint16_t*)keys;
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread / 4; i++) {
 | 
				
			||||||
 | 
					      k[4 * i] = ks[i] & 0x000f;
 | 
				
			||||||
 | 
					      k[4 * i + 1] = ks[i] & 0x00f0;
 | 
				
			||||||
 | 
					      k[4 * i + 2] = ks[i] & 0x0f00;
 | 
				
			||||||
 | 
					      k[4 * i + 3] = ks[i] & 0xf000;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else if (bits == 8) {
 | 
				
			||||||
 | 
					    auto ks = (const device uint8_t*)keys;
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					      k[i] = ks[i];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename U, int elem_per_thread, int bits>
 | 
				
			||||||
 | 
					METAL_FUNC void load_values(
 | 
				
			||||||
 | 
					    const device uint32_t* values,
 | 
				
			||||||
 | 
					    thread U* v,
 | 
				
			||||||
 | 
					    U value_scale,
 | 
				
			||||||
 | 
					    U value_bias) {
 | 
				
			||||||
 | 
					  auto vs = (const device uint8_t*)values;
 | 
				
			||||||
 | 
					  if (bits == 4) {
 | 
				
			||||||
 | 
					    U s[2] = {value_scale, value_scale / 16.0f};
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread / 2; i++) {
 | 
				
			||||||
 | 
					      v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias;
 | 
				
			||||||
 | 
					      v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else if (bits == 8) {
 | 
				
			||||||
 | 
					    for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 | 
					      v[i] = value_scale * vs[i] + value_bias;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename T, int D, int group_size, int bits>
 | 
					template <typename T, int D, int group_size, int bits>
 | 
				
			||||||
[[kernel]] void quant_sdpa_vector(
 | 
					[[kernel]] void quant_sdpa_vector(
 | 
				
			||||||
    const device T* queries [[buffer(0)]],
 | 
					    const device T* queries [[buffer(0)]],
 | 
				
			||||||
@@ -174,15 +235,8 @@ template <typename T, int D, int group_size, int bits>
 | 
				
			|||||||
  out += head_idx * D + simd_gid * elem_per_thread;
 | 
					  out += head_idx * D + simd_gid * elem_per_thread;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Read the query and 0 the output accumulator
 | 
					  // Read the query and 0 the output accumulator
 | 
				
			||||||
  U query_sum = 0;
 | 
					  U query_sum = load_queries<T, U, elem_per_thread, bits>(
 | 
				
			||||||
  U shifts[4] = {1, 16, 256, 4096};
 | 
					      queries, q, static_cast<U>(scale));
 | 
				
			||||||
  for (int i = 0; i < elem_per_thread; i++) {
 | 
					 | 
				
			||||||
    // Shift by the appropriate amount here
 | 
					 | 
				
			||||||
    U shift = shifts[i % 4];
 | 
					 | 
				
			||||||
    q[i] = static_cast<U>(scale) * queries[i];
 | 
					 | 
				
			||||||
    query_sum += q[i];
 | 
					 | 
				
			||||||
    q[i] /= shift;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  for (int i = 0; i < elem_per_thread; i++) {
 | 
					  for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
    o[i] = 0;
 | 
					    o[i] = 0;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@@ -192,15 +246,9 @@ template <typename T, int D, int group_size, int bits>
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  // For each key
 | 
					  // For each key
 | 
				
			||||||
  for (int i = quad_gid; i < N; i += BN) {
 | 
					  for (int i = quad_gid; i < N; i += BN) {
 | 
				
			||||||
    // Read the key
 | 
					    load_keys<U, elem_per_thread, bits>(keys, k);
 | 
				
			||||||
    auto ks = (const device uint16_t*)keys;
 | 
					
 | 
				
			||||||
    for (int i = 0; i < elem_per_thread / 4; i++) {
 | 
					    // Assume D % group_size == 0 so all the keys are in the same group
 | 
				
			||||||
      k[4 * i] = ks[i] & 0x000f;
 | 
					 | 
				
			||||||
      k[4 * i + 1] = ks[i] & 0x00f0;
 | 
					 | 
				
			||||||
      k[4 * i + 2] = ks[i] & 0x0f00;
 | 
					 | 
				
			||||||
      k[4 * i + 3] = ks[i] & 0xf000;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    // All the keys in a set are in the same group
 | 
					 | 
				
			||||||
    U key_scale = key_scales[0];
 | 
					    U key_scale = key_scales[0];
 | 
				
			||||||
    U key_bias = key_biases[0];
 | 
					    U key_bias = key_biases[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -224,18 +272,7 @@ template <typename T, int D, int group_size, int bits>
 | 
				
			|||||||
    U value_bias = value_biases[0];
 | 
					    U value_bias = value_biases[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Load the values
 | 
					    // Load the values
 | 
				
			||||||
    auto vs = (const device uint16_t*)values;
 | 
					    load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias);
 | 
				
			||||||
    U s[4] = {
 | 
					 | 
				
			||||||
        value_scale,
 | 
					 | 
				
			||||||
        value_scale / 16.0f,
 | 
					 | 
				
			||||||
        value_scale / 256.0f,
 | 
					 | 
				
			||||||
        value_scale / 4096.0f};
 | 
					 | 
				
			||||||
    for (int i = 0; i < elem_per_thread / 4; i++) {
 | 
					 | 
				
			||||||
      v[4 * i] = s[0] * (vs[i] & 0x000f) + value_bias;
 | 
					 | 
				
			||||||
      v[4 * i + 1] = s[1] * (vs[i] & 0x00f0) + value_bias;
 | 
					 | 
				
			||||||
      v[4 * i + 2] = s[2] * (vs[i] & 0x0f00) + value_bias;
 | 
					 | 
				
			||||||
      v[4 * i + 3] = s[3] * (vs[i] & 0xf000) + value_bias;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Update the output accumulator
 | 
					    // Update the output accumulator
 | 
				
			||||||
    for (int i = 0; i < elem_per_thread; i++) {
 | 
					    for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -198,7 +198,9 @@ void quant_sdpa_vector(
 | 
				
			|||||||
    const array& v_scales,
 | 
					    const array& v_scales,
 | 
				
			||||||
    const array& v_biases,
 | 
					    const array& v_biases,
 | 
				
			||||||
    array& out,
 | 
					    array& out,
 | 
				
			||||||
    float scale) {
 | 
					    float scale,
 | 
				
			||||||
 | 
					    int group_size,
 | 
				
			||||||
 | 
					    int bits) {
 | 
				
			||||||
  // Set the kernel name
 | 
					  // Set the kernel name
 | 
				
			||||||
  std::string kname;
 | 
					  std::string kname;
 | 
				
			||||||
  kname.reserve(96);
 | 
					  kname.reserve(96);
 | 
				
			||||||
@@ -206,6 +208,10 @@ void quant_sdpa_vector(
 | 
				
			|||||||
  kname += get_type_string(q.dtype());
 | 
					  kname += get_type_string(q.dtype());
 | 
				
			||||||
  kname += "_";
 | 
					  kname += "_";
 | 
				
			||||||
  kname += std::to_string(q.shape(-1));
 | 
					  kname += std::to_string(q.shape(-1));
 | 
				
			||||||
 | 
					  kname += "_";
 | 
				
			||||||
 | 
					  kname += std::to_string(group_size);
 | 
				
			||||||
 | 
					  kname += "_";
 | 
				
			||||||
 | 
					  kname += std::to_string(bits);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Compute the necessary sizes
 | 
					  // Compute the necessary sizes
 | 
				
			||||||
  int gqa_factor = q.shape(1) / k.shape(1);
 | 
					  int gqa_factor = q.shape(1) / k.shape(1);
 | 
				
			||||||
@@ -314,7 +320,19 @@ void ScaledDotProductAttention::eval_gpu(
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    quant_sdpa_vector(
 | 
					    quant_sdpa_vector(
 | 
				
			||||||
        s, d, q, k, k_scales, k_biases, v, v_scales, v_biases, o, scale_);
 | 
					        s,
 | 
				
			||||||
 | 
					        d,
 | 
				
			||||||
 | 
					        q,
 | 
				
			||||||
 | 
					        k,
 | 
				
			||||||
 | 
					        k_scales,
 | 
				
			||||||
 | 
					        k_biases,
 | 
				
			||||||
 | 
					        v,
 | 
				
			||||||
 | 
					        v_scales,
 | 
				
			||||||
 | 
					        v_biases,
 | 
				
			||||||
 | 
					        o,
 | 
				
			||||||
 | 
					        scale_,
 | 
				
			||||||
 | 
					        group_size_,
 | 
				
			||||||
 | 
					        bits_);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -773,7 +773,13 @@ array quantized_scaled_dot_product_attention(
 | 
				
			|||||||
        std::move(out_shape),
 | 
					        std::move(out_shape),
 | 
				
			||||||
        queries.dtype(),
 | 
					        queries.dtype(),
 | 
				
			||||||
        std::make_shared<ScaledDotProductAttention>(
 | 
					        std::make_shared<ScaledDotProductAttention>(
 | 
				
			||||||
            stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/true),
 | 
					            stream,
 | 
				
			||||||
 | 
					            fallback,
 | 
				
			||||||
 | 
					            scale,
 | 
				
			||||||
 | 
					            /*needs_mask=*/false,
 | 
				
			||||||
 | 
					            /*quantized=*/true,
 | 
				
			||||||
 | 
					            group_size,
 | 
				
			||||||
 | 
					            bits),
 | 
				
			||||||
        {queries,
 | 
					        {queries,
 | 
				
			||||||
         keys,
 | 
					         keys,
 | 
				
			||||||
         key_scales,
 | 
					         key_scales,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -191,11 +191,15 @@ class ScaledDotProductAttention : public Custom {
 | 
				
			|||||||
      std::function<std::vector<array>(std::vector<array>)> fallback,
 | 
					      std::function<std::vector<array>(std::vector<array>)> fallback,
 | 
				
			||||||
      const float scale,
 | 
					      const float scale,
 | 
				
			||||||
      const bool needs_mask,
 | 
					      const bool needs_mask,
 | 
				
			||||||
      const bool quantized)
 | 
					      const bool quantized,
 | 
				
			||||||
 | 
					      const int group_size = 64,
 | 
				
			||||||
 | 
					      const int bits = 4)
 | 
				
			||||||
      : Custom(stream, fallback),
 | 
					      : Custom(stream, fallback),
 | 
				
			||||||
        scale_(scale),
 | 
					        scale_(scale),
 | 
				
			||||||
        needs_mask_(needs_mask),
 | 
					        needs_mask_(needs_mask),
 | 
				
			||||||
        quantized_(quantized) {}
 | 
					        quantized_(quantized),
 | 
				
			||||||
 | 
					        group_size_(group_size),
 | 
				
			||||||
 | 
					        bits_(bits) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
 | 
					  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
 | 
				
			||||||
      override {
 | 
					      override {
 | 
				
			||||||
@@ -217,6 +221,8 @@ class ScaledDotProductAttention : public Custom {
 | 
				
			|||||||
  float scale_;
 | 
					  float scale_;
 | 
				
			||||||
  bool needs_mask_;
 | 
					  bool needs_mask_;
 | 
				
			||||||
  bool quantized_;
 | 
					  bool quantized_;
 | 
				
			||||||
 | 
					  int group_size_;
 | 
				
			||||||
 | 
					  int bits_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AffineQuantize : public Custom {
 | 
					class AffineQuantize : public Custom {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user