diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 81b16af3..14026f0c 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -135,7 +135,6 @@ class QuantizedKVCache(_BaseCache): self.values = None self.offset = 0 self.step = 256 - self.lengths = None self.group_size = group_size self.bits = bits @@ -218,7 +217,6 @@ class KVCache(_BaseCache): self.values = None self.offset = 0 self.step = 256 - self.lengths = None def update_and_fetch(self, keys, values): prev = self.offset @@ -287,7 +285,6 @@ class RotatingKVCache(_BaseCache): self.offset = 0 self.max_size = max_size self.step = step - self.lengths = None self._idx = 0 def _trim(self, trim_size, v, append=None): diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 61dd8c58..00c01436 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_map from mlx_lm.models import rope_utils -from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.base import create_causal_mask from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache @@ -129,17 +129,14 @@ class TestModels(unittest.TestCase): self.assertEqual(cache.offset, 22) self.assertTrue(mx.allclose(x, k[..., -2:, :])) - def test_cache_lengths(self): + def test_causal_mask_lengths(self): mx.random.seed(8) B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2) lengths = mx.array([1, 2, 3, 1]) - h = mx.random.uniform(shape=(B, T_q, D)) q = mx.random.uniform(shape=(B, N_q, T_q, D)) k = mx.random.uniform(shape=(B, N_kv, T_kv, D)) v = k - cache = [KVCache()] - cache[0].lengths = lengths - mask = create_attention_mask(h, cache) + mask = create_causal_mask(T_q, 0, lengths=lengths) out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) q[1, :, 2:] = mx.ones_like(q[1, :, 2:])