Fix llama cache check (#763)

* fix llama cache check

* add test
This commit is contained in:
Awni Hannun 2024-05-08 08:35:54 -07:00 committed by GitHub
parent ee60e2a9d5
commit fad9598372
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 1 deletions

View File

@ -159,7 +159,7 @@ class LlamaModel(nn.Module):
mask = None
if h.shape[1] > 1:
mask = create_additive_causal_mask(
h.shape[1], cache[0][0].shape[2] if cache is not None else 0
h.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(h.dtype)

View File

@ -29,6 +29,10 @@ class TestModels(unittest.TestCase):
)
cache = [KVCache(model.head_dim, n) for n in kv_heads]
outputs = model(inputs, cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
self.assertEqual(outputs.shape, (1, 1, vocab_size))
self.assertEqual(outputs.dtype, t)