From 79075b7a212a0f3487c9e3098c72176fb9c099ab Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Thu, 31 Oct 2024 12:37:15 -0700 Subject: [PATCH] fix tests --- llms/mlx_lm/generate.py | 4 ++-- llms/mlx_lm/models/phi.py | 7 ++++++- llms/mlx_lm/models/phixtral.py | 7 ++++++- llms/mlx_lm/models/plamo.py | 1 + llms/mlx_lm/utils.py | 10 +++++++--- 5 files changed, 22 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index ed3ddd0c..0355ca29 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -258,9 +258,9 @@ def main(): top_p=args.top_p, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, - quantized_kv_start=args.quantized_kv_start, - kv_group_size=args.kv_group_size, kv_bits=args.kv_bits, + kv_group_size=args.kv_group_size, + quantized_kv_start=args.quantized_kv_start, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 5bd8603d..510025ea 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -94,7 +94,12 @@ class PhiAttention(nn.Module): scale = math.sqrt(1 / queries.shape[-1]) output = scaled_dot_product_attention( - queries.astype(mx.float32), keys, values, scale=scale, mask=mask + queries.astype(mx.float32), + keys, + values, + cache=cache, + scale=scale, + mask=mask, ).astype(values.dtype) output = output.moveaxis(2, 1).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 67084d20..42d647b0 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -72,7 +72,12 @@ class RoPEAttention(nn.Module): scale = math.sqrt(1 / queries.shape[-1]) output = scaled_dot_product_attention( - queries.astype(mx.float32), keys, values, scale=scale, mask=mask + queries.astype(mx.float32), + keys, + values, + cache=cache, + scale=scale, + mask=mask, ).astype(values.dtype) output = output.moveaxis(2, 1).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index a87c6cac..c8e5bf50 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -96,6 +96,7 @@ class Attention(nn.Module): queries, keys, values, + cache=cache, scale=self.scale, mask=attention_mask, ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 71b85861..06784f10 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -185,9 +185,9 @@ def generate_step( prompt_cache: Optional[Any] = None, logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, - quantized_kv_start: Optional[int] = None, + kv_bits: Optional[int] = None, kv_group_size: int = 64, - kv_bits: int = 8, + quantized_kv_start: int = 0, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -216,6 +216,11 @@ def generate_step( logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed logits. Default: ``None``. + kv_bits (int, optional): Number of bits to use for KV cache quantization. + None implies no cache quantization. Default: ``None``. + kv_group_size (int): Group size for KV cache quantization. Default: ``64``. + quantized_kv_start (int): Step to begin using a quantized KV cache. + when ``kv_bits`` is non-None. Default: ``0``. Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing @@ -279,7 +284,6 @@ def generate_step( def _step(y): - nonlocal prompt_cache logits = model(y[None], cache=prompt_cache) logits = logits[:, -1, :]