mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
support different k and v head dims
This commit is contained in:
parent
79075b7a21
commit
83a7a17f84
@ -140,23 +140,23 @@ class QuantizedKVCache(_BaseCache):
|
|||||||
|
|
||||||
def update_and_fetch(self, keys, values):
|
def update_and_fetch(self, keys, values):
|
||||||
B, n_kv_heads, num_steps, k_head_dim = keys.shape
|
B, n_kv_heads, num_steps, k_head_dim = keys.shape
|
||||||
|
v_head_dim = values.shape[-1]
|
||||||
prev = self.offset
|
prev = self.offset
|
||||||
|
|
||||||
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
|
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
|
||||||
el_per_int = 8 * mx.uint32.size // self.bits
|
el_per_int = 8 * mx.uint32.size // self.bits
|
||||||
new_steps = (self.step + num_steps - 1) // self.step * self.step
|
new_steps = (self.step + num_steps - 1) // self.step * self.step
|
||||||
shape = (B, n_kv_heads, new_steps, k_head_dim // el_per_int)
|
shape = (B, n_kv_heads, new_steps)
|
||||||
group_shape = (B, n_kv_heads, new_steps, k_head_dim // self.group_size)
|
|
||||||
|
|
||||||
def init_quant():
|
def init_quant(dim):
|
||||||
return (
|
return (
|
||||||
mx.zeros(shape, dtype=mx.uint32),
|
mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
|
||||||
mx.zeros(group_shape, dtype=keys.dtype),
|
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
||||||
mx.zeros(group_shape, dtype=keys.dtype),
|
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
def expand_quant(x):
|
def expand_quant(x):
|
||||||
new_x = mx.zeros((B, n_kv_heads, new_steps, x.shape[-1]), dtype=x.dtype)
|
new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
|
||||||
return mx.concatenate([x, new_x], axis=-2)
|
return mx.concatenate([x, new_x], axis=-2)
|
||||||
|
|
||||||
if self.keys is not None:
|
if self.keys is not None:
|
||||||
@ -169,7 +169,7 @@ class QuantizedKVCache(_BaseCache):
|
|||||||
expand_quant, (self.keys, self.values)
|
expand_quant, (self.keys, self.values)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.keys, self.values = init_quant(), init_quant()
|
self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
|
||||||
|
|
||||||
self.offset += num_steps
|
self.offset += num_steps
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user