diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index ada05d0f..b49d0419 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -159,7 +159,7 @@ class LlamaModel(nn.Module): mask = None if h.shape[1] > 1: mask = create_additive_causal_mask( - h.shape[1], cache[0][0].shape[2] if cache is not None else 0 + h.shape[1], cache[0].offset if cache is not None else 0 ) mask = mask.astype(h.dtype) diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 225e9d27..eb5a0625 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -29,6 +29,10 @@ class TestModels(unittest.TestCase): ) cache = [KVCache(model.head_dim, n) for n in kv_heads] + outputs = model(inputs, cache) + self.assertEqual(outputs.shape, (1, 2, vocab_size)) + self.assertEqual(outputs.dtype, t) + outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache) self.assertEqual(outputs.shape, (1, 1, vocab_size)) self.assertEqual(outputs.dtype, t)