mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-08 18:06:37 +08:00
chore(mlx-lm): refactor server.py to utilize generate_step from utils for consistency (#491)
* chore(mlx-lm): refactor server.py to utilize generate_step from utils for consistency * chore(mlx-lm): update server doc * chore: remove unused generate func
This commit is contained in:
parent
19a21bfce4
commit
82f3f31d93
@ -61,3 +61,5 @@ curl localhost:8080/v1/chat/completions \
|
||||
|
||||
- `top_p`: (Optional) A float specifying the nucleus sampling parameter.
|
||||
Defaults to `1.0`.
|
||||
- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens. Defaults to `1.0`.
|
||||
- `repetition_context_size`: (Optional) The size of the context window for applying repetition penalty. Defaults to `20`.
|
@ -11,7 +11,7 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .utils import load
|
||||
from .utils import generate_step, load
|
||||
|
||||
_model: Optional[nn.Module] = None
|
||||
_tokenizer: Optional[PreTrainedTokenizer] = None
|
||||
@ -56,50 +56,6 @@ def stopping_criteria(
|
||||
return StopCondition(stop_met=False, trim_length=0)
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
temp: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
if top_p > 0 and top_p < 1.0:
|
||||
if (
|
||||
logits.dtype == mx.bfloat16
|
||||
): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
|
||||
logits = logits.astype(mx.float32)
|
||||
probs = mx.softmax(logits / temp, axis=-1)
|
||||
|
||||
sorted_probs = mx.sort(probs)[::-1]
|
||||
sorted_indices = mx.argsort(probs)[::-1]
|
||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||
|
||||
top_probs = mx.where(
|
||||
cumulative_probs > 1 - top_p,
|
||||
sorted_probs,
|
||||
mx.zeros_like(sorted_probs),
|
||||
)
|
||||
sorted_tok = mx.random.categorical(mx.log(top_probs))
|
||||
tok = sorted_indices.squeeze(0)[sorted_tok]
|
||||
return tok
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
|
||||
while True:
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
y = sample(logits)
|
||||
token = y.item()
|
||||
|
||||
yield token
|
||||
|
||||
|
||||
def convert_chat(messages: any, role_mapping: Optional[dict] = None):
|
||||
default_role_mapping = {
|
||||
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
|
||||
@ -242,18 +198,23 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
repetition_penalty: Optional[float],
|
||||
repetition_context_size: Optional[int],
|
||||
response_creator: Callable[[str, str, mx.array, List[int], str], dict],
|
||||
):
|
||||
tokens = []
|
||||
for token, _ in zip(
|
||||
generate(
|
||||
prompt,
|
||||
_model,
|
||||
temperature,
|
||||
for (token, _), _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
model=_model,
|
||||
temp=temperature,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
repetition_context_size=repetition_context_size,
|
||||
),
|
||||
range(max_tokens),
|
||||
):
|
||||
token = token.item()
|
||||
tokens.append(token)
|
||||
stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
|
||||
if stop_condition.stop_met:
|
||||
@ -274,6 +235,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
repetition_penalty: Optional[float],
|
||||
repetition_context_size: Optional[int],
|
||||
response_creator: Callable[[str, str, str], dict],
|
||||
):
|
||||
self.send_response(200)
|
||||
@ -288,15 +251,18 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
|
||||
stop_sequence_buffer = []
|
||||
REPLACEMENT_CHAR = "\ufffd"
|
||||
for token, _ in zip(
|
||||
generate(
|
||||
prompt,
|
||||
_model,
|
||||
temperature,
|
||||
for (token, _), _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
model=_model,
|
||||
temp=temperature,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
repetition_context_size=repetition_context_size,
|
||||
),
|
||||
range(max_tokens),
|
||||
):
|
||||
token = token.item()
|
||||
tokens.append(token)
|
||||
stop_sequence_buffer.append(token)
|
||||
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
|
||||
@ -367,6 +333,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
requested_model = body.get("model", "default_model")
|
||||
temperature = body.get("temperature", 1.0)
|
||||
top_p = body.get("top_p", 1.0)
|
||||
repetition_penalty = body.get("repetition_penalty", 1.0)
|
||||
repetition_context_size = body.get("repetition_context_size", 20)
|
||||
if not stream:
|
||||
return self.generate_response(
|
||||
prompt,
|
||||
@ -377,6 +345,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
repetition_context_size,
|
||||
create_chat_response,
|
||||
)
|
||||
else:
|
||||
@ -389,6 +359,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
repetition_context_size,
|
||||
create_chat_chunk_response,
|
||||
)
|
||||
|
||||
@ -412,6 +384,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
requested_model = body.get("model", "default_model")
|
||||
temperature = body.get("temperature", 1.0)
|
||||
top_p = body.get("top_p", 1.0)
|
||||
repetition_penalty = body.get("repetition_penalty", 1.0)
|
||||
repetition_context_size = body.get("repetition_context_size", 20)
|
||||
if not stream:
|
||||
return self.generate_response(
|
||||
prompt,
|
||||
@ -422,6 +396,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
repetition_context_size,
|
||||
create_completion_response,
|
||||
)
|
||||
else:
|
||||
@ -434,6 +410,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
repetition_context_size,
|
||||
create_completion_chunk_response,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user