mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
[mlx-lm] Use top p in server (#1144)
* use top p in server * couple other fixes
This commit is contained in:
parent
19abf3dcaa
commit
2ba0e36683
@ -190,7 +190,7 @@ def make_repetition_penalty(penalty: float, context_size: int = 20):
|
|||||||
Callable[[mx.array, List[int]], mx.array]:
|
Callable[[mx.array, List[int]], mx.array]:
|
||||||
The repetition penalty processor.
|
The repetition penalty processor.
|
||||||
"""
|
"""
|
||||||
if penalty < 0 or not isinstance(penalty, float):
|
if penalty < 0 or not isinstance(penalty, (int, float)):
|
||||||
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
|
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
|
||||||
|
|
||||||
def repetition_penalty_processor(tokens, logits):
|
def repetition_penalty_processor(tokens, logits):
|
||||||
|
@ -465,7 +465,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
sampler = make_sampler(self.temperature)
|
sampler = make_sampler(self.temperature, top_p=self.top_p)
|
||||||
logits_processors = make_logits_processors(
|
logits_processors = make_logits_processors(
|
||||||
self.logit_bias, self.repetition_penalty, self.repetition_context_size
|
self.logit_bias, self.repetition_penalty, self.repetition_context_size
|
||||||
)
|
)
|
||||||
|
@ -299,6 +299,9 @@ def generate_step(
|
|||||||
prompt_processed_tokens = 0
|
prompt_processed_tokens = 0
|
||||||
while y.size > prefill_step_size:
|
while y.size > prefill_step_size:
|
||||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
model(y[:prefill_step_size][None], cache=prompt_cache)
|
||||||
|
maybe_quantize_kv_cache(
|
||||||
|
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||||
|
)
|
||||||
mx.eval([c.state for c in prompt_cache])
|
mx.eval([c.state for c in prompt_cache])
|
||||||
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
||||||
prompt_processed_tokens += prefill_step_size
|
prompt_processed_tokens += prefill_step_size
|
||||||
|
Loading…
Reference in New Issue
Block a user