From 82f3f31d93d5d308e6bb8a863adea2e57c08301e Mon Sep 17 00:00:00 2001
From: Anchen
Date: Wed, 28 Feb 2024 01:25:24 +1100
Subject: [PATCH] chore(mlx-lm): refactor server.py to utilize generate_step
from utils for consistency (#491)
* chore(mlx-lm): refactor server.py to utilize generate_step from utils for consistency
* chore(mlx-lm): update server doc
* chore: remove unused generate func
---
llms/mlx_lm/SERVER.md | 2 +
llms/mlx_lm/server.py | 88 ++++++++++++++++---------------------------
2 files changed, 35 insertions(+), 55 deletions(-)
diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md
index 1176951d..e7dd5578 100644
--- a/llms/mlx_lm/SERVER.md
+++ b/llms/mlx_lm/SERVER.md
@@ -61,3 +61,5 @@ curl localhost:8080/v1/chat/completions \
- `top_p`: (Optional) A float specifying the nucleus sampling parameter.
Defaults to `1.0`.
+- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens. Defaults to `1.0`.
+- `repetition_context_size`: (Optional) The size of the context window for applying repetition penalty. Defaults to `20`.
\ No newline at end of file
diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py
index a679216c..e8f35325 100644
--- a/llms/mlx_lm/server.py
+++ b/llms/mlx_lm/server.py
@@ -11,7 +11,7 @@ import mlx.nn as nn
import numpy as np
from transformers import PreTrainedTokenizer
-from .utils import load
+from .utils import generate_step, load
_model: Optional[nn.Module] = None
_tokenizer: Optional[PreTrainedTokenizer] = None
@@ -56,50 +56,6 @@ def stopping_criteria(
return StopCondition(stop_met=False, trim_length=0)
-def generate(
- prompt: mx.array,
- model: nn.Module,
- temp: float = 0.0,
- top_p: float = 1.0,
-):
- def sample(logits):
- if temp == 0:
- return mx.argmax(logits, axis=-1)
- else:
- if top_p > 0 and top_p < 1.0:
- if (
- logits.dtype == mx.bfloat16
- ): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
- logits = logits.astype(mx.float32)
- probs = mx.softmax(logits / temp, axis=-1)
-
- sorted_probs = mx.sort(probs)[::-1]
- sorted_indices = mx.argsort(probs)[::-1]
- cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
-
- top_probs = mx.where(
- cumulative_probs > 1 - top_p,
- sorted_probs,
- mx.zeros_like(sorted_probs),
- )
- sorted_tok = mx.random.categorical(mx.log(top_probs))
- tok = sorted_indices.squeeze(0)[sorted_tok]
- return tok
- return mx.random.categorical(logits * (1 / temp))
-
- y = prompt
- cache = None
-
- while True:
- logits, cache = model(y[None], cache=cache)
- logits = logits[:, -1, :]
-
- y = sample(logits)
- token = y.item()
-
- yield token
-
-
def convert_chat(messages: any, role_mapping: Optional[dict] = None):
default_role_mapping = {
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
@@ -242,18 +198,23 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens: int,
temperature: float,
top_p: float,
+ repetition_penalty: Optional[float],
+ repetition_context_size: Optional[int],
response_creator: Callable[[str, str, mx.array, List[int], str], dict],
):
tokens = []
- for token, _ in zip(
- generate(
- prompt,
- _model,
- temperature,
+ for (token, _), _ in zip(
+ generate_step(
+ prompt=prompt,
+ model=_model,
+ temp=temperature,
top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ repetition_context_size=repetition_context_size,
),
range(max_tokens),
):
+ token = token.item()
tokens.append(token)
stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
if stop_condition.stop_met:
@@ -274,6 +235,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens: int,
temperature: float,
top_p: float,
+ repetition_penalty: Optional[float],
+ repetition_context_size: Optional[int],
response_creator: Callable[[str, str, str], dict],
):
self.send_response(200)
@@ -288,15 +251,18 @@ class APIHandler(BaseHTTPRequestHandler):
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
stop_sequence_buffer = []
REPLACEMENT_CHAR = "\ufffd"
- for token, _ in zip(
- generate(
- prompt,
- _model,
- temperature,
+ for (token, _), _ in zip(
+ generate_step(
+ prompt=prompt,
+ model=_model,
+ temp=temperature,
top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ repetition_context_size=repetition_context_size,
),
range(max_tokens),
):
+ token = token.item()
tokens.append(token)
stop_sequence_buffer.append(token)
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
@@ -367,6 +333,8 @@ class APIHandler(BaseHTTPRequestHandler):
requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
+ repetition_penalty = body.get("repetition_penalty", 1.0)
+ repetition_context_size = body.get("repetition_context_size", 20)
if not stream:
return self.generate_response(
prompt,
@@ -377,6 +345,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens,
temperature,
top_p,
+ repetition_penalty,
+ repetition_context_size,
create_chat_response,
)
else:
@@ -389,6 +359,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens,
temperature,
top_p,
+ repetition_penalty,
+ repetition_context_size,
create_chat_chunk_response,
)
@@ -412,6 +384,8 @@ class APIHandler(BaseHTTPRequestHandler):
requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
+ repetition_penalty = body.get("repetition_penalty", 1.0)
+ repetition_context_size = body.get("repetition_context_size", 20)
if not stream:
return self.generate_response(
prompt,
@@ -422,6 +396,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens,
temperature,
top_p,
+ repetition_penalty,
+ repetition_context_size,
create_completion_response,
)
else:
@@ -434,6 +410,8 @@ class APIHandler(BaseHTTPRequestHandler):
max_tokens,
temperature,
top_p,
+ repetition_penalty,
+ repetition_context_size,
create_completion_chunk_response,
)