chore(mlx-lm): refactor server.py to utilize generate_step from utils for consistency (#491)

* chore(mlx-lm): refactor server.py to utilize generate_step from utils for consistency

* chore(mlx-lm): update server doc

* chore: remove unused generate func
This commit is contained in:
Anchen 2024-02-28 01:25:24 +11:00 committed by GitHub
parent 19a21bfce4
commit 82f3f31d93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 55 deletions

View File

@ -61,3 +61,5 @@ curl localhost:8080/v1/chat/completions \
- `top_p`: (Optional) A float specifying the nucleus sampling parameter.
Defaults to `1.0`.
- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens. Defaults to `1.0`.
- `repetition_context_size`: (Optional) The size of the context window for applying repetition penalty. Defaults to `20`.

View File

@ -11,7 +11,7 @@ import mlx.nn as nn
import numpy as np
from transformers import PreTrainedTokenizer
from .utils import load
from .utils import generate_step, load
_model: Optional[nn.Module] = None
_tokenizer: Optional[PreTrainedTokenizer] = None
@ -56,50 +56,6 @@ def stopping_criteria(
return StopCondition(stop_met=False, trim_length=0)
def generate(
prompt: mx.array,
model: nn.Module,
temp: float = 0.0,
top_p: float = 1.0,
):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
if (
logits.dtype == mx.bfloat16
): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
logits = logits.astype(mx.float32)
probs = mx.softmax(logits / temp, axis=-1)
sorted_probs = mx.sort(probs)[::-1]
sorted_indices = mx.argsort(probs)[::-1]
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
top_probs = mx.where(
cumulative_probs > 1 - top_p,
sorted_probs,
mx.zeros_like(sorted_probs),
)
sorted_tok = mx.random.categorical(mx.log(top_probs))
tok = sorted_indices.squeeze(0)[sorted_tok]
return tok
return mx.random.categorical(logits * (1 / temp))
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
token = y.item()
yield token
def convert_chat(messages: any, role_mapping: Optional[dict] = None):
default_role_mapping = {
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
@ -242,18 +198,23 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens: int,
temperature: float,
top_p: float,
repetition_penalty: Optional[float],
repetition_context_size: Optional[int],
response_creator: Callable[[str, str, mx.array, List[int], str], dict],
):
tokens = []
for token, _ in zip(
generate(
prompt,
_model,
temperature,
for (token, _), _ in zip(
generate_step(
prompt=prompt,
model=_model,
temp=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
),
range(max_tokens),
):
token = token.item()
tokens.append(token)
stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
if stop_condition.stop_met:
@ -274,6 +235,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens: int,
temperature: float,
top_p: float,
repetition_penalty: Optional[float],
repetition_context_size: Optional[int],
response_creator: Callable[[str, str, str], dict],
):
self.send_response(200)
@ -288,15 +251,18 @@ class APIHandler(BaseHTTPRequestHandler):
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
stop_sequence_buffer = []
REPLACEMENT_CHAR = "\ufffd"
for token, _ in zip(
generate(
prompt,
_model,
temperature,
for (token, _), _ in zip(
generate_step(
prompt=prompt,
model=_model,
temp=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
),
range(max_tokens),
):
token = token.item()
tokens.append(token)
stop_sequence_buffer.append(token)
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
@ -367,6 +333,8 @@ class APIHandler(BaseHTTPRequestHandler):
requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
repetition_penalty = body.get("repetition_penalty", 1.0)
repetition_context_size = body.get("repetition_context_size", 20)
if not stream:
return self.generate_response(
prompt,
@ -377,6 +345,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens,
temperature,
top_p,
repetition_penalty,
repetition_context_size,
create_chat_response,
)
else:
@ -389,6 +359,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens,
temperature,
top_p,
repetition_penalty,
repetition_context_size,
create_chat_chunk_response,
)
@ -412,6 +384,8 @@ class APIHandler(BaseHTTPRequestHandler):
requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
repetition_penalty = body.get("repetition_penalty", 1.0)
repetition_context_size = body.get("repetition_context_size", 20)
if not stream:
return self.generate_response(
prompt,
@ -422,6 +396,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens,
temperature,
top_p,
repetition_penalty,
repetition_context_size,
create_completion_response,
)
else:
@ -434,6 +410,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens,
temperature,
top_p,
repetition_penalty,
repetition_context_size,
create_completion_chunk_response,
)