diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 381033b3..efdbaf13 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -8,13 +8,9 @@ import time import mlx.core as mx from .models.cache import make_prompt_cache, save_prompt_cache -from .utils import ( - DEFAULT_KV_BITS, - DEFAULT_KV_GROUP_SIZE, - check_quantized_kv_args, - load, - maybe_quantize_kv_cache, -) +from .utils import load, maybe_quantize_kv_cache + +DEFAULT_QUANTIZED_KV_START = 5000 def setup_arg_parser(): @@ -77,22 +73,24 @@ def setup_arg_parser(): help="Message to be processed by the model ('-' reads from stdin)", ) parser.add_argument( - "--quantized-kv-start", - help="Use a quantized KV cache from this step onwards.", + "--kv-bits", type=int, + help="Number of bits for KV cache quantization. " + "Defaults to no quantization.", default=None, ) parser.add_argument( "--kv-group-size", type=int, - help="Group size for kv cache quantization.", - default=DEFAULT_KV_GROUP_SIZE, + help="Group size for KV cache quantization.", + default=64, ) parser.add_argument( - "--kv-bits", + "--quantized-kv-start", + help="When --kv-bits is set, start quantizing the KV cache " + "from this step onwards.", type=int, - help="Number of bits for kv cache quantization.", - default=DEFAULT_KV_BITS, + default=DEFAULT_QUANTIZED_KV_START, ) return parser @@ -117,8 +115,6 @@ def main(): args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt - check_quantized_kv_args(args.quantized_kv_start, args.kv_group_size, args.kv_bits) - if args.use_default_chat_template: if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 9b84e7e6..ed3ddd0c 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -7,7 +7,7 @@ import sys import mlx.core as mx from .models.cache import QuantizedKVCache, load_prompt_cache -from .utils import check_quantized_kv_args, generate, load +from .utils import generate, load DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 @@ -15,6 +15,7 @@ DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" +DEFAULT_QUANTIZED_KV_START = 5000 def str2bool(string): @@ -108,24 +109,24 @@ def setup_arg_parser(): help="A file containing saved KV caches to avoid recomputing them", ) parser.add_argument( - "--quantized-kv-start", - help="Use a quantized KV cache from this step onwards.", + "--kv-bits", type=int, + help="Number of bits for KV cache quantization. " + "Defaults to no quantization.", default=None, ) parser.add_argument( "--kv-group-size", type=int, - help="Group size for kv cache quantization. " - "--quantized-kv-start must be provided to have an effect.", + help="Group size for KV cache quantization.", default=64, ) parser.add_argument( - "--kv-bits", + "--quantized-kv-start", + help="When --kv-bits is set, start quantizing the KV cache " + "from this step onwards.", type=int, - help="Number of bits for kv cache quantization. " - "--quantized-kv-start must be provided to have an effect.", - default=8, + default=DEFAULT_QUANTIZED_KV_START, ) return parser @@ -173,13 +174,15 @@ def main(): args.prompt_cache_file, return_metadata=True, ) - if args.quantized_kv_start and isinstance(prompt_cache[0], QuantizedKVCache): - raise ValueError( - "Specified `--quantized-kv-start` but cache from " - "`--prompt-cache-file` is already quantized." - ) - - check_quantized_kv_args(args.quantized_kv_start, args.kv_group_size, args.kv_bits) + if isinstance(prompt_cache[0], QuantizedKVCache): + if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits: + raise ValueError( + "--kv-bits does not match the kv cache loaded from --prompt-cache-file." + ) + if args.kv_group_size != prompt_cache[0].group_size: + raise ValueError( + "--kv-group-size does not match the kv cache loaded from --prompt-cache-file." + ) # Building tokenizer_config tokenizer_config = ( diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 3628a808..cda41c79 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -5,6 +5,9 @@ from dataclasses import dataclass from typing import Any, Optional import mlx.core as mx +from mlx.utils import tree_map + +from .cache import QuantizedKVCache @dataclass @@ -48,3 +51,63 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): else: mask = None return mask + + +def quantized_scaled_dot_product_attention( + queries: mx.array, + q_keys: tuple[mx.array, mx.array, mx.array], + q_values: tuple[mx.array, mx.array, mx.array], + scale: float, + mask: Optional[mx.array], + group_size: int = 64, + bits: int = 8, +) -> mx.array: + B, n_q_heads, L, D = queries.shape + n_kv_heads = q_keys[0].shape[-3] + n_repeats = n_q_heads // n_kv_heads + + queries *= scale + + if n_repeats > 1: + queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D)) + q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys) + q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values) + + scores = mx.quantized_matmul( + queries, *q_keys, transpose=True, group_size=group_size, bits=bits + ) + if mask is not None: + scores += mask + scores = mx.softmax(scores, axis=-1, precise=True) + out = mx.quantized_matmul( + scores, *q_values, transpose=False, group_size=group_size, bits=bits + ) + + if n_repeats > 1: + out = mx.reshape(out, (B, n_q_heads, L, D)) + + return out + + +def scaled_dot_product_attention( + queries, + keys, + values, + cache, + scale: float, + mask: Optional[mx.array], +) -> mx.array: + if isinstance(cache, QuantizedKVCache): + return quantized_scaled_dot_product_attention( + queries, + keys, + values, + scale=scale, + mask=mask, + group_size=cache.group_size, + bits=cache.bits, + ) + else: + return mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=scale, mask=mask + ) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index f4efe41e..c11aa51f 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -129,42 +129,6 @@ class _BaseCache: return False -def quantized_scaled_dot_product_attention( - queries: mx.array, - q_keys: tuple[mx.array, mx.array, mx.array], - q_values: tuple[mx.array, mx.array, mx.array], - scale: float, - mask: Optional[mx.array], - group_size: int = 64, - bits: int = 8, -) -> mx.array: - B, n_q_heads, L, D = queries.shape - n_kv_heads = q_keys[0].shape[-3] - n_repeats = n_q_heads // n_kv_heads - - queries *= scale - - if n_repeats > 1: - queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D)) - q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys) - q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values) - - scores = mx.quantized_matmul( - queries, *q_keys, transpose=True, group_size=group_size, bits=bits - ) - if mask is not None: - scores += mask - scores = mx.softmax(scores, axis=-1, precise=True) - out = mx.quantized_matmul( - scores, *q_values, transpose=False, group_size=group_size, bits=bits - ) - - if n_repeats > 1: - out = mx.reshape(out, (B, n_q_heads, L, D)) - - return out - - class QuantizedKVCache(_BaseCache): def __init__(self, group_size: int = 64, bits: int = 8): self.keys = None @@ -452,14 +416,7 @@ class RotatingKVCache(_BaseCache): return n def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: - quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) - quant_cache.offset = self.offset - if self.keys is not None: - quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits) - quant_cache.values = mx.quantize( - self.values, group_size=group_size, bits=bits - ) - return quant_cache + raise NotImplementedError("RotatingKVCache Quantization NYI") class MambaCache(_BaseCache): diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 057c816d..7e002b0c 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -93,8 +93,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index 3b7e83d7..7be274cc 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -74,8 +74,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.out_proj(output) diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index 03cb3b1a..b7b24dba 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -97,8 +97,8 @@ class DeepseekAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index bb3e5184..444813b9 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -235,8 +235,8 @@ class DeepseekV2Attention(nn.Module): queries = mx.concatenate([q_nope, q_pe], axis=-1) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 61de781e..3f384c3f 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 97d9a8ff..52076a34 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -61,8 +61,8 @@ class Attention(nn.Module): if cache is not None: keys, values = cache.update_and_fetch(keys, values) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 068046ea..23e86e20 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -74,8 +74,8 @@ class Attention(nn.Module): if cache is not None: keys, values = cache.update_and_fetch(keys, values) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.c_proj(output) diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index 9f662491..ccb0b28b 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention # Based on the transformers implementation at: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index 5264cb57..f5ce057e 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -141,8 +141,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.wo(output) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 673d322a..6f72dd6e 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -6,8 +6,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask -from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -191,20 +190,9 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - if isinstance(cache, QuantizedKVCache): - output = quantized_scaled_dot_product_attention( - queries, - keys, - values, - scale=self.scale, - mask=mask, - group_size=cache.group_size, - bits=cache.bits, - ) - else: - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask - ) + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, cache=cache, scale=self.scale, mask=mask + ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index 4ac3c3b4..907beb2a 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -105,8 +105,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - attn_output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + attn_output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 20944fe3..dd94d1f4 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -87,8 +87,8 @@ class MixtralAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index 3ea06e27..f73c0277 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -113,8 +113,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 090e21c6..408802f4 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -107,8 +107,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 56b383b2..5bd8603d 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -7,7 +7,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -93,7 +93,7 @@ class PhiAttention(nn.Module): keys = self.rope(keys) scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( + output = scaled_dot_product_attention( queries.astype(mx.float32), keys, values, scale=scale, mask=mask ).astype(values.dtype) diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 9ef76f04..ee6efc49 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .su_rope import SuScaledRotaryEmbedding @@ -107,8 +107,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index 6b0759b4..53e1a638 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -8,7 +8,7 @@ from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -188,8 +188,8 @@ class Attention(nn.Module): queries, keys, values, scale=self.scale, mask=mask ) else: - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.dense(output) diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index ca20a388..f42a6dd0 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .su_rope import SuScaledRotaryEmbedding from .switch_layers import SwitchGLU @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 865d0d8e..67084d20 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -8,7 +8,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import create_attention_mask +from .base import create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchMLP @@ -71,7 +71,7 @@ class RoPEAttention(nn.Module): # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( + output = scaled_dot_product_attention( queries.astype(mx.float32), keys, values, scale=scale, mask=mask ).astype(values.dtype) output = output.moveaxis(2, 1).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index b0fd1a6c..a87c6cac 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -92,7 +92,7 @@ class Attention(nn.Module): keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1]) values = mx.tile(values, [1, self.config.n_shared_head, 1, 1]) - output = mx.fast.scaled_dot_product_attention( + output = scaled_dot_product_attention( queries, keys, values, diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 2b69d5ec..8145a890 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -5,7 +5,7 @@ from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -64,8 +64,8 @@ class Attention(nn.Module): queries = self.rotary_emb(queries) keys = self.rotary_emb(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 0b88e71e..468ffb43 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -6,8 +6,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask -from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -90,20 +89,9 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - if isinstance(cache, QuantizedKVCache): - output = quantized_scaled_dot_product_attention( - queries, - keys, - values, - scale=self.scale, - mask=mask, - group_size=cache.group_size, - bits=cache.bits, - ) - else: - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask - ) + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, cache=cache, scale=self.scale, mask=mask + ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index d199116f..167fc5dd 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -89,8 +89,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 5595d311..49e4bb8f 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -7,7 +7,7 @@ from typing import List, Literal, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .cache import MambaCache, RotatingKVCache @@ -263,8 +263,8 @@ class LocalAttentionBlock(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 11202b02..482bb324 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -6,7 +6,7 @@ from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -120,8 +120,8 @@ class Attention(nn.Module): # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=scale, mask=mask ).astype(values.dtype) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index ce0a2ec5..d7e626f2 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -6,7 +6,7 @@ from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -64,8 +64,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b28e90ce..1f5dd405 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -33,9 +33,6 @@ MODEL_REMAPPING = { MAX_FILE_SIZE_GB = 5 -DEFAULT_KV_GROUP_SIZE = 64 -DEFAULT_KV_BITS = 8 - class ModelNotFoundError(Exception): def __init__(self, message): @@ -162,20 +159,11 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float) return logits -def check_quantized_kv_args(quantized_kv_start, kv_group_size, kv_bits): - if not quantized_kv_start and ( - kv_group_size != DEFAULT_KV_GROUP_SIZE or kv_bits != DEFAULT_KV_BITS - ): - raise ValueError( - "--kv-group-size and --kv-bits only apply when --quantized-kv-start is specified." - ) - - def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): if ( - quantized_kv_start - and prompt_cache[0].offset > quantized_kv_start + kv_bits is not None and not isinstance(prompt_cache[0], cache.QuantizedKVCache) + and prompt_cache[0].offset > quantized_kv_start ): return [ c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in prompt_cache