mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 11:45:16 +08:00
fix tests and sliding window attention
This commit is contained in:
parent
5d8b36ce7c
commit
52595dafae
@ -48,8 +48,12 @@ class Attention(nn.Module):
|
|||||||
dim = args.hidden_size
|
dim = args.hidden_size
|
||||||
self.n_heads = n_heads = args.num_attention_heads
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
self.head_dim = head_dim = args.head_dim
|
||||||
head_dim = args.hidden_size // args.num_attention_heads
|
if (head_dim * n_heads) != dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}"
|
||||||
|
f" and `num_heads`: {n_heads})."
|
||||||
|
)
|
||||||
self.scale = head_dim**-0.5
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
attetion_bias = args.attention_bias
|
attetion_bias = args.attention_bias
|
||||||
@ -77,11 +81,8 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
queries = queries.reshape(B, L, self.n_heads, -1)
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
keys = keys.reshape(B, L, self.n_kv_heads, -1)
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
queries = queries.transpose(0, 2, 1, 3)
|
|
||||||
keys = keys.transpose(0, 2, 1, 3)
|
|
||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
@ -94,10 +95,10 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# sliding window attention
|
# sliding window attention
|
||||||
if self.sliding_window is not None:
|
if self.sliding_window is not None:
|
||||||
keys = keys[:, : -self.sliding_window :, :]
|
keys = keys[:, :, -self.sliding_window :, :]
|
||||||
values = values[:, : -self.sliding_window :, :]
|
values = values[:, :, -self.sliding_window :, :]
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask[:, : -self.sliding_window, :]
|
mask = mask[:, -self.sliding_window :]
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
@ -200,7 +201,7 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def head_dim(self):
|
def head_dim(self):
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
return self.args.head_dim
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_kv_heads(self):
|
def n_kv_heads(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user