support different k and v head dims

This commit is contained in:
Alex Barron 2024-10-31 16:24:40 -07:00
parent 79075b7a21
commit 83a7a17f84

View File

@ -140,23 +140,23 @@ class QuantizedKVCache(_BaseCache):
def update_and_fetch(self, keys, values): def update_and_fetch(self, keys, values):
B, n_kv_heads, num_steps, k_head_dim = keys.shape B, n_kv_heads, num_steps, k_head_dim = keys.shape
v_head_dim = values.shape[-1]
prev = self.offset prev = self.offset
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]: if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
el_per_int = 8 * mx.uint32.size // self.bits el_per_int = 8 * mx.uint32.size // self.bits
new_steps = (self.step + num_steps - 1) // self.step * self.step new_steps = (self.step + num_steps - 1) // self.step * self.step
shape = (B, n_kv_heads, new_steps, k_head_dim // el_per_int) shape = (B, n_kv_heads, new_steps)
group_shape = (B, n_kv_heads, new_steps, k_head_dim // self.group_size)
def init_quant(): def init_quant(dim):
return ( return (
mx.zeros(shape, dtype=mx.uint32), mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
mx.zeros(group_shape, dtype=keys.dtype), mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
mx.zeros(group_shape, dtype=keys.dtype), mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
) )
def expand_quant(x): def expand_quant(x):
new_x = mx.zeros((B, n_kv_heads, new_steps, x.shape[-1]), dtype=x.dtype) new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
return mx.concatenate([x, new_x], axis=-2) return mx.concatenate([x, new_x], axis=-2)
if self.keys is not None: if self.keys is not None:
@ -169,7 +169,7 @@ class QuantizedKVCache(_BaseCache):
expand_quant, (self.keys, self.values) expand_quant, (self.keys, self.values)
) )
else: else:
self.keys, self.values = init_quant(), init_quant() self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
self.offset += num_steps self.offset += num_steps