* in place kv_cache

* fix

* fix kv cache size

* partially fix kv cache dtype

* step kv cache

* multiple of step size

* more teests + kv cache

* more kv cache

* udpate all models to use kv cache
This commit is contained in:
Awni Hannun
2024-05-08 08:18:13 -07:00
committed by GitHub
parent bfbc0e434a
commit ee60e2a9d5
22 changed files with 534 additions and 298 deletions

View File

@@ -51,30 +51,24 @@ class Attention(nn.Module):
B, L, _ = q.shape
q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
keys = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
values = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
k_cache, v_cache = cache
q = self.rotary_emb(q, offset=k_cache.shape[2])
k = self.rotary_emb(k, offset=k_cache.shape[2])
k = mx.concatenate([k_cache, k], axis=2)
v = mx.concatenate([v_cache, v], axis=2)
queries = self.rotary_emb(queries, offset=cache.offset)
keys = self.rotary_emb(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys)
scores = (q * self.scale) @ k.transpose(0, 1, 3, 2)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(v_hat), (k, v)
return self.c_proj(output)
class MLP(nn.Module):
@@ -109,13 +103,13 @@ class TransformerBlock(nn.Module):
def __call__(self, x, mask=None, cache=None):
residual = x
x = self.ln_1(x)
x, cache = self.attn(x, mask=mask, cache=cache)
x = self.attn(x, mask=mask, cache=cache)
residual = x + residual
x = self.ln_2(residual)
x = self.mlp(x)
x = x + residual
return x, cache
return x
class QwenModel(nn.Module):
@@ -137,10 +131,10 @@ class QwenModel(nn.Module):
if cache is None:
cache = [None] * len(self.h)
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
for layer, c in zip(self.h, cache):
x = layer(x, mask, c)
return self.ln_f(x), cache
return self.ln_f(x)
class Model(nn.Module):
@@ -151,6 +145,7 @@ class Model(nn.Module):
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=not config.no_bias
)
self.args = config
def __call__(
self,
@@ -158,9 +153,17 @@ class Model(nn.Module):
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache
y = self.transformer(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_attention_heads