Handle longer prompt/generation (#931)

* rebase

* nits

* nit

* fix rotating cache with step prefill

* update version
This commit is contained in:
Awni Hannun 2024-08-16 15:28:39 -07:00 committed by GitHub
parent e196fa3208
commit 7be292c0c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 255 additions and 13 deletions

View File

@ -76,7 +76,12 @@ def setup_arg_parser():
type=int, type=int,
default=None, default=None,
help="Set the MLX cache limit in GB", help="Set the MLX cache limit in GB",
required=False, )
parser.add_argument(
"--max-kv-size",
type=int,
default=1024,
help="Set the maximum key-value cache size",
) )
return parser return parser
@ -154,6 +159,7 @@ def main():
formatter=formatter, formatter=formatter,
temp=args.temp, temp=args.temp,
top_p=args.top_p, top_p=args.top_p,
max_kv_size=args.max_kv_size,
) )

View File

@ -1,6 +1,8 @@
# Copyright © 2023-2024 Apple Inc.
import inspect import inspect
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import Any, List, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -44,6 +46,100 @@ class KVCache:
self.values[..., prev : self.offset, :] = values self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
def state(self):
return self.keys, self.values
class RotatingKVCache:
def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keep = keep
self.keys = None
self.values = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size, v, append=None):
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def update_and_fetch(self, keys, values):
prev = self.offset
B, _, S = keys.shape[:3]
# Prefill mode
if S > 1:
if self.keys is None:
self.keys = keys
self.values = values
else:
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self.keys.shape[2] - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += S
self._idx = self.keys.shape[2]
return self.keys, self.values
# Generation mode
# May not have hit the max size yet, so potentially
# keep growing the cache
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
new_size = min(self.step, self.max_size - prev)
k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + 1, :] = keys
self.values[..., self._idx : self._idx + 1, :] = values
self.offset += 1
self._idx += 1
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
def state(self):
return self.keys, self.values
@dataclass @dataclass
class BaseModelArgs: class BaseModelArgs:
@ -65,13 +161,17 @@ def create_additive_causal_mask(N: int, offset: int = 0):
return mask * -1e9 return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[List[KVCache]] = None): def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1] T = h.shape[1]
if T > 1: if T > 1:
# Input consists of multiple tokens, create a causal mask so that prior if cache is not None and cache[0] is not None:
# tokens do not give attention to later tokens. If a cache is in place c = cache[0]
# (because e.g. prompt reuse), offset the mask accordingly. if isinstance(c, RotatingKVCache):
offset = cache[0].offset if cache is not None and cache[0] is not None else 0 offset = min(c.max_size - 1, c.offset)
else:
offset = c.offset
else:
offset = 0
mask = create_additive_causal_mask(T, offset) mask = create_additive_causal_mask(T, offset)
mask = mask.astype(h.dtype) mask = mask.astype(h.dtype)
else: else:

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from sys import exit from sys import exit
from typing import Optional, Tuple from typing import Optional, Tuple

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple from typing import Tuple

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import inspect import inspect
import math import math
from dataclasses import dataclass from dataclasses import dataclass

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple from typing import Tuple

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Literal, Optional from typing import List, Literal, Optional
@ -53,6 +55,9 @@ class RecurrentCache:
def update(self, conv_state, recurrent_state): def update(self, conv_state, recurrent_state):
self._cache = (conv_state, recurrent_state) self._cache = (conv_state, recurrent_state)
def state(self):
return self._cache
class WindowKVCache: class WindowKVCache:
@ -80,6 +85,9 @@ class WindowKVCache:
self.values = _update(self.values, values) self.values = _update(self.values, values)
return self.keys, self.values return self.keys, self.values
def state(self):
return self.keys, self.values
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5): def __init__(self, dims: int, eps: float = 1e-5):

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple from typing import Tuple

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math import math
from typing import List, Union from typing import List, Union

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math import math
import mlx.core as mx import mlx.core as mx

View File

