mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:17:07 +08:00
Merge pull request #1 from N8python/add-cohere2-arch-rotating-kv-cache
Add rotating kvcache to save space
This commit is contained in:
commit
20d792576c
@ -34,13 +34,15 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
|
|||||||
return mask * -1e9
|
return mask * -1e9
|
||||||
|
|
||||||
|
|
||||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
def create_attention_mask(h: mx.array, cache: Optional[Any] = None, reference_idx: Optional[int] = None):
|
||||||
T = h.shape[1]
|
T = h.shape[1]
|
||||||
if T > 1:
|
if T > 1:
|
||||||
window_size = None
|
window_size = None
|
||||||
offset = 0
|
offset = 0
|
||||||
if cache is not None and cache[0] is not None:
|
if cache is not None and cache[0] is not None:
|
||||||
c = cache[0]
|
c = cache[0]
|
||||||
|
if reference_idx is not None:
|
||||||
|
c = cache[reference_idx]
|
||||||
if hasattr(c, "max_size"):
|
if hasattr(c, "max_size"):
|
||||||
offset = min(c.max_size, c.offset)
|
offset = min(c.max_size, c.offset)
|
||||||
window_size = c.max_size
|
window_size = c.max_size
|
||||||
|
@ -6,8 +6,8 @@ from typing import Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
from .cache import KVCache, RotatingKVCache
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs(BaseModelArgs):
|
class ModelArgs(BaseModelArgs):
|
||||||
@ -95,7 +95,6 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
keys, values = cache.update_and_fetch(keys, values)
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
|
||||||
# Apply sliding window attention if enabled
|
# Apply sliding window attention if enabled
|
||||||
if self.sliding_window is not None:
|
if self.sliding_window is not None:
|
||||||
window_size = self.sliding_window
|
window_size = self.sliding_window
|
||||||
@ -104,8 +103,8 @@ class Attention(nn.Module):
|
|||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask[..., -window_size:]
|
mask = mask[..., -window_size:]
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
@ -171,7 +170,7 @@ class CohereModel(nn.Module):
|
|||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache, reference_idx=self.args.sliding_window_pattern - 1)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -198,6 +197,15 @@ class Model(nn.Module):
|
|||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
out = out * self.model.args.logit_scale
|
out = out * self.model.args.logit_scale
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def make_cache(self):
|
||||||
|
caches = []
|
||||||
|
for i in range(self.args.num_hidden_layers):
|
||||||
|
if i % self.args.sliding_window_pattern == self.args.sliding_window_pattern - 1:
|
||||||
|
caches.append(KVCache())
|
||||||
|
else:
|
||||||
|
caches.append(RotatingKVCache(max_size=self.args.sliding_window, keep=0))
|
||||||
|
return caches
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
|
@ -187,11 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
|||||||
and prompt_cache[0].offset > quantized_kv_start
|
and prompt_cache[0].offset > quantized_kv_start
|
||||||
):
|
):
|
||||||
for i in range(len(prompt_cache)):
|
for i in range(len(prompt_cache)):
|
||||||
prompt_cache[i] = prompt_cache[i].to_quantized(
|
if isinstance(prompt_cache[i], cache.KVCache):
|
||||||
group_size=kv_group_size, bits=kv_bits
|
prompt_cache[i] = prompt_cache[i].to_quantized(
|
||||||
)
|
group_size=kv_group_size, bits=kv_bits
|
||||||
|
)
|
||||||
|
|
||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
Loading…
Reference in New Issue
Block a user