Fix llama cache check (#763)

* fix llama cache check

* add test
This commit is contained in:
Awni Hannun
2024-05-08 08:35:54 -07:00
committed by GitHub
parent ee60e2a9d5
commit fad9598372
2 changed files with 5 additions and 1 deletions

View File

@@ -29,6 +29,10 @@ class TestModels(unittest.TestCase):
)
cache = [KVCache(model.head_dim, n) for n in kv_heads]
outputs = model(inputs, cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
self.assertEqual(outputs.shape, (1, 1, vocab_size))
self.assertEqual(outputs.dtype, t)