[mlx-lm] Use top p in server (#1144)

* use top p in server

* couple other fixes
This commit is contained in:
Awni Hannun 2024-12-12 11:12:21 -08:00 committed by GitHub
parent 19abf3dcaa
commit 2ba0e36683
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 5 additions and 2 deletions

View File

@ -190,7 +190,7 @@ def make_repetition_penalty(penalty: float, context_size: int = 20):
Callable[[mx.array, List[int]], mx.array]:
The repetition penalty processor.
"""
if penalty < 0 or not isinstance(penalty, float):
if penalty < 0 or not isinstance(penalty, (int, float)):
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
def repetition_penalty_processor(tokens, logits):

View File

@ -465,7 +465,7 @@ class APIHandler(BaseHTTPRequestHandler):
text = ""
tic = time.perf_counter()
sampler = make_sampler(self.temperature)
sampler = make_sampler(self.temperature, top_p=self.top_p)
logits_processors = make_logits_processors(
self.logit_bias, self.repetition_penalty, self.repetition_context_size
)

View File

@ -299,6 +299,9 @@ def generate_step(
prompt_processed_tokens = 0
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size