mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 17:58:54 +08:00
Kv cache (#643)
* in place kv_cache * fix * fix kv cache size * partially fix kv cache dtype * step kv cache * multiple of step size * more teests + kv cache * more kv cache * udpate all models to use kv cache
This commit is contained in:
@@ -107,11 +107,9 @@ class Attention(nn.Module):
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
@@ -125,7 +123,7 @@ class Attention(nn.Module):
|
||||
queries, keys, values, scale=scale, mask=mask
|
||||
).astype(values.dtype)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@@ -157,7 +155,7 @@ class DecoderLayer(nn.Module):
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.input_layernorm(x)
|
||||
r, cache = self.self_attn(h, mask, cache)
|
||||
r = self.self_attn(h, mask, cache)
|
||||
|
||||
if self.use_parallel_residual:
|
||||
out = x + r + self.mlp(h)
|
||||
@@ -165,7 +163,7 @@ class DecoderLayer(nn.Module):
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
return out
|
||||
|
||||
|
||||
class StableLM(nn.Module):
|
||||
@@ -180,9 +178,10 @@ class StableLM(nn.Module):
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
x, cache[e] = layer(x, mask, cache[e])
|
||||
return self.norm(x), cache
|
||||
for layer, c in zip(self.layers, cache):
|
||||
x = layer(x, mask, cache=c)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@@ -191,6 +190,7 @@ class Model(nn.Module):
|
||||
self.model_type = config.model_type
|
||||
self.model = StableLM(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -203,9 +203,17 @@ class Model(nn.Module):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
y, cache = self.model(x, mask, cache)
|
||||
return self.lm_head(y), cache
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
||||
Reference in New Issue
Block a user