* in place kv_cache

* fix

* fix kv cache size

* partially fix kv cache dtype

* step kv cache

* multiple of step size

* more teests + kv cache

* more kv cache

* udpate all models to use kv cache
This commit is contained in:
Awni Hannun
2024-05-08 08:18:13 -07:00
committed by GitHub
parent bfbc0e434a
commit ee60e2a9d5
22 changed files with 534 additions and 298 deletions

View File

@@ -18,6 +18,7 @@ from mlx.utils import tree_flatten
from transformers import AutoTokenizer, PreTrainedTokenizer
# Local imports
from .models.base import KVCache
from .sample_utils import top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers
@@ -160,7 +161,12 @@ def generate_step(
)
y = prompt
cache = None
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
cache = [KVCache(model.head_dim, n) for n in kv_heads]
repetition_context = prompt.tolist()
@@ -168,8 +174,8 @@ def generate_step(
repetition_context = repetition_context[-repetition_context_size:]
def _step(y):
nonlocal cache, repetition_context
logits, cache = model(y[None], cache=cache)
nonlocal repetition_context
logits = model(y[None], cache=cache)
logits = logits[:, -1, :]
if repetition_penalty:
@@ -445,9 +451,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
card.text = dedent(
f"""
# {upload_repo}
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
## Use with mlx
```bash