mlx-examples/llms/mlx_lm/models/base.py
Awni Hannun ee60e2a9d5
Kv cache (#643)
* in place kv_cache

* fix

* fix kv cache size

* partially fix kv cache dtype

* step kv cache

* multiple of step size

* more teests + kv cache

* more kv cache

* udpate all models to use kv cache
2024-05-08 08:18:13 -07:00

47 lines
1.4 KiB
Python

import inspect
from dataclasses import dataclass
import mlx.core as mx
class KVCache:
def __init__(self, head_dim, n_kv_heads):
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if prev % self.step == 0:
n_steps = (self.step + keys.shape[2] - 1) // self.step
shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim)
new_k = mx.zeros(shape, keys.dtype)
new_v = mx.zeros(shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@dataclass
class BaseModelArgs:
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)