fix test:

This commit is contained in:
Alex Barron 2024-12-18 13:59:02 -08:00
parent ef895f6e5b
commit eb9452beb9
2 changed files with 3 additions and 9 deletions

View File

@ -135,7 +135,6 @@ class QuantizedKVCache(_BaseCache):
self.values = None
self.offset = 0
self.step = 256
self.lengths = None
self.group_size = group_size
self.bits = bits
@ -218,7 +217,6 @@ class KVCache(_BaseCache):
self.values = None
self.offset = 0
self.step = 256
self.lengths = None
def update_and_fetch(self, keys, values):
prev = self.offset
@ -287,7 +285,6 @@ class RotatingKVCache(_BaseCache):
self.offset = 0
self.max_size = max_size
self.step = step
self.lengths = None
self._idx = 0
def _trim(self, trim_size, v, append=None):

View File

@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map
from mlx_lm.models import rope_utils
from mlx_lm.models.base import create_attention_mask
from mlx_lm.models.base import create_causal_mask
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
@ -129,17 +129,14 @@ class TestModels(unittest.TestCase):
self.assertEqual(cache.offset, 22)
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
def test_cache_lengths(self):
def test_causal_mask_lengths(self):
mx.random.seed(8)
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
lengths = mx.array([1, 2, 3, 1])
h = mx.random.uniform(shape=(B, T_q, D))
q = mx.random.uniform(shape=(B, N_q, T_q, D))
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
v = k
cache = [KVCache()]
cache[0].lengths = lengths
mask = create_attention_mask(h, cache)
mask = create_causal_mask(T_q, 0, lengths=lengths)
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])