mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Add boolean mask support in vector SDPA (#1757)
This commit is contained in:
@@ -12,7 +12,7 @@ dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
def attention(q, k, v, mask=None):
|
||||
def _sdpa(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
@@ -20,6 +20,9 @@ def attention(q, k, v):
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
if mask is not None:
|
||||
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
|
||||
s = mx.where(m, s, mx.finfo(s.dtype).min)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
@@ -29,9 +32,9 @@ def attention(q, k, v):
|
||||
return q
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
def sdpa(q, k, v, mask=None):
|
||||
for i in range(loops):
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
return q
|
||||
|
||||
|
||||
@@ -53,6 +56,26 @@ def time_self_attention_sdpa():
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_sdpa_with_mask():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
mask = mx.full((L,), True)
|
||||
mask[L // 2 :] = False
|
||||
mx.eval(q, k, v, mask)
|
||||
|
||||
def sdpa_mask(*args):
|
||||
return sdpa(*args, mask=mask)
|
||||
|
||||
def attention_mask(*args):
|
||||
return attention(*args, mask=mask)
|
||||
|
||||
time_fn(attention_mask, q, k, v)
|
||||
time_fn(sdpa_mask, q, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
time_self_attention_sdpa_with_mask()
|
||||
|
Reference in New Issue
Block a user