mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Change gqa to use repeat instead of concatenate (#443)
This commit is contained in:
parent
06ddb8414d
commit
f71e965d57
@ -107,12 +107,9 @@ class Attention(nn.Module):
|
|||||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
def repeat(a):
|
|
||||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
|
||||||
return a.reshape([B, self.n_heads, L, -1])
|
|
||||||
|
|
||||||
if self.repeats > 1:
|
if self.repeats > 1:
|
||||||
keys, values = map(repeat, (keys, values))
|
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||||
|
values = mx.repeat(values, self.repeats, axis=1)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
key_cache, value_cache = cache
|
||||||
|
@ -73,11 +73,8 @@ class Attention(nn.Module):
|
|||||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
def repeat(a):
|
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
values = mx.repeat(values, self.repeats, axis=1)
|
||||||
return a.reshape([B, self.n_heads, L, -1])
|
|
||||||
|
|
||||||
keys, values = map(repeat, (keys, values))
|
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
key_cache, value_cache = cache
|
||||||
|
@ -93,11 +93,8 @@ class Attention(nn.Module):
|
|||||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
def repeat(a):
|
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
values = mx.repeat(values, self.repeats, axis=1)
|
||||||
return a.reshape([B, self.n_heads, L, -1])
|
|
||||||
|
|
||||||
keys, values = map(repeat, (keys, values))
|
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
key_cache, value_cache = cache
|
||||||
|
@ -93,12 +93,9 @@ class Attention(nn.Module):
|
|||||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
def repeat(a):
|
|
||||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
|
||||||
return a.reshape([B, self.n_heads, L, -1])
|
|
||||||
|
|
||||||
if self.repeats > 1:
|
if self.repeats > 1:
|
||||||
keys, values = map(repeat, (keys, values))
|
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||||
|
values = mx.repeat(values, self.repeats, axis=1)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
key_cache, value_cache = cache
|
||||||
|
@ -95,12 +95,9 @@ class MixtralAttention(nn.Module):
|
|||||||
0, 2, 1, 3
|
0, 2, 1, 3
|
||||||
)
|
)
|
||||||
|
|
||||||
def repeat(a):
|
|
||||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
|
||||||
return a.reshape([B, self.num_heads, L, -1])
|
|
||||||
|
|
||||||
if self.repeats > 1:
|
if self.repeats > 1:
|
||||||
keys, values = map(repeat, (keys, values))
|
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||||
|
values = mx.repeat(values, self.repeats, axis=1)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
key_cache, value_cache = cache
|
||||||
|
@ -86,12 +86,9 @@ class PhiAttention(nn.Module):
|
|||||||
B, L, self.num_key_value_heads, self.head_dim
|
B, L, self.num_key_value_heads, self.head_dim
|
||||||
).transpose(0, 2, 1, 3)
|
).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
def repeat(a):
|
|
||||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
|
||||||
return a.reshape([B, self.num_heads, L, -1])
|
|
||||||
|
|
||||||
if self.repeats > 1:
|
if self.repeats > 1:
|
||||||
keys, values = map(repeat, (keys, values))
|
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||||
|
values = mx.repeat(values, self.repeats, axis=1)
|
||||||
|
|
||||||
# Add RoPE to the queries and keys and combine them with the cache
|
# Add RoPE to the queries and keys and combine them with the cache
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
|
@ -93,12 +93,9 @@ class Attention(nn.Module):
|
|||||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
def repeat(a):
|
|
||||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
|
||||||
return a.reshape([B, self.n_heads, L, -1])
|
|
||||||
|
|
||||||
if self.repeats > 1:
|
if self.repeats > 1:
|
||||||
keys, values = map(repeat, (keys, values))
|
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||||
|
values = mx.repeat(values, self.repeats, axis=1)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
key_cache, value_cache = cache
|
||||||
|
@ -87,12 +87,9 @@ class Attention(nn.Module):
|
|||||||
B, L, self.num_key_value_heads, self.head_dim
|
B, L, self.num_key_value_heads, self.head_dim
|
||||||
).transpose(0, 2, 1, 3)
|
).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
def repeat(a):
|
|
||||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
|
||||||
return a.reshape([B, self.num_heads, L, -1])
|
|
||||||
|
|
||||||
if self.repeats > 1:
|
if self.repeats > 1:
|
||||||
keys, values = map(repeat, (keys, values))
|
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||||
|
values = mx.repeat(values, self.repeats, axis=1)
|
||||||
|
|
||||||
# Add RoPE to the queries and keys and combine them with the cache
|
# Add RoPE to the queries and keys and combine them with the cache
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user