mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 07:30:06 +08:00
fix test:
This commit is contained in:
parent
ef895f6e5b
commit
eb9452beb9
@ -135,7 +135,6 @@ class QuantizedKVCache(_BaseCache):
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
self.lengths = None
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
@ -218,7 +217,6 @@ class KVCache(_BaseCache):
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
self.lengths = None
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
@ -287,7 +285,6 @@ class RotatingKVCache(_BaseCache):
|
||||
self.offset = 0
|
||||
self.max_size = max_size
|
||||
self.step = step
|
||||
self.lengths = None
|
||||
self._idx = 0
|
||||
|
||||
def _trim(self, trim_size, v, append=None):
|
||||
|
@ -5,7 +5,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map
|
||||
from mlx_lm.models import rope_utils
|
||||
from mlx_lm.models.base import create_attention_mask
|
||||
from mlx_lm.models.base import create_causal_mask
|
||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||
|
||||
|
||||
@ -129,17 +129,14 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(cache.offset, 22)
|
||||
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||
|
||||
def test_cache_lengths(self):
|
||||
def test_causal_mask_lengths(self):
|
||||
mx.random.seed(8)
|
||||
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
|
||||
lengths = mx.array([1, 2, 3, 1])
|
||||
h = mx.random.uniform(shape=(B, T_q, D))
|
||||
q = mx.random.uniform(shape=(B, N_q, T_q, D))
|
||||
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
|
||||
v = k
|
||||
cache = [KVCache()]
|
||||
cache[0].lengths = lengths
|
||||
mask = create_attention_mask(h, cache)
|
||||
mask = create_causal_mask(T_q, 0, lengths=lengths)
|
||||
|
||||
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])
|
||||
|
Loading…
Reference in New Issue
Block a user