mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 09:48:54 +08:00
Kv cache (#643)
* in place kv_cache * fix * fix kv cache size * partially fix kv cache dtype * step kv cache * multiple of step size * more teests + kv cache * more kv cache * udpate all models to use kv cache
This commit is contained in:
@@ -18,6 +18,7 @@ from mlx.utils import tree_flatten
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models.base import KVCache
|
||||
from .sample_utils import top_p_sampling
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import apply_lora_layers
|
||||
@@ -160,7 +161,12 @@ def generate_step(
|
||||
)
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
kv_heads = (
|
||||
[model.n_kv_heads] * len(model.layers)
|
||||
if isinstance(model.n_kv_heads, int)
|
||||
else model.n_kv_heads
|
||||
)
|
||||
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
|
||||
repetition_context = prompt.tolist()
|
||||
|
||||
@@ -168,8 +174,8 @@ def generate_step(
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
|
||||
def _step(y):
|
||||
nonlocal cache, repetition_context
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
nonlocal repetition_context
|
||||
logits = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if repetition_penalty:
|
||||
@@ -445,9 +451,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
card.text = dedent(
|
||||
f"""
|
||||
# {upload_repo}
|
||||
|
||||
|
||||
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
|
||||
|
||||
|
||||
## Use with mlx
|
||||
|
||||
```bash
|
||||
|
||||
Reference in New Issue
Block a user