mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
simplify
This commit is contained in:
@@ -33,6 +33,9 @@ MODEL_REMAPPING = {
|
||||
|
||||
MAX_FILE_SIZE_GB = 5
|
||||
|
||||
DEFAULT_KV_GROUP_SIZE = 64
|
||||
DEFAULT_KV_BITS = 8
|
||||
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
def __init__(self, message):
|
||||
@@ -159,6 +162,27 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
|
||||
return logits
|
||||
|
||||
|
||||
def check_quantized_kv_args(quantized_kv_start, kv_group_size, kv_bits):
|
||||
if not quantized_kv_start and (
|
||||
kv_group_size != DEFAULT_KV_GROUP_SIZE or kv_bits != DEFAULT_KV_BITS
|
||||
):
|
||||
raise ValueError(
|
||||
"--kv-group-size and --kv-bits only apply when --quantized-kv-start is specified."
|
||||
)
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
||||
if (
|
||||
quantized_kv_start
|
||||
and prompt_cache[0].offset > quantized_kv_start
|
||||
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
|
||||
):
|
||||
return [
|
||||
c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in prompt_cache
|
||||
]
|
||||
return prompt_cache
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
@@ -173,7 +197,7 @@ def generate_step(
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
quantized_kv: bool = False,
|
||||
quantized_kv_start: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
kv_bits: int = 8,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
@@ -261,14 +285,13 @@ def generate_step(
|
||||
prompt_cache = cache.make_prompt_cache(
|
||||
model,
|
||||
max_kv_size=max_kv_size,
|
||||
quantized_kv=quantized_kv,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
def _step(y):
|
||||
|
||||
nonlocal prompt_cache
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
@@ -279,6 +302,10 @@ def generate_step(
|
||||
for processor in logits_processor:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
prompt_cache = maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
|
||||
y, logprobs = sample(logits)
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
|
Reference in New Issue
Block a user