From 83a7a17f84c9add71bf695b1d787e89a14e1b750 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Thu, 31 Oct 2024 16:24:40 -0700 Subject: [PATCH] support different k and v head dims --- llms/mlx_lm/models/cache.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index c11aa51f..1cd5289d 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -140,23 +140,23 @@ class QuantizedKVCache(_BaseCache): def update_and_fetch(self, keys, values): B, n_kv_heads, num_steps, k_head_dim = keys.shape + v_head_dim = values.shape[-1] prev = self.offset if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]: el_per_int = 8 * mx.uint32.size // self.bits new_steps = (self.step + num_steps - 1) // self.step * self.step - shape = (B, n_kv_heads, new_steps, k_head_dim // el_per_int) - group_shape = (B, n_kv_heads, new_steps, k_head_dim // self.group_size) + shape = (B, n_kv_heads, new_steps) - def init_quant(): + def init_quant(dim): return ( - mx.zeros(shape, dtype=mx.uint32), - mx.zeros(group_shape, dtype=keys.dtype), - mx.zeros(group_shape, dtype=keys.dtype), + mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), ) def expand_quant(x): - new_x = mx.zeros((B, n_kv_heads, new_steps, x.shape[-1]), dtype=x.dtype) + new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype) return mx.concatenate([x, new_x], axis=-2) if self.keys is not None: @@ -169,7 +169,7 @@ class QuantizedKVCache(_BaseCache): expand_quant, (self.keys, self.values) ) else: - self.keys, self.values = init_quant(), init_quant() + self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim) self.offset += num_steps