mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Kv cache (#643)
* in place kv_cache * fix * fix kv cache size * partially fix kv cache dtype * step kv cache * multiple of step size * more teests + kv cache * more kv cache * udpate all models to use kv cache
This commit is contained in:
@@ -64,44 +64,38 @@ class Attention(nn.Module):
|
||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||
bsz, q_len, _ = hidden_states.shape
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
query_states = query_states.reshape(
|
||||
bsz, q_len, self.q_num_heads, self.qk_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
key_states = key_states.reshape(
|
||||
bsz, q_len, self.k_num_heads, self.qk_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
value_states = value_states.reshape(
|
||||
bsz, q_len, self.v_num_heads, self.v_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
# expand shared kv
|
||||
assert self.k_num_heads == self.v_num_heads
|
||||
|
||||
kv_seq_len = 0
|
||||
if cache is not None:
|
||||
kv_seq_len += cache[0].shape[-2]
|
||||
query_states = self.rotary_emb(query_states, offset=kv_seq_len)
|
||||
key_states = self.rotary_emb(key_states, offset=kv_seq_len)
|
||||
queries = queries.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
keys = keys.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
values = values.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = mx.concatenate([cache[0], key_states], axis=2)
|
||||
value_states = mx.concatenate([cache[1], value_states], axis=2)
|
||||
queries = self.rotary_emb(queries, offset=cache.offset)
|
||||
keys = self.rotary_emb(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rotary_emb(queries)
|
||||
keys = self.rotary_emb(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
scale=self.scale,
|
||||
mask=attention_mask,
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
||||
return self.o_proj(output), (key_states, value_states)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@@ -139,7 +133,7 @@ class PlamoDecoderLayer(nn.Module):
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states_sa, cache = self.self_attn(
|
||||
hidden_states_sa = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cache=cache,
|
||||
@@ -149,7 +143,7 @@ class PlamoDecoderLayer(nn.Module):
|
||||
hidden_states_mlp = self.mlp(hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states_sa + hidden_states_mlp
|
||||
return hidden_states, cache
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PlamoDecoder(nn.Module):
|
||||
@@ -185,14 +179,10 @@ class PlamoModel(nn.Module):
|
||||
if cache is None:
|
||||
cache = [None for _ in range(len(self.layers.layers))]
|
||||
|
||||
for e, layer in enumerate(self.layers.layers):
|
||||
h, c = layer(h, mask, cache[e])
|
||||
if cache is not None:
|
||||
cache[e] = c
|
||||
else:
|
||||
cache.append(c)
|
||||
for layer, c in zip(self.layers.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h), cache
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@@ -203,15 +193,24 @@ class Model(nn.Module):
|
||||
self.lm_head: nn.Module = nn.Linear(
|
||||
args.hidden_size, args.vocab_size, bias=False
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
out, cache = self.model(inputs, cache)
|
||||
return self.lm_head(out), cache
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_attention_heads // self.args.n_shared_head
|
||||
|
Reference in New Issue
Block a user