mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 07:30:06 +08:00
Merge pull request #1 from N8python/add-cohere2-arch-rotating-kv-cache
Add rotating kvcache to save space
This commit is contained in:
commit
20d792576c
@ -34,13 +34,15 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None, reference_idx: Optional[int] = None):
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
window_size = None
|
||||
offset = 0
|
||||
if cache is not None and cache[0] is not None:
|
||||
c = cache[0]
|
||||
if reference_idx is not None:
|
||||
c = cache[reference_idx]
|
||||
if hasattr(c, "max_size"):
|
||||
offset = min(c.max_size, c.offset)
|
||||
window_size = c.max_size
|
||||
|
@ -6,8 +6,8 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .cache import KVCache, RotatingKVCache
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
@ -95,7 +95,6 @@ class Attention(nn.Module):
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
# Apply sliding window attention if enabled
|
||||
if self.sliding_window is not None:
|
||||
window_size = self.sliding_window
|
||||
@ -104,8 +103,8 @@ class Attention(nn.Module):
|
||||
if mask is not None:
|
||||
mask = mask[..., -window_size:]
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
@ -171,7 +170,7 @@ class CohereModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
mask = create_attention_mask(h, cache, reference_idx=self.args.sliding_window_pattern - 1)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@ -198,6 +197,15 @@ class Model(nn.Module):
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = out * self.model.args.logit_scale
|
||||
return out
|
||||
|
||||
def make_cache(self):
|
||||
caches = []
|
||||
for i in range(self.args.num_hidden_layers):
|
||||
if i % self.args.sliding_window_pattern == self.args.sliding_window_pattern - 1:
|
||||
caches.append(KVCache())
|
||||
else:
|
||||
caches.append(RotatingKVCache(max_size=self.args.sliding_window, keep=0))
|
||||
return caches
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
|
@ -187,11 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
||||
and prompt_cache[0].offset > quantized_kv_start
|
||||
):
|
||||
for i in range(len(prompt_cache)):
|
||||
prompt_cache[i] = prompt_cache[i].to_quantized(
|
||||
group_size=kv_group_size, bits=kv_bits
|
||||
)
|
||||
|
||||
|
||||
if isinstance(prompt_cache[i], cache.KVCache):
|
||||
prompt_cache[i] = prompt_cache[i].to_quantized(
|
||||
group_size=kv_group_size, bits=kv_bits
|
||||
)
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
|
Loading…
Reference in New Issue
Block a user