mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Kv cache (#643)
* in place kv_cache * fix * fix kv cache size * partially fix kv cache dtype * step kv cache * multiple of step size * more teests + kv cache * more kv cache * udpate all models to use kv cache
This commit is contained in:
@@ -51,30 +51,24 @@ class Attention(nn.Module):
|
||||
|
||||
B, L, _ = q.shape
|
||||
|
||||
q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
q = self.rotary_emb(q, offset=k_cache.shape[2])
|
||||
k = self.rotary_emb(k, offset=k_cache.shape[2])
|
||||
k = mx.concatenate([k_cache, k], axis=2)
|
||||
v = mx.concatenate([v_cache, v], axis=2)
|
||||
|
||||
queries = self.rotary_emb(queries, offset=cache.offset)
|
||||
keys = self.rotary_emb(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
queries = self.rotary_emb(queries)
|
||||
keys = self.rotary_emb(keys)
|
||||
|
||||
scores = (q * self.scale) @ k.transpose(0, 1, 3, 2)
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.c_proj(v_hat), (k, v)
|
||||
return self.c_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@@ -109,13 +103,13 @@ class TransformerBlock(nn.Module):
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
residual = x
|
||||
x = self.ln_1(x)
|
||||
x, cache = self.attn(x, mask=mask, cache=cache)
|
||||
x = self.attn(x, mask=mask, cache=cache)
|
||||
residual = x + residual
|
||||
x = self.ln_2(residual)
|
||||
x = self.mlp(x)
|
||||
x = x + residual
|
||||
|
||||
return x, cache
|
||||
return x
|
||||
|
||||
|
||||
class QwenModel(nn.Module):
|
||||
@@ -137,10 +131,10 @@ class QwenModel(nn.Module):
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for e, layer in enumerate(self.h):
|
||||
x, cache[e] = layer(x, mask, cache[e])
|
||||
for layer, c in zip(self.h, cache):
|
||||
x = layer(x, mask, c)
|
||||
|
||||
return self.ln_f(x), cache
|
||||
return self.ln_f(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@@ -151,6 +145,7 @@ class Model(nn.Module):
|
||||
self.lm_head = nn.Linear(
|
||||
config.hidden_size, config.vocab_size, bias=not config.no_bias
|
||||
)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -158,9 +153,17 @@ class Model(nn.Module):
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
y, cache = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y), cache
|
||||
y = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_attention_heads
|
||||
|
Reference in New Issue
Block a user