mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
single sdpa function
This commit is contained in:
@@ -33,9 +33,6 @@ MODEL_REMAPPING = {
|
||||
|
||||
MAX_FILE_SIZE_GB = 5
|
||||
|
||||
DEFAULT_KV_GROUP_SIZE = 64
|
||||
DEFAULT_KV_BITS = 8
|
||||
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
def __init__(self, message):
|
||||
@@ -162,20 +159,11 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
|
||||
return logits
|
||||
|
||||
|
||||
def check_quantized_kv_args(quantized_kv_start, kv_group_size, kv_bits):
|
||||
if not quantized_kv_start and (
|
||||
kv_group_size != DEFAULT_KV_GROUP_SIZE or kv_bits != DEFAULT_KV_BITS
|
||||
):
|
||||
raise ValueError(
|
||||
"--kv-group-size and --kv-bits only apply when --quantized-kv-start is specified."
|
||||
)
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
||||
if (
|
||||
quantized_kv_start
|
||||
and prompt_cache[0].offset > quantized_kv_start
|
||||
kv_bits is not None
|
||||
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
|
||||
and prompt_cache[0].offset > quantized_kv_start
|
||||
):
|
||||
return [
|
||||
c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in prompt_cache
|
||||
|
Reference in New Issue
Block a user