* in place kv_cache

* fix

* fix kv cache size

* partially fix kv cache dtype

* step kv cache

* multiple of step size

* more teests + kv cache

* more kv cache

* udpate all models to use kv cache
This commit is contained in:
Awni Hannun
2024-05-08 08:18:13 -07:00
committed by GitHub
parent bfbc0e434a
commit ee60e2a9d5
22 changed files with 534 additions and 298 deletions

View File

@@ -64,44 +64,38 @@ class Attention(nn.Module):
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
bsz, q_len, _ = hidden_states.shape
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
# Prepare the queries, keys and values for the attention computation
query_states = query_states.reshape(
bsz, q_len, self.q_num_heads, self.qk_dim
).transpose(0, 2, 1, 3)
key_states = key_states.reshape(
bsz, q_len, self.k_num_heads, self.qk_dim
).transpose(0, 2, 1, 3)
value_states = value_states.reshape(
bsz, q_len, self.v_num_heads, self.v_dim
).transpose(0, 2, 1, 3)
# expand shared kv
assert self.k_num_heads == self.v_num_heads
kv_seq_len = 0
if cache is not None:
kv_seq_len += cache[0].shape[-2]
query_states = self.rotary_emb(query_states, offset=kv_seq_len)
key_states = self.rotary_emb(key_states, offset=kv_seq_len)
queries = queries.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(
0, 2, 1, 3
)
keys = keys.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(
0, 2, 1, 3
)
values = values.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(
0, 2, 1, 3
)
if cache is not None:
# reuse k, v, self_attention
key_states = mx.concatenate([cache[0], key_states], axis=2)
value_states = mx.concatenate([cache[1], value_states], axis=2)
queries = self.rotary_emb(queries, offset=cache.offset)
keys = self.rotary_emb(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys)
output = mx.fast.scaled_dot_product_attention(
query_states,
key_states,
value_states,
queries,
keys,
values,
scale=self.scale,
mask=attention_mask,
)
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
return self.o_proj(output), (key_states, value_states)
return self.o_proj(output)
class MLP(nn.Module):
@@ -139,7 +133,7 @@ class PlamoDecoderLayer(nn.Module):
hidden_states = self.norm(hidden_states)
# Self Attention
hidden_states_sa, cache = self.self_attn(
hidden_states_sa = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
cache=cache,
@@ -149,7 +143,7 @@ class PlamoDecoderLayer(nn.Module):
hidden_states_mlp = self.mlp(hidden_states)
hidden_states = residual + hidden_states_sa + hidden_states_mlp
return hidden_states, cache
return hidden_states
class PlamoDecoder(nn.Module):
@@ -185,14 +179,10 @@ class PlamoModel(nn.Module):
if cache is None:
cache = [None for _ in range(len(self.layers.layers))]
for e, layer in enumerate(self.layers.layers):
h, c = layer(h, mask, cache[e])
if cache is not None:
cache[e] = c
else:
cache.append(c)
for layer, c in zip(self.layers.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h), cache
return self.norm(h)
class Model(nn.Module):
@@ -203,15 +193,24 @@ class Model(nn.Module):
self.lm_head: nn.Module = nn.Linear(
args.hidden_size, args.vocab_size, bias=False
)
self.args = args
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
) -> Tuple[mx.array, mx.array]:
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
out = self.model(inputs, cache)
return self.lm_head(out)
@property
def layers(self):
return self.model.layers.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_attention_heads // self.args.n_shared_head