From 82f3f31d93d5d308e6bb8a863adea2e57c08301e Mon Sep 17 00:00:00 2001 From: Anchen Date: Wed, 28 Feb 2024 01:25:24 +1100 Subject: [PATCH] chore(mlx-lm): refactor server.py to utilize generate_step from utils for consistency (#491) * chore(mlx-lm): refactor server.py to utilize generate_step from utils for consistency * chore(mlx-lm): update server doc * chore: remove unused generate func --- llms/mlx_lm/SERVER.md | 2 + llms/mlx_lm/server.py | 88 ++++++++++++++++--------------------------- 2 files changed, 35 insertions(+), 55 deletions(-) diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index 1176951d..e7dd5578 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -61,3 +61,5 @@ curl localhost:8080/v1/chat/completions \ - `top_p`: (Optional) A float specifying the nucleus sampling parameter. Defaults to `1.0`. +- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens. Defaults to `1.0`. +- `repetition_context_size`: (Optional) The size of the context window for applying repetition penalty. Defaults to `20`. \ No newline at end of file diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index a679216c..e8f35325 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -11,7 +11,7 @@ import mlx.nn as nn import numpy as np from transformers import PreTrainedTokenizer -from .utils import load +from .utils import generate_step, load _model: Optional[nn.Module] = None _tokenizer: Optional[PreTrainedTokenizer] = None @@ -56,50 +56,6 @@ def stopping_criteria( return StopCondition(stop_met=False, trim_length=0) -def generate( - prompt: mx.array, - model: nn.Module, - temp: float = 0.0, - top_p: float = 1.0, -): - def sample(logits): - if temp == 0: - return mx.argmax(logits, axis=-1) - else: - if top_p > 0 and top_p < 1.0: - if ( - logits.dtype == mx.bfloat16 - ): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16 - logits = logits.astype(mx.float32) - probs = mx.softmax(logits / temp, axis=-1) - - sorted_probs = mx.sort(probs)[::-1] - sorted_indices = mx.argsort(probs)[::-1] - cumulative_probs = mx.cumsum(sorted_probs, axis=-1) - - top_probs = mx.where( - cumulative_probs > 1 - top_p, - sorted_probs, - mx.zeros_like(sorted_probs), - ) - sorted_tok = mx.random.categorical(mx.log(top_probs)) - tok = sorted_indices.squeeze(0)[sorted_tok] - return tok - return mx.random.categorical(logits * (1 / temp)) - - y = prompt - cache = None - - while True: - logits, cache = model(y[None], cache=cache) - logits = logits[:, -1, :] - - y = sample(logits) - token = y.item() - - yield token - - def convert_chat(messages: any, role_mapping: Optional[dict] = None): default_role_mapping = { "system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.", @@ -242,18 +198,23 @@ class APIHandler(BaseHTTPRequestHandler): max_tokens: int, temperature: float, top_p: float, + repetition_penalty: Optional[float], + repetition_context_size: Optional[int], response_creator: Callable[[str, str, mx.array, List[int], str], dict], ): tokens = [] - for token, _ in zip( - generate( - prompt, - _model, - temperature, + for (token, _), _ in zip( + generate_step( + prompt=prompt, + model=_model, + temp=temperature, top_p=top_p, + repetition_penalty=repetition_penalty, + repetition_context_size=repetition_context_size, ), range(max_tokens), ): + token = token.item() tokens.append(token) stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id) if stop_condition.stop_met: @@ -274,6 +235,8 @@ class APIHandler(BaseHTTPRequestHandler): max_tokens: int, temperature: float, top_p: float, + repetition_penalty: Optional[float], + repetition_context_size: Optional[int], response_creator: Callable[[str, str, str], dict], ): self.send_response(200) @@ -288,15 +251,18 @@ class APIHandler(BaseHTTPRequestHandler): # Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream. stop_sequence_buffer = [] REPLACEMENT_CHAR = "\ufffd" - for token, _ in zip( - generate( - prompt, - _model, - temperature, + for (token, _), _ in zip( + generate_step( + prompt=prompt, + model=_model, + temp=temperature, top_p=top_p, + repetition_penalty=repetition_penalty, + repetition_context_size=repetition_context_size, ), range(max_tokens), ): + token = token.item() tokens.append(token) stop_sequence_buffer.append(token) if len(stop_sequence_buffer) > max_stop_id_sequence_len: @@ -367,6 +333,8 @@ class APIHandler(BaseHTTPRequestHandler): requested_model = body.get("model", "default_model") temperature = body.get("temperature", 1.0) top_p = body.get("top_p", 1.0) + repetition_penalty = body.get("repetition_penalty", 1.0) + repetition_context_size = body.get("repetition_context_size", 20) if not stream: return self.generate_response( prompt, @@ -377,6 +345,8 @@ class APIHandler(BaseHTTPRequestHandler): max_tokens, temperature, top_p, + repetition_penalty, + repetition_context_size, create_chat_response, ) else: @@ -389,6 +359,8 @@ class APIHandler(BaseHTTPRequestHandler): max_tokens, temperature, top_p, + repetition_penalty, + repetition_context_size, create_chat_chunk_response, ) @@ -412,6 +384,8 @@ class APIHandler(BaseHTTPRequestHandler): requested_model = body.get("model", "default_model") temperature = body.get("temperature", 1.0) top_p = body.get("top_p", 1.0) + repetition_penalty = body.get("repetition_penalty", 1.0) + repetition_context_size = body.get("repetition_context_size", 20) if not stream: return self.generate_response( prompt, @@ -422,6 +396,8 @@ class APIHandler(BaseHTTPRequestHandler): max_tokens, temperature, top_p, + repetition_penalty, + repetition_context_size, create_completion_response, ) else: @@ -434,6 +410,8 @@ class APIHandler(BaseHTTPRequestHandler): max_tokens, temperature, top_p, + repetition_penalty, + repetition_context_size, create_completion_chunk_response, )