mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Change gqa to use repeat instead of concatenate (#443)
This commit is contained in:
parent
06ddb8414d
commit
f71e965d57
@ -107,12 +107,9 @@ class Attention(nn.Module):
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
|
@ -73,11 +73,8 @@ class Attention(nn.Module):
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
keys, values = map(repeat, (keys, values))
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
|
@ -93,11 +93,8 @@ class Attention(nn.Module):
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
keys, values = map(repeat, (keys, values))
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
|
@ -93,12 +93,9 @@ class Attention(nn.Module):
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
|
@ -95,12 +95,9 @@ class MixtralAttention(nn.Module):
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.num_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
|
@ -86,12 +86,9 @@ class PhiAttention(nn.Module):
|
||||
B, L, self.num_key_value_heads, self.head_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.num_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
|
@ -93,12 +93,9 @@ class Attention(nn.Module):
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
|
@ -87,12 +87,9 @@ class Attention(nn.Module):
|
||||
B, L, self.num_key_value_heads, self.head_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.num_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user