chore(mlx-lm): refactor server.py to utilize generate_step from utils for consistency (#491)

* chore(mlx-lm): refactor server.py to utilize generate_step from utils for consistency

* chore(mlx-lm): update server doc

* chore: remove unused generate func
This commit is contained in:
Anchen 2024-02-28 01:25:24 +11:00 committed by GitHub
parent 19a21bfce4
commit 82f3f31d93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 55 deletions

View File

@ -61,3 +61,5 @@ curl localhost:8080/v1/chat/completions \
- `top_p`: (Optional) A float specifying the nucleus sampling parameter. - `top_p`: (Optional) A float specifying the nucleus sampling parameter.
Defaults to `1.0`. Defaults to `1.0`.
- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens. Defaults to `1.0`.
- `repetition_context_size`: (Optional) The size of the context window for applying repetition penalty. Defaults to `20`.

View File

@ -11,7 +11,7 @@ import mlx.nn as nn
import numpy as np import numpy as np
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from .utils import load from .utils import generate_step, load
_model: Optional[nn.Module] = None _model: Optional[nn.Module] = None
_tokenizer: Optional[PreTrainedTokenizer] = None _tokenizer: Optional[PreTrainedTokenizer] = None
@ -56,50 +56,6 @@ def stopping_criteria(
return StopCondition(stop_met=False, trim_length=0) return StopCondition(stop_met=False, trim_length=0)
def generate(
prompt: mx.array,
model: nn.Module,
temp: float = 0.0,
top_p: float = 1.0,
):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
if (
logits.dtype == mx.bfloat16
): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
logits = logits.astype(mx.float32)
probs = mx.softmax(logits / temp, axis=-1)
sorted_probs = mx.sort(probs)[::-1]
sorted_indices = mx.argsort(probs)[::-1]
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
top_probs = mx.where(
cumulative_probs > 1 - top_p,
sorted_probs,
mx.zeros_like(sorted_probs),
)
sorted_tok = mx.random.categorical(mx.log(top_probs))
tok = sorted_indices.squeeze(0)[sorted_tok]
return tok
return mx.random.categorical(logits * (1 / temp))
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
token = y.item()
yield token
def convert_chat(messages: any, role_mapping: Optional[dict] = None): def convert_chat(messages: any, role_mapping: Optional[dict] = None):
default_role_mapping = { default_role_mapping = {
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.", "system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
@ -242,18 +198,23 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens: int, max_tokens: int,
temperature: float, temperature: float,
top_p: float, top_p: float,
repetition_penalty: Optional[float],
repetition_context_size: Optional[int],
response_creator: Callable[[str, str, mx.array, List[int], str], dict], response_creator: Callable[[str, str, mx.array, List[int], str], dict],
): ):
tokens = [] tokens = []
for token, _ in zip( for (token, _), _ in zip(
generate( generate_step(
prompt, prompt=prompt,
_model, model=_model,
temperature, temp=temperature,
top_p=top_p, top_p=top_p,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
), ),
range(max_tokens), range(max_tokens),
): ):
token = token.item()
tokens.append(token) tokens.append(token)
stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id) stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
if stop_condition.stop_met: if stop_condition.stop_met:
@ -274,6 +235,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens: int, max_tokens: int,
temperature: float, temperature: float,
top_p: float, top_p: float,
repetition_penalty: Optional[float],
repetition_context_size: Optional[int],
response_creator: Callable[[str, str, str], dict], response_creator: Callable[[str, str, str], dict],
): ):
self.send_response(200) self.send_response(200)
@ -288,15 +251,18 @@ class APIHandler(BaseHTTPRequestHandler):
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream. # Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
stop_sequence_buffer = [] stop_sequence_buffer = []
REPLACEMENT_CHAR = "\ufffd" REPLACEMENT_CHAR = "\ufffd"
for token, _ in zip( for (token, _), _ in zip(
generate( generate_step(
prompt, prompt=prompt,
_model, model=_model,
temperature, temp=temperature,
top_p=top_p, top_p=top_p,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
), ),
range(max_tokens), range(max_tokens),
): ):
token = token.item()
tokens.append(token) tokens.append(token)
stop_sequence_buffer.append(token) stop_sequence_buffer.append(token)
if len(stop_sequence_buffer) > max_stop_id_sequence_len: if len(stop_sequence_buffer) > max_stop_id_sequence_len:
@ -367,6 +333,8 @@ class APIHandler(BaseHTTPRequestHandler):
requested_model = body.get("model", "default_model") requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0) temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0) top_p = body.get("top_p", 1.0)
repetition_penalty = body.get("repetition_penalty", 1.0)
repetition_context_size = body.get("repetition_context_size", 20)
if not stream: if not stream:
return self.generate_response( return self.generate_response(
prompt, prompt,
@ -377,6 +345,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens, max_tokens,
temperature, temperature,
top_p, top_p,
repetition_penalty,
repetition_context_size,
create_chat_response, create_chat_response,
) )
else: else:
@ -389,6 +359,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens, max_tokens,
temperature, temperature,
top_p, top_p,
repetition_penalty,
repetition_context_size,
create_chat_chunk_response, create_chat_chunk_response,
) )
@ -412,6 +384,8 @@ class APIHandler(BaseHTTPRequestHandler):
requested_model = body.get("model", "default_model") requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0) temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0) top_p = body.get("top_p", 1.0)
repetition_penalty = body.get("repetition_penalty", 1.0)
repetition_context_size = body.get("repetition_context_size", 20)
if not stream: if not stream:
return self.generate_response( return self.generate_response(
prompt, prompt,
@ -422,6 +396,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens, max_tokens,
temperature, temperature,
top_p, top_p,
repetition_penalty,
repetition_context_size,
create_completion_response, create_completion_response,
) )
else: else:
@ -434,6 +410,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens, max_tokens,
temperature, temperature,
top_p, top_p,
repetition_penalty,
repetition_context_size,
create_completion_chunk_response, create_completion_chunk_response,
) )