mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	4 bit working
This commit is contained in:
		@@ -2,7 +2,7 @@ import mlx.core as mx
 | 
				
			|||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from time_utils import time_fn
 | 
					from time_utils import time_fn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
L = 30000
 | 
					L = 16
 | 
				
			||||||
H = 32
 | 
					H = 32
 | 
				
			||||||
H_k = 32 // 4
 | 
					H_k = 32 // 4
 | 
				
			||||||
D = 128
 | 
					D = 128
 | 
				
			||||||
@@ -29,13 +29,15 @@ def sdpa(q, k, v):
 | 
				
			|||||||
    v = mx.quantize(v)
 | 
					    v = mx.quantize(v)
 | 
				
			||||||
    k = mx.dequantize(*k)
 | 
					    k = mx.dequantize(*k)
 | 
				
			||||||
    v = mx.dequantize(*v)
 | 
					    v = mx.dequantize(*v)
 | 
				
			||||||
    return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
 | 
					    return mx.fast.scaled_dot_product_attention(q, k, v, scale=0.08, mask=None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def quant_sdpa(q, k, v):
 | 
					def quant_sdpa(q, k, v):
 | 
				
			||||||
    k = mx.quantize(k)
 | 
					    k = mx.quantize(k)
 | 
				
			||||||
    v = mx.quantize(v)
 | 
					    v = mx.quantize(v)
 | 
				
			||||||
    return mx.fast.quantized_scaled_dot_product_attention(q, *k, *v, scale=1.0)
 | 
					    return mx.fast.quantized_scaled_dot_product_attention(
 | 
				
			||||||
 | 
					        q, *k, *v, scale=0.08, mask=None
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def time_self_attention_primitives(q, k, v):
 | 
					def time_self_attention_primitives(q, k, v):
 | 
				
			||||||
@@ -52,9 +54,14 @@ def time_self_attention_quant_sdpa(q, k, v):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    mx.random.seed(3)
 | 
					    mx.random.seed(3)
 | 
				
			||||||
    q = mx.random.uniform(shape=(1, H, 10, D))
 | 
					    # q = mx.random.uniform(shape=(1, H, 1, D))
 | 
				
			||||||
    k = mx.random.uniform(shape=(1, H_k, L, D))
 | 
					    # k = mx.random.uniform(shape=(1, H_k, L, D))
 | 
				
			||||||
    v = mx.random.uniform(shape=(1, H_k, L, D))
 | 
					    # v = mx.random.uniform(shape=(1, H_k, L, D))
 | 
				
			||||||
 | 
					    q = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/queries.npy"))
 | 
				
			||||||
 | 
					    k = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/keys.npy"))
 | 
				
			||||||
 | 
					    v = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/values.npy"))
 | 
				
			||||||
 | 
					    print(q.dtype)
 | 
				
			||||||
 | 
					    print(q.shape, k.shape, v.shape)
 | 
				
			||||||
    mx.eval(q, k, v)
 | 
					    mx.eval(q, k, v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    k_quant = mx.quantize(k)
 | 
					    k_quant = mx.quantize(k)
 | 
				
			||||||
@@ -66,12 +73,12 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # time_self_attention_primitives(q, k, v)
 | 
					    # time_self_attention_primitives(q, k, v)
 | 
				
			||||||
    q_sdpa = quant_sdpa(q, k, v)
 | 
					    q_sdpa = quant_sdpa(q, k, v)
 | 
				
			||||||
    print(q_sdpa)
 | 
					    print(q_sdpa)
 | 
				
			||||||
    o_attention = attention(q, k, v)
 | 
					    # o_attention = attention(q, k, v)
 | 
				
			||||||
    print(o_attention)
 | 
					    # print(o_attention)
 | 
				
			||||||
    np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5)
 | 
					    # np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5)
 | 
				
			||||||
    # o_sdpa = sdpa(q, k, v)
 | 
					    o_sdpa = sdpa(q, k, v)
 | 
				
			||||||
    # print(o_sdpa)
 | 
					    print(o_sdpa)
 | 
				
			||||||
    # np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5)
 | 
					    np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5)
 | 
				
			||||||
    # print(o_sdpa[..., :64])
 | 
					    # print(o_sdpa[..., :64])
 | 
				
			||||||
    # print()
 | 
					    # print()
 | 
				
			||||||
    # print(o_attention[..., :64])
 | 
					    # print(o_attention[..., :64])
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -178,9 +178,10 @@ template <typename T, int D, int group_size, int bits>
 | 
				
			|||||||
  U shifts[4] = {1, 16, 256, 4096};
 | 
					  U shifts[4] = {1, 16, 256, 4096};
 | 
				
			||||||
  for (int i = 0; i < elem_per_thread; i++) {
 | 
					  for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
    // Shift by the appropriate amount here
 | 
					    // Shift by the appropriate amount here
 | 
				
			||||||
    query_sum += queries[i];
 | 
					 | 
				
			||||||
    U shift = shifts[i % 4];
 | 
					    U shift = shifts[i % 4];
 | 
				
			||||||
    q[i] = static_cast<U>(scale) * queries[i] / shift;
 | 
					    q[i] = static_cast<U>(scale) * queries[i];
 | 
				
			||||||
 | 
					    query_sum += q[i];
 | 
				
			||||||
 | 
					    q[i] /= shift;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  for (int i = 0; i < elem_per_thread; i++) {
 | 
					  for (int i = 0; i < elem_per_thread; i++) {
 | 
				
			||||||
    o[i] = 0;
 | 
					    o[i] = 0;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -687,7 +687,6 @@ array quantized_scaled_dot_product_attention(
 | 
				
			|||||||
  auto n_q_heads = queries.shape(-3);
 | 
					  auto n_q_heads = queries.shape(-3);
 | 
				
			||||||
  auto n_kv_heads = keys.shape(-3);
 | 
					  auto n_kv_heads = keys.shape(-3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::cout << "group bits " << group_size << " " << bits << std::endl;
 | 
					 | 
				
			||||||
  auto out_shape = std::vector<int>(
 | 
					  auto out_shape = std::vector<int>(
 | 
				
			||||||
      {queries.shape(0), queries.shape(1), queries.shape(2), out_dim});
 | 
					      {queries.shape(0), queries.shape(1), queries.shape(2), out_dim});
 | 
				
			||||||
  auto stream = to_stream(s);
 | 
					  auto stream = to_stream(s);
 | 
				
			||||||
@@ -747,7 +746,8 @@ array quantized_scaled_dot_product_attention(
 | 
				
			|||||||
    return std::vector<array>{out};
 | 
					    return std::vector<array>{out};
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (true) {
 | 
					  int L = queries.shape(2);
 | 
				
			||||||
 | 
					  if (L > 1) {
 | 
				
			||||||
    if (needs_mask) {
 | 
					    if (needs_mask) {
 | 
				
			||||||
      return fallback(
 | 
					      return fallback(
 | 
				
			||||||
          {queries,
 | 
					          {queries,
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user