* in place kv_cache

* fix

* fix kv cache size

* partially fix kv cache dtype

* step kv cache

* multiple of step size

* more teests + kv cache

* more kv cache

* udpate all models to use kv cache
This commit is contained in:
Awni Hannun
2024-05-08 08:18:13 -07:00
committed by GitHub
parent bfbc0e434a
commit ee60e2a9d5
22 changed files with 534 additions and 298 deletions

View File

@@ -107,11 +107,9 @@ class Attention(nn.Module):
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@@ -125,7 +123,7 @@ class Attention(nn.Module):
queries, keys, values, scale=scale, mask=mask
).astype(values.dtype)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
return self.o_proj(output)
class MLP(nn.Module):
@@ -157,7 +155,7 @@ class DecoderLayer(nn.Module):
def __call__(self, x, mask, cache):
h = self.input_layernorm(x)
r, cache = self.self_attn(h, mask, cache)
r = self.self_attn(h, mask, cache)
if self.use_parallel_residual:
out = x + r + self.mlp(h)
@@ -165,7 +163,7 @@ class DecoderLayer(nn.Module):
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
return out
class StableLM(nn.Module):
@@ -180,9 +178,10 @@ class StableLM(nn.Module):
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, mask, cache[e])
return self.norm(x), cache
for layer, c in zip(self.layers, cache):
x = layer(x, mask, cache=c)
return self.norm(x)
class Model(nn.Module):
@@ -191,6 +190,7 @@ class Model(nn.Module):
self.model_type = config.model_type
self.model = StableLM(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.args = config
def __call__(
self,
@@ -203,9 +203,17 @@ class Model(nn.Module):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
y, cache = self.model(x, mask, cache)
return self.lm_head(y), cache
y = self.model(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads