From ef14b1e9c382e9172c3ad7436ae670e73c29dc14 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Tue, 22 Oct 2024 19:20:45 -0700 Subject: [PATCH] 4 bit working --- benchmarks/python/sdpa_vector_bench.py | 31 +++++++++++++++---------- mlx/backend/metal/kernels/sdpa_vector.h | 5 ++-- mlx/fast.cpp | 4 ++-- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index c05fb8f39..c8c4eea30 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -2,7 +2,7 @@ import mlx.core as mx import numpy as np from time_utils import time_fn -L = 30000 +L = 16 H = 32 H_k = 32 // 4 D = 128 @@ -29,13 +29,15 @@ def sdpa(q, k, v): v = mx.quantize(v) k = mx.dequantize(*k) v = mx.dequantize(*v) - return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + return mx.fast.scaled_dot_product_attention(q, k, v, scale=0.08, mask=None) def quant_sdpa(q, k, v): k = mx.quantize(k) v = mx.quantize(v) - return mx.fast.quantized_scaled_dot_product_attention(q, *k, *v, scale=1.0) + return mx.fast.quantized_scaled_dot_product_attention( + q, *k, *v, scale=0.08, mask=None + ) def time_self_attention_primitives(q, k, v): @@ -52,9 +54,14 @@ def time_self_attention_quant_sdpa(q, k, v): if __name__ == "__main__": mx.random.seed(3) - q = mx.random.uniform(shape=(1, H, 10, D)) - k = mx.random.uniform(shape=(1, H_k, L, D)) - v = mx.random.uniform(shape=(1, H_k, L, D)) + # q = mx.random.uniform(shape=(1, H, 1, D)) + # k = mx.random.uniform(shape=(1, H_k, L, D)) + # v = mx.random.uniform(shape=(1, H_k, L, D)) + q = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/queries.npy")) + k = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/keys.npy")) + v = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/values.npy")) + print(q.dtype) + print(q.shape, k.shape, v.shape) mx.eval(q, k, v) k_quant = mx.quantize(k) @@ -66,12 +73,12 @@ if __name__ == "__main__": # time_self_attention_primitives(q, k, v) q_sdpa = quant_sdpa(q, k, v) print(q_sdpa) - o_attention = attention(q, k, v) - print(o_attention) - np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5) - # o_sdpa = sdpa(q, k, v) - # print(o_sdpa) - # np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5) + # o_attention = attention(q, k, v) + # print(o_attention) + # np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5) + o_sdpa = sdpa(q, k, v) + print(o_sdpa) + np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5) # print(o_sdpa[..., :64]) # print() # print(o_attention[..., :64]) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index e5961f97c..ac2da6567 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -178,9 +178,10 @@ template U shifts[4] = {1, 16, 256, 4096}; for (int i = 0; i < elem_per_thread; i++) { // Shift by the appropriate amount here - query_sum += queries[i]; U shift = shifts[i % 4]; - q[i] = static_cast(scale) * queries[i] / shift; + q[i] = static_cast(scale) * queries[i]; + query_sum += q[i]; + q[i] /= shift; } for (int i = 0; i < elem_per_thread; i++) { o[i] = 0; diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 8573d0988..a6f377cd6 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -687,7 +687,6 @@ array quantized_scaled_dot_product_attention( auto n_q_heads = queries.shape(-3); auto n_kv_heads = keys.shape(-3); - std::cout << "group bits " << group_size << " " << bits << std::endl; auto out_shape = std::vector( {queries.shape(0), queries.shape(1), queries.shape(2), out_dim}); auto stream = to_stream(s); @@ -747,7 +746,8 @@ array quantized_scaled_dot_product_attention( return std::vector{out}; }; - if (true) { + int L = queries.shape(2); + if (L > 1) { if (needs_mask) { return fallback( {queries,