diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index b861b286..06a307a6 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -398,7 +398,11 @@ class Griffin(nn.Module): if cache is None: cache = [None] * len(self.layers) - mask = create_attention_mask(x, cache) + for i, block in enumerate(self.layers): + if block.temporal_block_type != "recurrent": + mask_cache = [cache[i]] + + mask = create_attention_mask(x, mask_cache) for i, block in enumerate(self.layers): x = block(x, mask=mask, cache=cache[i]) diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index a839f797..1efde5ae 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -3,8 +3,7 @@ import unittest import mlx.core as mx from mlx.utils import tree_map -from mlx_lm.models.base import KVCache, RotatingKVCache -from mlx_lm.utils import make_kv_caches +from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache class TestModels(unittest.TestCase): @@ -140,7 +139,7 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) - cache = make_kv_caches(model) + cache = make_prompt_cache(model) outputs = model(inputs, cache) self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t)