single sdpa function

This commit is contained in:
Alex Barron
2024-10-31 12:02:34 -07:00
parent 29f21e7fe4
commit 2e0690374e
31 changed files with 174 additions and 191 deletions

View File

@@ -33,9 +33,6 @@ MODEL_REMAPPING = {
MAX_FILE_SIZE_GB = 5
DEFAULT_KV_GROUP_SIZE = 64
DEFAULT_KV_BITS = 8
class ModelNotFoundError(Exception):
def __init__(self, message):
@@ -162,20 +159,11 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
return logits
def check_quantized_kv_args(quantized_kv_start, kv_group_size, kv_bits):
if not quantized_kv_start and (
kv_group_size != DEFAULT_KV_GROUP_SIZE or kv_bits != DEFAULT_KV_BITS
):
raise ValueError(
"--kv-group-size and --kv-bits only apply when --quantized-kv-start is specified."
)
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
if (
quantized_kv_start
and prompt_cache[0].offset > quantized_kv_start
kv_bits is not None
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
and prompt_cache[0].offset > quantized_kv_start
):
return [
c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in prompt_cache