mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
single sdpa function
This commit is contained in:
parent
29f21e7fe4
commit
2e0690374e
@ -8,13 +8,9 @@ import time
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .models.cache import make_prompt_cache, save_prompt_cache
|
from .models.cache import make_prompt_cache, save_prompt_cache
|
||||||
from .utils import (
|
from .utils import load, maybe_quantize_kv_cache
|
||||||
DEFAULT_KV_BITS,
|
|
||||||
DEFAULT_KV_GROUP_SIZE,
|
DEFAULT_QUANTIZED_KV_START = 5000
|
||||||
check_quantized_kv_args,
|
|
||||||
load,
|
|
||||||
maybe_quantize_kv_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_arg_parser():
|
def setup_arg_parser():
|
||||||
@ -77,22 +73,24 @@ def setup_arg_parser():
|
|||||||
help="Message to be processed by the model ('-' reads from stdin)",
|
help="Message to be processed by the model ('-' reads from stdin)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--quantized-kv-start",
|
"--kv-bits",
|
||||||
help="Use a quantized KV cache from this step onwards.",
|
|
||||||
type=int,
|
type=int,
|
||||||
|
help="Number of bits for KV cache quantization. "
|
||||||
|
"Defaults to no quantization.",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-group-size",
|
"--kv-group-size",
|
||||||
type=int,
|
type=int,
|
||||||
help="Group size for kv cache quantization.",
|
help="Group size for KV cache quantization.",
|
||||||
default=DEFAULT_KV_GROUP_SIZE,
|
default=64,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-bits",
|
"--quantized-kv-start",
|
||||||
|
help="When --kv-bits is set, start quantizing the KV cache "
|
||||||
|
"from this step onwards.",
|
||||||
type=int,
|
type=int,
|
||||||
help="Number of bits for kv cache quantization.",
|
default=DEFAULT_QUANTIZED_KV_START,
|
||||||
default=DEFAULT_KV_BITS,
|
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -117,8 +115,6 @@ def main():
|
|||||||
|
|
||||||
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
|
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
|
||||||
|
|
||||||
check_quantized_kv_args(args.quantized_kv_start, args.kv_group_size, args.kv_bits)
|
|
||||||
|
|
||||||
if args.use_default_chat_template:
|
if args.use_default_chat_template:
|
||||||
if tokenizer.chat_template is None:
|
if tokenizer.chat_template is None:
|
||||||
tokenizer.chat_template = tokenizer.default_chat_template
|
tokenizer.chat_template = tokenizer.default_chat_template
|
||||||
|
@ -7,7 +7,7 @@ import sys
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .models.cache import QuantizedKVCache, load_prompt_cache
|
from .models.cache import QuantizedKVCache, load_prompt_cache
|
||||||
from .utils import check_quantized_kv_args, generate, load
|
from .utils import generate, load
|
||||||
|
|
||||||
DEFAULT_PROMPT = "hello"
|
DEFAULT_PROMPT = "hello"
|
||||||
DEFAULT_MAX_TOKENS = 100
|
DEFAULT_MAX_TOKENS = 100
|
||||||
@ -15,6 +15,7 @@ DEFAULT_TEMP = 0.0
|
|||||||
DEFAULT_TOP_P = 1.0
|
DEFAULT_TOP_P = 1.0
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
DEFAULT_QUANTIZED_KV_START = 5000
|
||||||
|
|
||||||
|
|
||||||
def str2bool(string):
|
def str2bool(string):
|
||||||
@ -108,24 +109,24 @@ def setup_arg_parser():
|
|||||||
help="A file containing saved KV caches to avoid recomputing them",
|
help="A file containing saved KV caches to avoid recomputing them",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--quantized-kv-start",
|
"--kv-bits",
|
||||||
help="Use a quantized KV cache from this step onwards.",
|
|
||||||
type=int,
|
type=int,
|
||||||
|
help="Number of bits for KV cache quantization. "
|
||||||
|
"Defaults to no quantization.",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-group-size",
|
"--kv-group-size",
|
||||||
type=int,
|
type=int,
|
||||||
help="Group size for kv cache quantization. "
|
help="Group size for KV cache quantization.",
|
||||||
"--quantized-kv-start must be provided to have an effect.",
|
|
||||||
default=64,
|
default=64,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-bits",
|
"--quantized-kv-start",
|
||||||
|
help="When --kv-bits is set, start quantizing the KV cache "
|
||||||
|
"from this step onwards.",
|
||||||
type=int,
|
type=int,
|
||||||
help="Number of bits for kv cache quantization. "
|
default=DEFAULT_QUANTIZED_KV_START,
|
||||||
"--quantized-kv-start must be provided to have an effect.",
|
|
||||||
default=8,
|
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -173,13 +174,15 @@ def main():
|
|||||||
args.prompt_cache_file,
|
args.prompt_cache_file,
|
||||||
return_metadata=True,
|
return_metadata=True,
|
||||||
)
|
)
|
||||||
if args.quantized_kv_start and isinstance(prompt_cache[0], QuantizedKVCache):
|
if isinstance(prompt_cache[0], QuantizedKVCache):
|
||||||
raise ValueError(
|
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
|
||||||
"Specified `--quantized-kv-start` but cache from "
|
raise ValueError(
|
||||||
"`--prompt-cache-file` is already quantized."
|
"--kv-bits does not match the kv cache loaded from --prompt-cache-file."
|
||||||
)
|
)
|
||||||
|
if args.kv_group_size != prompt_cache[0].group_size:
|
||||||
check_quantized_kv_args(args.quantized_kv_start, args.kv_group_size, args.kv_bits)
|
raise ValueError(
|
||||||
|
"--kv-group-size does not match the kv cache loaded from --prompt-cache-file."
|
||||||
|
)
|
||||||
|
|
||||||
# Building tokenizer_config
|
# Building tokenizer_config
|
||||||
tokenizer_config = (
|
tokenizer_config = (
|
||||||
|
@ -5,6 +5,9 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
from mlx.utils import tree_map
|
||||||
|
|
||||||
|
from .cache import QuantizedKVCache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -48,3 +51,63 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
|||||||
else:
|
else:
|
||||||
mask = None
|
mask = None
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def quantized_scaled_dot_product_attention(
|
||||||
|
queries: mx.array,
|
||||||
|
q_keys: tuple[mx.array, mx.array, mx.array],
|
||||||
|
q_values: tuple[mx.array, mx.array, mx.array],
|
||||||
|
scale: float,
|
||||||
|
mask: Optional[mx.array],
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 8,
|
||||||
|
) -> mx.array:
|
||||||
|
B, n_q_heads, L, D = queries.shape
|
||||||
|
n_kv_heads = q_keys[0].shape[-3]
|
||||||
|
n_repeats = n_q_heads // n_kv_heads
|
||||||
|
|
||||||
|
queries *= scale
|
||||||
|
|
||||||
|
if n_repeats > 1:
|
||||||
|
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
|
||||||
|
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
|
||||||
|
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
|
||||||
|
|
||||||
|
scores = mx.quantized_matmul(
|
||||||
|
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
|
||||||
|
)
|
||||||
|
if mask is not None:
|
||||||
|
scores += mask
|
||||||
|
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||||
|
out = mx.quantized_matmul(
|
||||||
|
scores, *q_values, transpose=False, group_size=group_size, bits=bits
|
||||||
|
)
|
||||||
|
|
||||||
|
if n_repeats > 1:
|
||||||
|
out = mx.reshape(out, (B, n_q_heads, L, D))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(
|
||||||
|
queries,
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
cache,
|
||||||
|
scale: float,
|
||||||
|
mask: Optional[mx.array],
|
||||||
|
) -> mx.array:
|
||||||
|
if isinstance(cache, QuantizedKVCache):
|
||||||
|
return quantized_scaled_dot_product_attention(
|
||||||
|
queries,
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
scale=scale,
|
||||||
|
mask=mask,
|
||||||
|
group_size=cache.group_size,
|
||||||
|
bits=cache.bits,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=scale, mask=mask
|
||||||
|
)
|
||||||
|
@ -129,42 +129,6 @@ class _BaseCache:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def quantized_scaled_dot_product_attention(
|
|
||||||
queries: mx.array,
|
|
||||||
q_keys: tuple[mx.array, mx.array, mx.array],
|
|
||||||
q_values: tuple[mx.array, mx.array, mx.array],
|
|
||||||
scale: float,
|
|
||||||
mask: Optional[mx.array],
|
|
||||||
group_size: int = 64,
|
|
||||||
bits: int = 8,
|
|
||||||
) -> mx.array:
|
|
||||||
B, n_q_heads, L, D = queries.shape
|
|
||||||
n_kv_heads = q_keys[0].shape[-3]
|
|
||||||
n_repeats = n_q_heads // n_kv_heads
|
|
||||||
|
|
||||||
queries *= scale
|
|
||||||
|
|
||||||
if n_repeats > 1:
|
|
||||||
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
|
|
||||||
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
|
|
||||||
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
|
|
||||||
|
|
||||||
scores = mx.quantized_matmul(
|
|
||||||
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
|
|
||||||
)
|
|
||||||
if mask is not None:
|
|
||||||
scores += mask
|
|
||||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
|
||||||
out = mx.quantized_matmul(
|
|
||||||
scores, *q_values, transpose=False, group_size=group_size, bits=bits
|
|
||||||
)
|
|
||||||
|
|
||||||
if n_repeats > 1:
|
|
||||||
out = mx.reshape(out, (B, n_q_heads, L, D))
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class QuantizedKVCache(_BaseCache):
|
class QuantizedKVCache(_BaseCache):
|
||||||
def __init__(self, group_size: int = 64, bits: int = 8):
|
def __init__(self, group_size: int = 64, bits: int = 8):
|
||||||
self.keys = None
|
self.keys = None
|
||||||
@ -452,14 +416,7 @@ class RotatingKVCache(_BaseCache):
|
|||||||
return n
|
return n
|
||||||
|
|
||||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||||
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
raise NotImplementedError("RotatingKVCache Quantization NYI")
|
||||||
quant_cache.offset = self.offset
|
|
||||||
if self.keys is not None:
|
|
||||||
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
|
|
||||||
quant_cache.values = mx.quantize(
|
|
||||||
self.values, group_size=group_size, bits=bits
|
|
||||||
)
|
|
||||||
return quant_cache
|
|
||||||
|
|
||||||
|
|
||||||
class MambaCache(_BaseCache):
|
class MambaCache(_BaseCache):
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -93,8 +93,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -7,7 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -74,8 +74,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.out_proj(output)
|
return self.out_proj(output)
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict, Optional
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
@ -97,8 +97,8 @@ class DeepseekAttention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
@ -235,8 +235,8 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
|
|
||||||
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -79,8 +79,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -7,7 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -61,8 +61,8 @@ class Attention(nn.Module):
|
|||||||
if cache is not None:
|
if cache is not None:
|
||||||
keys, values = cache.update_and_fetch(keys, values)
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -7,7 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -74,8 +74,8 @@ class Attention(nn.Module):
|
|||||||
if cache is not None:
|
if cache is not None:
|
||||||
keys, values = cache.update_and_fetch(keys, values)
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.c_proj(output)
|
return self.c_proj(output)
|
||||||
|
@ -7,7 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
# Based on the transformers implementation at:
|
# Based on the transformers implementation at:
|
||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||||
@ -79,8 +79,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -141,8 +141,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.wo(output)
|
return self.wo(output)
|
||||||
|
@ -6,8 +6,7 @@ from typing import Any, Dict, Optional, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -191,20 +190,9 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
if isinstance(cache, QuantizedKVCache):
|
output = scaled_dot_product_attention(
|
||||||
output = quantized_scaled_dot_product_attention(
|
queries, keys, values, cache=cache, cache=cache, scale=self.scale, mask=mask
|
||||||
queries,
|
)
|
||||||
keys,
|
|
||||||
values,
|
|
||||||
scale=self.scale,
|
|
||||||
mask=mask,
|
|
||||||
group_size=cache.group_size,
|
|
||||||
bits=cache.bits,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
|
||||||
)
|
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -7,7 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -105,8 +105,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
attn_output = mx.fast.scaled_dot_product_attention(
|
attn_output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
@ -87,8 +87,8 @@ class MixtralAttention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -113,8 +113,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -107,8 +107,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -7,7 +7,7 @@ from typing import Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -93,7 +93,7 @@ class PhiAttention(nn.Module):
|
|||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
||||||
).astype(values.dtype)
|
).astype(values.dtype)
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .su_rope import SuScaledRotaryEmbedding
|
from .su_rope import SuScaledRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
@ -107,8 +107,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -8,7 +8,7 @@ from typing import Any, Optional
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -188,8 +188,8 @@ class Attention(nn.Module):
|
|||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.dense(output)
|
return self.dense(output)
|
||||||
|
@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .su_rope import SuScaledRotaryEmbedding
|
from .su_rope import SuScaledRotaryEmbedding
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
@ -79,8 +79,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -8,7 +8,7 @@ from typing import Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import create_attention_mask
|
from .base import create_attention_mask, scaled_dot_product_attention
|
||||||
from .switch_layers import SwitchMLP
|
from .switch_layers import SwitchMLP
|
||||||
|
|
||||||
|
|
||||||
@ -71,7 +71,7 @@ class RoPEAttention(nn.Module):
|
|||||||
# Finally perform the attention computation
|
# Finally perform the attention computation
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
||||||
).astype(values.dtype)
|
).astype(values.dtype)
|
||||||
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||||
|
@ -7,7 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -92,7 +92,7 @@ class Attention(nn.Module):
|
|||||||
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
|
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
|
||||||
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
|
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries,
|
queries,
|
||||||
keys,
|
keys,
|
||||||
values,
|
values,
|
||||||
|
@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -64,8 +64,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rotary_emb(queries)
|
queries = self.rotary_emb(queries)
|
||||||
keys = self.rotary_emb(keys)
|
keys = self.rotary_emb(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
|
@ -6,8 +6,7 @@ from typing import Any, Dict, Optional, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -90,20 +89,9 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
if isinstance(cache, QuantizedKVCache):
|
output = scaled_dot_product_attention(
|
||||||
output = quantized_scaled_dot_product_attention(
|
queries, keys, values, cache=cache, cache=cache, scale=self.scale, mask=mask
|
||||||
queries,
|
)
|
||||||
keys,
|
|
||||||
values,
|
|
||||||
scale=self.scale,
|
|
||||||
mask=mask,
|
|
||||||
group_size=cache.group_size,
|
|
||||||
bits=cache.bits,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
|
||||||
)
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
@ -89,8 +89,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -7,7 +7,7 @@ from typing import List, Literal, Optional
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .cache import MambaCache, RotatingKVCache
|
from .cache import MambaCache, RotatingKVCache
|
||||||
|
|
||||||
|
|
||||||
@ -263,8 +263,8 @@ class LocalAttentionBlock(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -6,7 +6,7 @@ from dataclasses import dataclass
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -120,8 +120,8 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# Finally perform the attention computation
|
# Finally perform the attention computation
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=scale, mask=mask
|
queries, keys, values, cache=cache, scale=scale, mask=mask
|
||||||
).astype(values.dtype)
|
).astype(values.dtype)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Optional
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -64,8 +64,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -33,9 +33,6 @@ MODEL_REMAPPING = {
|
|||||||
|
|
||||||
MAX_FILE_SIZE_GB = 5
|
MAX_FILE_SIZE_GB = 5
|
||||||
|
|
||||||
DEFAULT_KV_GROUP_SIZE = 64
|
|
||||||
DEFAULT_KV_BITS = 8
|
|
||||||
|
|
||||||
|
|
||||||
class ModelNotFoundError(Exception):
|
class ModelNotFoundError(Exception):
|
||||||
def __init__(self, message):
|
def __init__(self, message):
|
||||||
@ -162,20 +159,11 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def check_quantized_kv_args(quantized_kv_start, kv_group_size, kv_bits):
|
|
||||||
if not quantized_kv_start and (
|
|
||||||
kv_group_size != DEFAULT_KV_GROUP_SIZE or kv_bits != DEFAULT_KV_BITS
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"--kv-group-size and --kv-bits only apply when --quantized-kv-start is specified."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
||||||
if (
|
if (
|
||||||
quantized_kv_start
|
kv_bits is not None
|
||||||
and prompt_cache[0].offset > quantized_kv_start
|
|
||||||
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
|
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
|
||||||
|
and prompt_cache[0].offset > quantized_kv_start
|
||||||
):
|
):
|
||||||
return [
|
return [
|
||||||
c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in prompt_cache
|
c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in prompt_cache
|
||||||
|
Loading…
Reference in New Issue
Block a user