This commit is contained in:
Alex Barron
2024-10-28 16:03:43 -07:00
parent 48655a7f83
commit 37a3723823
6 changed files with 197 additions and 90 deletions

View File

@@ -33,6 +33,9 @@ MODEL_REMAPPING = {
MAX_FILE_SIZE_GB = 5
DEFAULT_KV_GROUP_SIZE = 64
DEFAULT_KV_BITS = 8
class ModelNotFoundError(Exception):
def __init__(self, message):
@@ -159,6 +162,27 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
return logits
def check_quantized_kv_args(quantized_kv_start, kv_group_size, kv_bits):
if not quantized_kv_start and (
kv_group_size != DEFAULT_KV_GROUP_SIZE or kv_bits != DEFAULT_KV_BITS
):
raise ValueError(
"--kv-group-size and --kv-bits only apply when --quantized-kv-start is specified."
)
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
if (
quantized_kv_start
and prompt_cache[0].offset > quantized_kv_start
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
):
return [
c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in prompt_cache
]
return prompt_cache
def generate_step(
prompt: mx.array,
model: nn.Module,
@@ -173,7 +197,7 @@ def generate_step(
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
quantized_kv: bool = False,
quantized_kv_start: Optional[int] = None,
kv_group_size: int = 64,
kv_bits: int = 8,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
@@ -261,14 +285,13 @@ def generate_step(
prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
quantized_kv=quantized_kv,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
def _step(y):
nonlocal prompt_cache
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
@@ -279,6 +302,10 @@ def generate_step(
for processor in logits_processor:
logits = processor(tokens, logits)
prompt_cache = maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
y, logprobs = sample(logits)
return y, logprobs.squeeze(0)