mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 01:51:18 +08:00
4 bit working
This commit is contained in:
parent
5824626c0b
commit
ef14b1e9c3
@ -2,7 +2,7 @@ import mlx.core as mx
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|
||||||
L = 30000
|
L = 16
|
||||||
H = 32
|
H = 32
|
||||||
H_k = 32 // 4
|
H_k = 32 // 4
|
||||||
D = 128
|
D = 128
|
||||||
@ -29,13 +29,15 @@ def sdpa(q, k, v):
|
|||||||
v = mx.quantize(v)
|
v = mx.quantize(v)
|
||||||
k = mx.dequantize(*k)
|
k = mx.dequantize(*k)
|
||||||
v = mx.dequantize(*v)
|
v = mx.dequantize(*v)
|
||||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
return mx.fast.scaled_dot_product_attention(q, k, v, scale=0.08, mask=None)
|
||||||
|
|
||||||
|
|
||||||
def quant_sdpa(q, k, v):
|
def quant_sdpa(q, k, v):
|
||||||
k = mx.quantize(k)
|
k = mx.quantize(k)
|
||||||
v = mx.quantize(v)
|
v = mx.quantize(v)
|
||||||
return mx.fast.quantized_scaled_dot_product_attention(q, *k, *v, scale=1.0)
|
return mx.fast.quantized_scaled_dot_product_attention(
|
||||||
|
q, *k, *v, scale=0.08, mask=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def time_self_attention_primitives(q, k, v):
|
def time_self_attention_primitives(q, k, v):
|
||||||
@ -52,9 +54,14 @@ def time_self_attention_quant_sdpa(q, k, v):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
q = mx.random.uniform(shape=(1, H, 10, D))
|
# q = mx.random.uniform(shape=(1, H, 1, D))
|
||||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
# k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
# v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||||
|
q = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/queries.npy"))
|
||||||
|
k = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/keys.npy"))
|
||||||
|
v = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/values.npy"))
|
||||||
|
print(q.dtype)
|
||||||
|
print(q.shape, k.shape, v.shape)
|
||||||
mx.eval(q, k, v)
|
mx.eval(q, k, v)
|
||||||
|
|
||||||
k_quant = mx.quantize(k)
|
k_quant = mx.quantize(k)
|
||||||
@ -66,12 +73,12 @@ if __name__ == "__main__":
|
|||||||
# time_self_attention_primitives(q, k, v)
|
# time_self_attention_primitives(q, k, v)
|
||||||
q_sdpa = quant_sdpa(q, k, v)
|
q_sdpa = quant_sdpa(q, k, v)
|
||||||
print(q_sdpa)
|
print(q_sdpa)
|
||||||
o_attention = attention(q, k, v)
|
# o_attention = attention(q, k, v)
|
||||||
print(o_attention)
|
# print(o_attention)
|
||||||
np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5)
|
# np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5)
|
||||||
# o_sdpa = sdpa(q, k, v)
|
o_sdpa = sdpa(q, k, v)
|
||||||
# print(o_sdpa)
|
print(o_sdpa)
|
||||||
# np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5)
|
np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5)
|
||||||
# print(o_sdpa[..., :64])
|
# print(o_sdpa[..., :64])
|
||||||
# print()
|
# print()
|
||||||
# print(o_attention[..., :64])
|
# print(o_attention[..., :64])
|
||||||
|
@ -178,9 +178,10 @@ template <typename T, int D, int group_size, int bits>
|
|||||||
U shifts[4] = {1, 16, 256, 4096};
|
U shifts[4] = {1, 16, 256, 4096};
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
// Shift by the appropriate amount here
|
// Shift by the appropriate amount here
|
||||||
query_sum += queries[i];
|
|
||||||
U shift = shifts[i % 4];
|
U shift = shifts[i % 4];
|
||||||
q[i] = static_cast<U>(scale) * queries[i] / shift;
|
q[i] = static_cast<U>(scale) * queries[i];
|
||||||
|
query_sum += q[i];
|
||||||
|
q[i] /= shift;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
o[i] = 0;
|
o[i] = 0;
|
||||||
|
@ -687,7 +687,6 @@ array quantized_scaled_dot_product_attention(
|
|||||||
auto n_q_heads = queries.shape(-3);
|
auto n_q_heads = queries.shape(-3);
|
||||||
auto n_kv_heads = keys.shape(-3);
|
auto n_kv_heads = keys.shape(-3);
|
||||||
|
|
||||||
std::cout << "group bits " << group_size << " " << bits << std::endl;
|
|
||||||
auto out_shape = std::vector<int>(
|
auto out_shape = std::vector<int>(
|
||||||
{queries.shape(0), queries.shape(1), queries.shape(2), out_dim});
|
{queries.shape(0), queries.shape(1), queries.shape(2), out_dim});
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
@ -747,7 +746,8 @@ array quantized_scaled_dot_product_attention(
|
|||||||
return std::vector<array>{out};
|
return std::vector<array>{out};
|
||||||
};
|
};
|
||||||
|
|
||||||
if (true) {
|
int L = queries.shape(2);
|
||||||
|
if (L > 1) {
|
||||||
if (needs_mask) {
|
if (needs_mask) {
|
||||||
return fallback(
|
return fallback(
|
||||||
{queries,
|
{queries,
|
||||||
|
Loading…
Reference in New Issue
Block a user