mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
support different k and v head dims
This commit is contained in:
parent
79075b7a21
commit
83a7a17f84
@ -140,23 +140,23 @@ class QuantizedKVCache(_BaseCache):
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
B, n_kv_heads, num_steps, k_head_dim = keys.shape
|
||||
v_head_dim = values.shape[-1]
|
||||
prev = self.offset
|
||||
|
||||
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
|
||||
el_per_int = 8 * mx.uint32.size // self.bits
|
||||
new_steps = (self.step + num_steps - 1) // self.step * self.step
|
||||
shape = (B, n_kv_heads, new_steps, k_head_dim // el_per_int)
|
||||
group_shape = (B, n_kv_heads, new_steps, k_head_dim // self.group_size)
|
||||
shape = (B, n_kv_heads, new_steps)
|
||||
|
||||
def init_quant():
|
||||
def init_quant(dim):
|
||||
return (
|
||||
mx.zeros(shape, dtype=mx.uint32),
|
||||
mx.zeros(group_shape, dtype=keys.dtype),
|
||||
mx.zeros(group_shape, dtype=keys.dtype),
|
||||
mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
|
||||
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
||||
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
||||
)
|
||||
|
||||
def expand_quant(x):
|
||||
new_x = mx.zeros((B, n_kv_heads, new_steps, x.shape[-1]), dtype=x.dtype)
|
||||
new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
|
||||
return mx.concatenate([x, new_x], axis=-2)
|
||||
|
||||
if self.keys is not None:
|
||||
@ -169,7 +169,7 @@ class QuantizedKVCache(_BaseCache):
|
||||
expand_quant, (self.keys, self.values)
|
||||
)
|
||||
else:
|
||||
self.keys, self.values = init_quant(), init_quant()
|
||||
self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
|
||||
|
||||
self.offset += num_steps
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user