mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Handle longer prompt/generation (#931)
* rebase * nits * nit * fix rotating cache with step prefill * update version
This commit is contained in:
parent
e196fa3208
commit
7be292c0c9
@ -76,7 +76,12 @@ def setup_arg_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Set the MLX cache limit in GB",
|
help="Set the MLX cache limit in GB",
|
||||||
required=False,
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-kv-size",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="Set the maximum key-value cache size",
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -154,6 +159,7 @@ def main():
|
|||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
temp=args.temp,
|
temp=args.temp,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
|
max_kv_size=args.max_kv_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -44,6 +46,100 @@ class KVCache:
|
|||||||
self.values[..., prev : self.offset, :] = values
|
self.values[..., prev : self.offset, :] = values
|
||||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||||
|
|
||||||
|
def state(self):
|
||||||
|
return self.keys, self.values
|
||||||
|
|
||||||
|
|
||||||
|
class RotatingKVCache:
|
||||||
|
|
||||||
|
def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
|
||||||
|
self.n_kv_heads = n_kv_heads
|
||||||
|
if isinstance(head_dim, int):
|
||||||
|
self.k_head_dim = self.v_head_dim = head_dim
|
||||||
|
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
|
||||||
|
self.k_head_dim, self.v_head_dim = head_dim
|
||||||
|
else:
|
||||||
|
raise ValueError("head_dim must be an int or a tuple of two ints")
|
||||||
|
self.keep = keep
|
||||||
|
self.keys = None
|
||||||
|
self.values = None
|
||||||
|
self.offset = 0
|
||||||
|
self.max_size = max_size
|
||||||
|
self.step = step
|
||||||
|
self._idx = 0
|
||||||
|
|
||||||
|
def _trim(self, trim_size, v, append=None):
|
||||||
|
to_cat = []
|
||||||
|
if trim_size > 0:
|
||||||
|
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
|
||||||
|
else:
|
||||||
|
to_cat = [v]
|
||||||
|
if append is not None:
|
||||||
|
to_cat.append(append)
|
||||||
|
return mx.concatenate(to_cat, axis=2)
|
||||||
|
|
||||||
|
def update_and_fetch(self, keys, values):
|
||||||
|
prev = self.offset
|
||||||
|
B, _, S = keys.shape[:3]
|
||||||
|
|
||||||
|
# Prefill mode
|
||||||
|
if S > 1:
|
||||||
|
if self.keys is None:
|
||||||
|
self.keys = keys
|
||||||
|
self.values = values
|
||||||
|
else:
|
||||||
|
# The largest size is self.max_size + S - 1 to ensure
|
||||||
|
# every token gets at least self.max_size context
|
||||||
|
trim_size = self.keys.shape[2] - self.max_size + 1
|
||||||
|
self.keys = self._trim(trim_size, self.keys, keys)
|
||||||
|
self.values = self._trim(trim_size, self.values, values)
|
||||||
|
self.offset += S
|
||||||
|
self._idx = self.keys.shape[2]
|
||||||
|
return self.keys, self.values
|
||||||
|
|
||||||
|
# Generation mode
|
||||||
|
# May not have hit the max size yet, so potentially
|
||||||
|
# keep growing the cache
|
||||||
|
if self.keys is None or (
|
||||||
|
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
|
||||||
|
):
|
||||||
|
new_size = min(self.step, self.max_size - prev)
|
||||||
|
k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
|
||||||
|
v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
|
||||||
|
new_k = mx.zeros(k_shape, keys.dtype)
|
||||||
|
new_v = mx.zeros(v_shape, values.dtype)
|
||||||
|
if self.keys is not None:
|
||||||
|
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||||
|
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||||
|
else:
|
||||||
|
self.keys, self.values = new_k, new_v
|
||||||
|
self._idx = prev
|
||||||
|
|
||||||
|
# Trim if needed
|
||||||
|
trim_size = self.keys.shape[2] - self.max_size
|
||||||
|
if trim_size > 0:
|
||||||
|
self.keys = self._trim(trim_size, self.keys)
|
||||||
|
self.values = self._trim(trim_size, self.values)
|
||||||
|
self._idx = self.max_size
|
||||||
|
|
||||||
|
# Rotate
|
||||||
|
if self._idx == self.max_size:
|
||||||
|
self._idx = self.keep
|
||||||
|
|
||||||
|
# Assign
|
||||||
|
self.keys[..., self._idx : self._idx + 1, :] = keys
|
||||||
|
self.values[..., self._idx : self._idx + 1, :] = values
|
||||||
|
self.offset += 1
|
||||||
|
self._idx += 1
|
||||||
|
|
||||||
|
# If the buffer is not full, slice off the end
|
||||||
|
if self.offset < self.max_size:
|
||||||
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||||
|
return self.keys, self.values
|
||||||
|
|
||||||
|
def state(self):
|
||||||
|
return self.keys, self.values
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelArgs:
|
class BaseModelArgs:
|
||||||
@ -65,13 +161,17 @@ def create_additive_causal_mask(N: int, offset: int = 0):
|
|||||||
return mask * -1e9
|
return mask * -1e9
|
||||||
|
|
||||||
|
|
||||||
def create_attention_mask(h: mx.array, cache: Optional[List[KVCache]] = None):
|
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||||
T = h.shape[1]
|
T = h.shape[1]
|
||||||
if T > 1:
|
if T > 1:
|
||||||
# Input consists of multiple tokens, create a causal mask so that prior
|
if cache is not None and cache[0] is not None:
|
||||||
# tokens do not give attention to later tokens. If a cache is in place
|
c = cache[0]
|
||||||
# (because e.g. prompt reuse), offset the mask accordingly.
|
if isinstance(c, RotatingKVCache):
|
||||||
offset = cache[0].offset if cache is not None and cache[0] is not None else 0
|
offset = min(c.max_size - 1, c.offset)
|
||||||
|
else:
|
||||||
|
offset = c.offset
|
||||||
|
else:
|
||||||
|
offset = 0
|
||||||
mask = create_additive_causal_mask(T, offset)
|
mask = create_additive_causal_mask(T, offset)
|
||||||
mask = mask.astype(h.dtype)
|
mask = mask.astype(h.dtype)
|
||||||
else:
|
else:
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from sys import exit
|
from sys import exit
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Literal, Optional
|
from typing import List, Literal, Optional
|
||||||
@ -53,6 +55,9 @@ class RecurrentCache:
|
|||||||
def update(self, conv_state, recurrent_state):
|
def update(self, conv_state, recurrent_state):
|
||||||
self._cache = (conv_state, recurrent_state)
|
self._cache = (conv_state, recurrent_state)
|
||||||
|
|
||||||
|
def state(self):
|
||||||
|
return self._cache
|
||||||
|
|
||||||
|
|
||||||
class WindowKVCache:
|
class WindowKVCache:
|
||||||
|
|
||||||
@ -80,6 +85,9 @@ class WindowKVCache:
|
|||||||
self.values = _update(self.values, values)
|
self.values = _update(self.values, values)
|
||||||
return self.keys, self.values
|
return self.keys, self.values
|
||||||
|
|
||||||
|
def state(self):
|
||||||
|
return self.keys, self.values
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dims: int, eps: float = 1e-5):
|
def __init__(self, dims: int, eps: float = 1e-5):
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
@ -19,7 +19,7 @@ from mlx.utils import tree_flatten
|
|||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .models.base import KVCache
|
from .models.base import KVCache, RotatingKVCache
|
||||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import apply_lora_layers
|
from .tuner.utils import apply_lora_layers
|
||||||
@ -136,6 +136,8 @@ def generate_step(
|
|||||||
min_p: float = 0.0,
|
min_p: float = 0.0,
|
||||||
min_tokens_to_keep: int = 1,
|
min_tokens_to_keep: int = 1,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
|
prefill_step_size: int = 512,
|
||||||
|
max_kv_size: Optional[int] = None,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
@ -156,6 +158,9 @@ def generate_step(
|
|||||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||||
be filtered by min_p sampling.
|
be filtered by min_p sampling.
|
||||||
logit_bias (dictionary, optional): Additive logit bias.
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
|
prefill_step_size (int): Step size for processing the prompt.
|
||||||
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||||
|
entries (except the first 4 tokens) will be overwritten.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||||
@ -197,7 +202,13 @@ def generate_step(
|
|||||||
if isinstance(model.n_kv_heads, int)
|
if isinstance(model.n_kv_heads, int)
|
||||||
else model.n_kv_heads
|
else model.n_kv_heads
|
||||||
)
|
)
|
||||||
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
if max_kv_size is not None:
|
||||||
|
cache = [
|
||||||
|
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
|
||||||
|
for n in kv_heads
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
||||||
|
|
||||||
repetition_context = prompt.tolist()
|
repetition_context = prompt.tolist()
|
||||||
|
|
||||||
@ -223,6 +234,11 @@ def generate_step(
|
|||||||
repetition_context = repetition_context[-repetition_context_size:]
|
repetition_context = repetition_context[-repetition_context_size:]
|
||||||
return y, logprobs.squeeze(0)
|
return y, logprobs.squeeze(0)
|
||||||
|
|
||||||
|
while y.size > prefill_step_size:
|
||||||
|
model(y[:prefill_step_size][None], cache=cache)
|
||||||
|
mx.eval([c.state for c in cache])
|
||||||
|
y = y[prefill_step_size:]
|
||||||
|
|
||||||
y, logprobs = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
|
||||||
mx.async_eval(y)
|
mx.async_eval(y)
|
||||||
@ -343,8 +359,10 @@ def generate(
|
|||||||
return
|
return
|
||||||
prompt_tps = prompt_tokens.size / prompt_time
|
prompt_tps = prompt_tokens.size / prompt_time
|
||||||
gen_tps = (token_count - 1) / gen_time
|
gen_tps = (token_count - 1) / gen_time
|
||||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
||||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||||
|
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||||
|
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||||
|
|
||||||
return detokenizer.text
|
return detokenizer.text
|
||||||
|
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.16.0"
|
__version__ = "0.17.0"
|
||||||
|
@ -4,7 +4,7 @@ import unittest
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
from mlx_lm.models.base import KVCache
|
from mlx_lm.models.base import KVCache, RotatingKVCache
|
||||||
|
|
||||||
|
|
||||||
class TestModels(unittest.TestCase):
|
class TestModels(unittest.TestCase):
|
||||||
@ -29,6 +29,64 @@ class TestModels(unittest.TestCase):
|
|||||||
self.assertTrue(mx.array_equal(v_up, expected))
|
self.assertTrue(mx.array_equal(v_up, expected))
|
||||||
self.assertEqual(cache.offset, cache.step + 1)
|
self.assertEqual(cache.offset, cache.step + 1)
|
||||||
|
|
||||||
|
def test_rotating_kv_cache(self):
|
||||||
|
b, h, d = 1, 2, 32
|
||||||
|
cache = RotatingKVCache(d, h, max_size=8, step=4)
|
||||||
|
|
||||||
|
k = mx.random.uniform(shape=(b, h, 2, d))
|
||||||
|
v = mx.random.uniform(shape=(b, h, 2, d))
|
||||||
|
|
||||||
|
k_up, v_up = cache.update_and_fetch(k, v)
|
||||||
|
self.assertTrue(mx.array_equal(k_up, k))
|
||||||
|
self.assertTrue(mx.array_equal(v_up, v))
|
||||||
|
self.assertEqual(cache.offset, 2)
|
||||||
|
|
||||||
|
k = mx.random.uniform(shape=(b, h, 5, d))
|
||||||
|
v = mx.random.uniform(shape=(b, h, 5, d))
|
||||||
|
k_up, v_up = cache.update_and_fetch(k, v)
|
||||||
|
self.assertTrue(mx.array_equal(k_up[..., 2:, :], k))
|
||||||
|
self.assertTrue(mx.array_equal(v_up[..., 2:, :], v))
|
||||||
|
|
||||||
|
k = mx.random.uniform(shape=(b, h, 4, d))
|
||||||
|
v = mx.random.uniform(shape=(b, h, 4, d))
|
||||||
|
k_up, v_up = cache.update_and_fetch(k, v)
|
||||||
|
self.assertTrue(mx.array_equal(k_up[..., -4:, :], k))
|
||||||
|
self.assertTrue(mx.array_equal(v_up[..., -4:, :], v))
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
for _ in range(10):
|
||||||
|
k = mx.random.uniform(shape=(b, h, 1, d))
|
||||||
|
v = mx.random.uniform(shape=(b, h, 1, d))
|
||||||
|
k_up, v_up = cache.update_and_fetch(k, v)
|
||||||
|
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
|
||||||
|
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
|
||||||
|
idx += 1
|
||||||
|
idx %= 8
|
||||||
|
|
||||||
|
# Try with nonzero keep
|
||||||
|
cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2)
|
||||||
|
|
||||||
|
# Check a large update
|
||||||
|
k = mx.random.uniform(shape=(b, h, 20, d))
|
||||||
|
v = mx.random.uniform(shape=(b, h, 20, d))
|
||||||
|
k_up, v_up = cache.update_and_fetch(k, v)
|
||||||
|
self.assertTrue(mx.array_equal(k_up, k))
|
||||||
|
self.assertTrue(mx.array_equal(v_up, v))
|
||||||
|
|
||||||
|
# A bunch of small updates
|
||||||
|
self.assertEqual(cache.offset, 20)
|
||||||
|
idx = 2
|
||||||
|
for i in range(10):
|
||||||
|
k = mx.random.uniform(shape=(b, h, 1, d))
|
||||||
|
v = mx.random.uniform(shape=(b, h, 1, d))
|
||||||
|
k_up, v_up = cache.update_and_fetch(k, v)
|
||||||
|
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
|
||||||
|
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
|
||||||
|
self.assertEqual(cache.offset, 21 + i)
|
||||||
|
idx += 1
|
||||||
|
if idx >= 8:
|
||||||
|
idx = 2
|
||||||
|
|
||||||
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
||||||
|
|
||||||
self.assertEqual(len(model.layers), num_layers)
|
self.assertEqual(len(model.layers), num_layers)
|
||||||
|
Loading…
Reference in New Issue
Block a user