More cache improvements (#1015)

* fix rotating kv cache for chat use case

* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

* nit in chat

* fix tests

* fix tests

* fix tests

* docs

* chat command

* comments + docs

* Define meta_state on all Cache implementations

* fixes + trim_prompt_cache api

* fix default model

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-10-07 20:45:51 -07:00
committed by GitHub
parent 9bc53fc210
commit fca087be49
43 changed files with 1151 additions and 691 deletions

View File

@@ -2,12 +2,12 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs):
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Optional[Dict] = None
rope_scaling: Dict = None
attention_bias: bool = False
@@ -172,12 +172,11 @@ class DeepseekV2Attention(nn.Module):
bias=config.attention_bias,
)
if self.config.rope_scaling is not None:
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
rope_kwargs = {
key: self.config.rope_scaling[key]
@@ -202,7 +201,7 @@ class DeepseekV2Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -347,7 +346,7 @@ class DeepseekV2DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -370,7 +369,7 @@ class DeepseekV2Model(nn.Module):
def __call__(
self,
x: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
@@ -395,7 +394,7 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
@@ -416,14 +415,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
self.args.v_head_dim,
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads