single sdpa function

This commit is contained in:
Alex Barron 2024-10-31 12:02:34 -07:00
parent 29f21e7fe4
commit 2e0690374e
31 changed files with 174 additions and 191 deletions

View File

@ -8,13 +8,9 @@ import time
import mlx.core as mx import mlx.core as mx
from .models.cache import make_prompt_cache, save_prompt_cache from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import ( from .utils import load, maybe_quantize_kv_cache
DEFAULT_KV_BITS,
DEFAULT_KV_GROUP_SIZE, DEFAULT_QUANTIZED_KV_START = 5000
check_quantized_kv_args,
load,
maybe_quantize_kv_cache,
)
def setup_arg_parser(): def setup_arg_parser():
@ -77,22 +73,24 @@ def setup_arg_parser():
help="Message to be processed by the model ('-' reads from stdin)", help="Message to be processed by the model ('-' reads from stdin)",
) )
parser.add_argument( parser.add_argument(
"--quantized-kv-start", "--kv-bits",
help="Use a quantized KV cache from this step onwards.",
type=int, type=int,
help="Number of bits for KV cache quantization. "
"Defaults to no quantization.",
default=None, default=None,
) )
parser.add_argument( parser.add_argument(
"--kv-group-size", "--kv-group-size",
type=int, type=int,
help="Group size for kv cache quantization.", help="Group size for KV cache quantization.",
default=DEFAULT_KV_GROUP_SIZE, default=64,
) )
parser.add_argument( parser.add_argument(
"--kv-bits", "--quantized-kv-start",
help="When --kv-bits is set, start quantizing the KV cache "
"from this step onwards.",
type=int, type=int,
help="Number of bits for kv cache quantization.", default=DEFAULT_QUANTIZED_KV_START,
default=DEFAULT_KV_BITS,
) )
return parser return parser
@ -117,8 +115,6 @@ def main():
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
check_quantized_kv_args(args.quantized_kv_start, args.kv_group_size, args.kv_bits)
if args.use_default_chat_template: if args.use_default_chat_template:
if tokenizer.chat_template is None: if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template tokenizer.chat_template = tokenizer.default_chat_template

View File

@ -7,7 +7,7 @@ import sys
import mlx.core as mx import mlx.core as mx
from .models.cache import QuantizedKVCache, load_prompt_cache from .models.cache import QuantizedKVCache, load_prompt_cache
from .utils import check_quantized_kv_args, generate, load from .utils import generate, load
DEFAULT_PROMPT = "hello" DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100 DEFAULT_MAX_TOKENS = 100
@ -15,6 +15,7 @@ DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0 DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0 DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000
def str2bool(string): def str2bool(string):
@ -108,24 +109,24 @@ def setup_arg_parser():
help="A file containing saved KV caches to avoid recomputing them", help="A file containing saved KV caches to avoid recomputing them",
) )
parser.add_argument( parser.add_argument(
"--quantized-kv-start", "--kv-bits",
help="Use a quantized KV cache from this step onwards.",
type=int, type=int,
help="Number of bits for KV cache quantization. "
"Defaults to no quantization.",
default=None, default=None,
) )
parser.add_argument( parser.add_argument(
"--kv-group-size", "--kv-group-size",
type=int, type=int,
help="Group size for kv cache quantization. " help="Group size for KV cache quantization.",
"--quantized-kv-start must be provided to have an effect.",
default=64, default=64,
) )
parser.add_argument( parser.add_argument(
"--kv-bits", "--quantized-kv-start",
help="When --kv-bits is set, start quantizing the KV cache "
"from this step onwards.",
type=int, type=int,
help="Number of bits for kv cache quantization. " default=DEFAULT_QUANTIZED_KV_START,
"--quantized-kv-start must be provided to have an effect.",
default=8,
) )
return parser return parser
@ -173,13 +174,15 @@ def main():
args.prompt_cache_file, args.prompt_cache_file,
return_metadata=True, return_metadata=True,
) )
if args.quantized_kv_start and isinstance(prompt_cache[0], QuantizedKVCache): if isinstance(prompt_cache[0], QuantizedKVCache):
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
raise ValueError( raise ValueError(
"Specified `--quantized-kv-start` but cache from " "--kv-bits does not match the kv cache loaded from --prompt-cache-file."
"`--prompt-cache-file` is already quantized." )
if args.kv_group_size != prompt_cache[0].group_size:
raise ValueError(
"--kv-group-size does not match the kv cache loaded from --prompt-cache-file."
) )
check_quantized_kv_args(args.quantized_kv_start, args.kv_group_size, args.kv_bits)
# Building tokenizer_config # Building tokenizer_config
tokenizer_config = ( tokenizer_config = (

View File

@ -5,6 +5,9 @@ from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
from mlx.utils import tree_map
from .cache import QuantizedKVCache
@dataclass @dataclass
@ -48,3 +51,63 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
else: else:
mask = None mask = None
return mask return mask
def quantized_scaled_dot_product_attention(
queries: mx.array,
q_keys: tuple[mx.array, mx.array, mx.array],
q_values: tuple[mx.array, mx.array, mx.array],
scale: float,
mask: Optional[mx.array],
group_size: int = 64,
bits: int = 8,
) -> mx.array:
B, n_q_heads, L, D = queries.shape
n_kv_heads = q_keys[0].shape[-3]
n_repeats = n_q_heads // n_kv_heads
queries *= scale
if n_repeats > 1:
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
scores = mx.quantized_matmul(
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
)
if mask is not None:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = mx.quantized_matmul(
scores, *q_values, transpose=False, group_size=group_size, bits=bits
)
if n_repeats > 1:
out = mx.reshape(out, (B, n_q_heads, L, D))
return out
def scaled_dot_product_attention(
queries,
keys,
values,
cache,
scale: float,
mask: Optional[mx.array],
) -> mx.array:
if isinstance(cache, QuantizedKVCache):
return quantized_scaled_dot_product_attention(
queries,
keys,
values,
scale=scale,
mask=mask,
group_size=cache.group_size,
bits=cache.bits,
)
else:
return mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=scale, mask=mask
)

View File

@ -129,42 +129,6 @@ class _BaseCache:
return False return False
def quantized_scaled_dot_product_attention(
queries: mx.array,
q_keys: tuple[mx.array, mx.array, mx.array],
q_values: tuple[mx.array, mx.array, mx.array],
scale: float,
mask: Optional[mx.array],
group_size: int = 64,
bits: int = 8,
) -> mx.array:
B, n_q_heads, L, D = queries.shape
n_kv_heads = q_keys[0].shape[-3]
n_repeats = n_q_heads // n_kv_heads
queries *= scale
if n_repeats > 1:
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
scores = mx.quantized_matmul(
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
)
if mask is not None:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = mx.quantized_matmul(
scores, *q_values, transpose=False, group_size=group_size, bits=bits
)
if n_repeats > 1:
out = mx.reshape(out, (B, n_q_heads, L, D))
return out
class QuantizedKVCache(_BaseCache): class QuantizedKVCache(_BaseCache):
def __init__(self, group_size: int = 64, bits: int = 8): def __init__(self, group_size: int = 64, bits: int = 8):
self.keys = None self.keys = None
@ -452,14 +416,7 @@ class RotatingKVCache(_BaseCache):
return n return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) raise NotImplementedError("RotatingKVCache Quantization NYI")
quant_cache.offset = self.offset
if self.keys is not None:
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
quant_cache.values = mx.quantize(
self.values, group_size=group_size, bits=bits
)
return quant_cache
class MambaCache(_BaseCache): class MambaCache(_BaseCache):

View File

@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -93,8 +93,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -74,8 +74,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output) return self.out_proj(output)

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -97,8 +97,8 @@ class DeepseekAttention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -235,8 +235,8 @@ class DeepseekV2Attention(nn.Module):
queries = mx.concatenate([q_nope, q_pe], axis=-1) queries = mx.concatenate([q_nope, q_pe], axis=-1)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -79,8 +79,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -61,8 +61,8 @@ class Attention(nn.Module):
if cache is not None: if cache is not None:
keys, values = cache.update_and_fetch(keys, values) keys, values = cache.update_and_fetch(keys, values)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -74,8 +74,8 @@ class Attention(nn.Module):
if cache is not None: if cache is not None:
keys, values = cache.update_and_fetch(keys, values) keys, values = cache.update_and_fetch(keys, values)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output) return self.c_proj(output)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
# Based on the transformers implementation at: # Based on the transformers implementation at:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@ -79,8 +79,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -141,8 +141,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output) return self.wo(output)

View File

