caching in server

This commit is contained in:
Awni Hannun
2024-10-09 12:46:44 -07:00
parent 6c368f2124
commit cdba586b67
3 changed files with 134 additions and 25 deletions

View File

@@ -1,5 +1,6 @@
# Copyright © 2024 Apple Inc.
import copy
import os
import tempfile
import unittest
@@ -215,6 +216,28 @@ class TestPromptCache(unittest.TestCase):
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
)
def test_cache_copying(self):
cache = [KVCache()]
x = mx.random.uniform(shape=(1, 8, 10, 4))
cache[0].update_and_fetch(x, x)
y = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(y, y)
old_cache = copy.deepcopy(cache)
trim_prompt_cache(cache, 1)
self.assertTrue(old_cache[0].offset, 11)
self.assertTrue(cache[0].offset, 10)
z = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(z, z)
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
if __name__ == "__main__":
unittest.main()