mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:17:07 +08:00
fix test:
This commit is contained in:
parent
ef895f6e5b
commit
eb9452beb9
@ -135,7 +135,6 @@ class QuantizedKVCache(_BaseCache):
|
|||||||
self.values = None
|
self.values = None
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self.step = 256
|
self.step = 256
|
||||||
self.lengths = None
|
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
|
|
||||||
@ -218,7 +217,6 @@ class KVCache(_BaseCache):
|
|||||||
self.values = None
|
self.values = None
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self.step = 256
|
self.step = 256
|
||||||
self.lengths = None
|
|
||||||
|
|
||||||
def update_and_fetch(self, keys, values):
|
def update_and_fetch(self, keys, values):
|
||||||
prev = self.offset
|
prev = self.offset
|
||||||
@ -287,7 +285,6 @@ class RotatingKVCache(_BaseCache):
|
|||||||
self.offset = 0
|
self.offset = 0
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.step = step
|
self.step = step
|
||||||
self.lengths = None
|
|
||||||
self._idx = 0
|
self._idx = 0
|
||||||
|
|
||||||
def _trim(self, trim_size, v, append=None):
|
def _trim(self, trim_size, v, append=None):
|
||||||
|
@ -5,7 +5,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
from mlx_lm.models import rope_utils
|
from mlx_lm.models import rope_utils
|
||||||
from mlx_lm.models.base import create_attention_mask
|
from mlx_lm.models.base import create_causal_mask
|
||||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||||
|
|
||||||
|
|
||||||
@ -129,17 +129,14 @@ class TestModels(unittest.TestCase):
|
|||||||
self.assertEqual(cache.offset, 22)
|
self.assertEqual(cache.offset, 22)
|
||||||
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||||
|
|
||||||
def test_cache_lengths(self):
|
def test_causal_mask_lengths(self):
|
||||||
mx.random.seed(8)
|
mx.random.seed(8)
|
||||||
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
|
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
|
||||||
lengths = mx.array([1, 2, 3, 1])
|
lengths = mx.array([1, 2, 3, 1])
|
||||||
h = mx.random.uniform(shape=(B, T_q, D))
|
|
||||||
q = mx.random.uniform(shape=(B, N_q, T_q, D))
|
q = mx.random.uniform(shape=(B, N_q, T_q, D))
|
||||||
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
|
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
|
||||||
v = k
|
v = k
|
||||||
cache = [KVCache()]
|
mask = create_causal_mask(T_q, 0, lengths=lengths)
|
||||||
cache[0].lengths = lengths
|
|
||||||
mask = create_attention_mask(h, cache)
|
|
||||||
|
|
||||||
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||||
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])
|
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])
|
||||||
|
Loading…
Reference in New Issue
Block a user