@ -6,8 +6,7 @@ from typing import Any, Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention
@dataclass @dataclass
@ -191,19 +190,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
if isinstance(cache, QuantizedKVCache): output = scaled_dot_product_attention(
output = quantized_scaled_dot_product_attention( queries, keys, values, cache=cache, cache=cache, scale=self.scale, mask=mask
queries,
keys,
values,
scale=self.scale,
mask=mask,
group_size=cache.group_size,
bits=cache.bits,
)
else:
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -105,8 +105,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
attn_output = mx.fast.scaled_dot_product_attention( attn_output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1) attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -87,8 +87,8 @@ class MixtralAttention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -113,8 +113,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -107,8 +107,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ from typing import Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -93,7 +93,7 @@ class PhiAttention(nn.Module):
keys = self.rope(keys) keys = self.rope(keys)
scale = math.sqrt(1 / queries.shape[-1]) scale = math.sqrt(1 / queries.shape[-1])
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask queries.astype(mx.float32), keys, values, scale=scale, mask=mask
).astype(values.dtype) ).astype(values.dtype)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .su_rope import SuScaledRotaryEmbedding from .su_rope import SuScaledRotaryEmbedding
@ -107,8 +107,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -8,7 +8,7 @@ from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -188,8 +188,8 @@ class Attention(nn.Module):
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, scale=self.scale, mask=mask
) )
else: else:
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(output) return self.dense(output)

View File

@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .su_rope import SuScaledRotaryEmbedding from .su_rope import SuScaledRotaryEmbedding
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -79,8 +79,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -8,7 +8,7 @@ from typing import Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import create_attention_mask from .base import create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchMLP from .switch_layers import SwitchMLP
@ -71,7 +71,7 @@ class RoPEAttention(nn.Module):
# Finally perform the attention computation # Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1]) scale = math.sqrt(1 / queries.shape[-1])
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask queries.astype(mx.float32), keys, values, scale=scale, mask=mask
).astype(values.dtype) ).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1) output = output.moveaxis(2, 1).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -92,7 +92,7 @@ class Attention(nn.Module):
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1]) keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1]) values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, queries,
keys, keys,
values, values,

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -64,8 +64,8 @@ class Attention(nn.Module):
queries = self.rotary_emb(queries) queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys) keys = self.rotary_emb(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -6,8 +6,7 @@ from typing import Any, Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention
@dataclass @dataclass
@ -90,19 +89,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
if isinstance(cache, QuantizedKVCache): output = scaled_dot_product_attention(
output = quantized_scaled_dot_product_attention( queries, keys, values, cache=cache, cache=cache, scale=self.scale, mask=mask
queries,
keys,
values,
scale=self.scale,
mask=mask,
group_size=cache.group_size,
bits=cache.bits,
)
else:
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -89,8 +89,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -7,7 +7,7 @@ from typing import List, Literal, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import MambaCache, RotatingKVCache from .cache import MambaCache, RotatingKVCache
@ -263,8 +263,8 @@ class LocalAttentionBlock(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -6,7 +6,7 @@ from dataclasses import dataclass
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -120,8 +120,8 @@ class Attention(nn.Module):
# Finally perform the attention computation # Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1]) scale = math.sqrt(1 / queries.shape[-1])
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=scale, mask=mask queries, keys, values, cache=cache, scale=scale, mask=mask
).astype(values.dtype) ).astype(values.dtype)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -6,7 +6,7 @@ from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -64,8 +64,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -33,9 +33,6 @@ MODEL_REMAPPING = {
MAX_FILE_SIZE_GB = 5 MAX_FILE_SIZE_GB = 5
DEFAULT_KV_GROUP_SIZE = 64
DEFAULT_KV_BITS = 8
class ModelNotFoundError(Exception): class ModelNotFoundError(Exception):
def __init__(self, message): def __init__(self, message):
@ -162,20 +159,11 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
return logits return logits
def check_quantized_kv_args(quantized_kv_start, kv_group_size, kv_bits):
if not quantized_kv_start and (
kv_group_size != DEFAULT_KV_GROUP_SIZE or kv_bits != DEFAULT_KV_BITS
):
raise ValueError(
"--kv-group-size and --kv-bits only apply when --quantized-kv-start is specified."
)
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
if ( if (
quantized_kv_start kv_bits is not None
and prompt_cache[0].offset > quantized_kv_start
and not isinstance(prompt_cache[0], cache.QuantizedKVCache) and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
and prompt_cache[0].offset > quantized_kv_start
): ):
return [ return [
c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in prompt_cache c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in prompt_cache