mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
fix tests
This commit is contained in:
parent
5f52882e32
commit
60c9794618
@ -398,7 +398,11 @@ class Griffin(nn.Module):
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
mask = create_attention_mask(x, cache)
|
||||
for i, block in enumerate(self.layers):
|
||||
if block.temporal_block_type != "recurrent":
|
||||
mask_cache = [cache[i]]
|
||||
|
||||
mask = create_attention_mask(x, mask_cache)
|
||||
|
||||
for i, block in enumerate(self.layers):
|
||||
x = block(x, mask=mask, cache=cache[i])
|
||||
|
@ -3,8 +3,7 @@ import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
from mlx_lm.models.base import KVCache, RotatingKVCache
|
||||
from mlx_lm.utils import make_kv_caches
|
||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||
|
||||
|
||||
class TestModels(unittest.TestCase):
|
||||
@ -140,7 +139,7 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
cache = make_kv_caches(model)
|
||||
cache = make_prompt_cache(model)
|
||||
outputs = model(inputs, cache)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
Loading…
Reference in New Issue
Block a user