@ -19,7 +19,7 @@ from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
# Local imports # Local imports
from .models.base import KVCache from .models.base import KVCache, RotatingKVCache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers from .tuner.utils import apply_lora_layers
@ -136,6 +136,8 @@ def generate_step(
min_p: float = 0.0, min_p: float = 0.0,
min_tokens_to_keep: int = 1, min_tokens_to_keep: int = 1,
logit_bias: Optional[Dict[int, float]] = None, logit_bias: Optional[Dict[int, float]] = None,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -156,6 +158,9 @@ def generate_step(
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling. be filtered by min_p sampling.
logit_bias (dictionary, optional): Additive logit bias. logit_bias (dictionary, optional): Additive logit bias.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
Yields: Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
@ -197,7 +202,13 @@ def generate_step(
if isinstance(model.n_kv_heads, int) if isinstance(model.n_kv_heads, int)
else model.n_kv_heads else model.n_kv_heads
) )
cache = [KVCache(model.head_dim, n) for n in kv_heads] if max_kv_size is not None:
cache = [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
cache = [KVCache(model.head_dim, n) for n in kv_heads]
repetition_context = prompt.tolist() repetition_context = prompt.tolist()
@ -223,6 +234,11 @@ def generate_step(
repetition_context = repetition_context[-repetition_context_size:] repetition_context = repetition_context[-repetition_context_size:]
return y, logprobs.squeeze(0) return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
mx.eval([c.state for c in cache])
y = y[prefill_step_size:]
y, logprobs = _step(y) y, logprobs = _step(y)
mx.async_eval(y) mx.async_eval(y)
@ -343,8 +359,10 @@ def generate(
return return
prompt_tps = prompt_tokens.size / prompt_time prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time gen_tps = (token_count - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec") print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text return detokenizer.text

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.16.0" __version__ = "0.17.0"

View File

@ -4,7 +4,7 @@ import unittest
import mlx.core as mx import mlx.core as mx
from mlx.utils import tree_map from mlx.utils import tree_map
from mlx_lm.models.base import KVCache from mlx_lm.models.base import KVCache, RotatingKVCache
class TestModels(unittest.TestCase): class TestModels(unittest.TestCase):
@ -29,6 +29,64 @@ class TestModels(unittest.TestCase):
self.assertTrue(mx.array_equal(v_up, expected)) self.assertTrue(mx.array_equal(v_up, expected))
self.assertEqual(cache.offset, cache.step + 1) self.assertEqual(cache.offset, cache.step + 1)
def test_rotating_kv_cache(self):
b, h, d = 1, 2, 32
cache = RotatingKVCache(d, h, max_size=8, step=4)
k = mx.random.uniform(shape=(b, h, 2, d))
v = mx.random.uniform(shape=(b, h, 2, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up, k))
self.assertTrue(mx.array_equal(v_up, v))
self.assertEqual(cache.offset, 2)
k = mx.random.uniform(shape=(b, h, 5, d))
v = mx.random.uniform(shape=(b, h, 5, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., 2:, :], k))
self.assertTrue(mx.array_equal(v_up[..., 2:, :], v))
k = mx.random.uniform(shape=(b, h, 4, d))
v = mx.random.uniform(shape=(b, h, 4, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., -4:, :], k))
self.assertTrue(mx.array_equal(v_up[..., -4:, :], v))
idx = 0
for _ in range(10):
k = mx.random.uniform(shape=(b, h, 1, d))
v = mx.random.uniform(shape=(b, h, 1, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
idx += 1
idx %= 8
# Try with nonzero keep
cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2)
# Check a large update
k = mx.random.uniform(shape=(b, h, 20, d))
v = mx.random.uniform(shape=(b, h, 20, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up, k))
self.assertTrue(mx.array_equal(v_up, v))
# A bunch of small updates
self.assertEqual(cache.offset, 20)
idx = 2
for i in range(10):
k = mx.random.uniform(shape=(b, h, 1, d))
v = mx.random.uniform(shape=(b, h, 1, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
self.assertEqual(cache.offset, 21 + i)
idx += 1
if idx >= 8:
idx = 2
def model_test_runner(self, model, model_type, vocab_size, num_layers): def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers) self.assertEqual(len(model.layers), num_layers)