mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints * Add type hints * Simplify expression * Type hint fix * Improved do_POST logic Use a map of endpoints to methods to reduce redundancy in code * Fix format * Improve redundancy Call method dynamically instead of writing out all arguments twice * Send response instead of returning * Fix typo * Revert change * Make adapter_file as Optional * Mark formatter as optional * format * Create message generator Store response data that stays static for the duration of the response inside of the object: system_fingerprint request_id object_type requested_model Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline * Remove leftover * Update parameters to reflect new object structure No longer pass all arguments between functions, but use the stores values inside of the object * Parse body before calling request specific methods * Call super init * Update server.py * Fixed outdated documentation parameter name * Add documentation * Fix sending headers twice During testing I found that when using the streaming option, headers have always been sent twice. This should fix that * Simplify streaming code by using guard clauses Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing * Bug fix * Use Content-Length header Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion. * Update utils.py * Add top_p documentation * Type hint model and tokenizer as required * Use static system fingerprint System fingerprint now stays the same across requests * Make type hint more specific * Bug Fix Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead. Mark upload_repo as optional * Move more of the shared code into do_POST Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form. * Store stop_id_sequences as lists instead of np During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported. * Update stop_id_sequences docs * Turn if check to non-inclusive Only continue if buffer is smaller * Documentation fix * Cleared method names Instead of handle_stream and generate_competion, we should name it handle_completion. Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive * Make comment clearer * fix format * format
This commit is contained in:
@@ -12,5 +12,4 @@ MLX Examples was developed with contributions from the following individuals:
|
|||||||
- Shunta Saito: Added support for PLaMo models.
|
- Shunta Saito: Added support for PLaMo models.
|
||||||
- Gabrijel Boduljak: Implemented `CLIP`.
|
- Gabrijel Boduljak: Implemented `CLIP`.
|
||||||
- Markus Enzweiler: Added the `cvae` examples.
|
- Markus Enzweiler: Added the `cvae` examples.
|
||||||
- Rasmus Kinnunen: Fixed a security hole in the `llms/mlx_lm` example
|
|
||||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||||
|
@@ -5,6 +5,7 @@ import glob
|
|||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@@ -109,7 +110,7 @@ def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
|
|||||||
def merge(
|
def merge(
|
||||||
config: str,
|
config: str,
|
||||||
mlx_path: str = "mlx_model",
|
mlx_path: str = "mlx_model",
|
||||||
upload_repo: str = None,
|
upload_repo: Optional[str] = None,
|
||||||
):
|
):
|
||||||
with open(config, "r") as fid:
|
with open(config, "r") as fid:
|
||||||
merge_conf = yaml.safe_load(fid)
|
merge_conf = yaml.safe_load(fid)
|
||||||
@@ -117,7 +118,7 @@ def merge(
|
|||||||
|
|
||||||
model_paths = merge_conf.get("models", [])
|
model_paths = merge_conf.get("models", [])
|
||||||
if len(model_paths) < 2:
|
if len(model_paths) < 2:
|
||||||
raise ValueError(f"Expected at least 2 models, got {len(models)}.")
|
raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.")
|
||||||
|
|
||||||
# Load all models
|
# Load all models
|
||||||
base_hf_path = model_paths[0]
|
base_hf_path = model_paths[0]
|
||||||
@@ -125,9 +126,9 @@ def merge(
|
|||||||
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
|
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
|
||||||
models = []
|
models = []
|
||||||
for mp in model_paths[1:]:
|
for mp in model_paths[1:]:
|
||||||
model, config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
|
model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
|
||||||
base_type = base_config["model_type"]
|
base_type = base_config["model_type"]
|
||||||
model_type = config["model_type"]
|
model_type = model_config["model_type"]
|
||||||
if base_type != model_type:
|
if base_type != model_type:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Can only merge models of the same type,"
|
f"Can only merge models of the same type,"
|
||||||
|
@@ -5,43 +5,39 @@ import json
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from collections import namedtuple
|
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
from typing import Callable, List, Optional
|
from typing import List, Literal, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from .utils import generate_step, load
|
from .utils import generate_step, load
|
||||||
|
|
||||||
_model: Optional[nn.Module] = None
|
MODEL: nn.Module
|
||||||
_tokenizer: Optional[PreTrainedTokenizer] = None
|
TOKENIZER: PreTrainedTokenizer
|
||||||
|
|
||||||
|
SYSTEM_FINGERPRINT: str = f"fp_{uuid.uuid4()}"
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path: str, adapter_file: Optional[str] = None):
|
class StopCondition(NamedTuple):
|
||||||
global _model
|
stop_met: bool
|
||||||
global _tokenizer
|
trim_length: int
|
||||||
_model, _tokenizer = load(model_path, adapter_file=adapter_file)
|
|
||||||
|
|
||||||
|
|
||||||
StopCondition = namedtuple("StopCondition", ["stop_met", "trim_length"])
|
|
||||||
|
|
||||||
|
|
||||||
def stopping_criteria(
|
def stopping_criteria(
|
||||||
tokens: List[int],
|
tokens: List[int],
|
||||||
stop_id_sequences: List[np.ndarray],
|
stop_id_sequences: List[List[int]],
|
||||||
eos_token_id: int,
|
eos_token_id: Union[int, None],
|
||||||
) -> StopCondition:
|
) -> StopCondition:
|
||||||
"""
|
"""
|
||||||
Determines whether the token generation should stop based on predefined conditions.
|
Determines whether the token generation should stop based on predefined conditions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokens (List[int]): The current sequence of generated tokens.
|
tokens (List[int]): The current sequence of generated tokens.
|
||||||
stop_id_sequences (List[np.ndarray]): A list of numpy arrays, each representing a sequence of token IDs.
|
stop_id_sequences (List[List[[int]]): A list of integer lists, each representing a sequence of token IDs.
|
||||||
If the end of the `tokens` list matches any of these sequences, the generation should stop.
|
If the end of the `tokens` list matches any of these sequences, the generation should stop.
|
||||||
eos_token_id (int): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this,
|
eos_token_id (Union[int, None]): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this,
|
||||||
the generation should stop.
|
the generation should stop.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -53,13 +49,13 @@ def stopping_criteria(
|
|||||||
|
|
||||||
for stop_ids in stop_id_sequences:
|
for stop_ids in stop_id_sequences:
|
||||||
if len(tokens) >= len(stop_ids):
|
if len(tokens) >= len(stop_ids):
|
||||||
if np.array_equal(tokens[-len(stop_ids) :], stop_ids):
|
if tokens[-len(stop_ids) :] == stop_ids:
|
||||||
return StopCondition(stop_met=True, trim_length=len(stop_ids))
|
return StopCondition(stop_met=True, trim_length=len(stop_ids))
|
||||||
|
|
||||||
return StopCondition(stop_met=False, trim_length=0)
|
return StopCondition(stop_met=False, trim_length=0)
|
||||||
|
|
||||||
|
|
||||||
def convert_chat(messages: any, role_mapping: Optional[dict] = None):
|
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||||
default_role_mapping = {
|
default_role_mapping = {
|
||||||
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
|
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
|
||||||
"system": "ASSISTANT's RULE: ",
|
"system": "ASSISTANT's RULE: ",
|
||||||
@@ -80,344 +76,324 @@ def convert_chat(messages: any, role_mapping: Optional[dict] = None):
|
|||||||
return prompt.rstrip()
|
return prompt.rstrip()
|
||||||
|
|
||||||
|
|
||||||
def create_chat_response(chat_id, requested_model, prompt, tokens, text):
|
|
||||||
response = {
|
|
||||||
"id": chat_id,
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"model": requested_model,
|
|
||||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": text,
|
|
||||||
},
|
|
||||||
"logprobs": None,
|
|
||||||
"finish_reason": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": len(prompt),
|
|
||||||
"completion_tokens": len(tokens),
|
|
||||||
"total_tokens": len(prompt) + len(tokens),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def create_completion_response(completion_id, requested_model, prompt, tokens, text):
|
|
||||||
return {
|
|
||||||
"id": completion_id,
|
|
||||||
"object": "text_completion",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"model": requested_model,
|
|
||||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
|
||||||
"choices": [
|
|
||||||
{"text": text, "index": 0, "logprobs": None, "finish_reason": "length"}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": len(prompt),
|
|
||||||
"completion_tokens": len(tokens),
|
|
||||||
"total_tokens": len(prompt) + len(tokens),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def create_chat_chunk_response(chat_id, requested_model, next_chunk):
|
|
||||||
response = {
|
|
||||||
"id": chat_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"model": requested_model,
|
|
||||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"role": "assistant", "content": next_chunk},
|
|
||||||
"logprobs": None,
|
|
||||||
"finish_reason": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def create_completion_chunk_response(completion_id, requested_model, next_chunk):
|
|
||||||
return {
|
|
||||||
"id": completion_id,
|
|
||||||
"object": "text_completion",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"choices": [
|
|
||||||
{"text": next_chunk, "index": 0, "logprobs": None, "finish_reason": None}
|
|
||||||
],
|
|
||||||
"model": requested_model,
|
|
||||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class APIHandler(BaseHTTPRequestHandler):
|
class APIHandler(BaseHTTPRequestHandler):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Create static request specific metadata
|
||||||
|
"""
|
||||||
|
self.created = int(time.time())
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _set_headers(self, status_code=200):
|
def _set_completion_headers(self, status_code: int = 200):
|
||||||
self.send_response(status_code)
|
self.send_response(status_code)
|
||||||
self.send_header("Content-type", "application/json")
|
self.send_header("Content-type", "application/json")
|
||||||
self.send_header("Access-Control-Allow-Origin", "*")
|
self.send_header("Access-Control-Allow-Origin", "*")
|
||||||
self.send_header("Access-Control-Allow-Methods", "*")
|
self.send_header("Access-Control-Allow-Methods", "*")
|
||||||
self.send_header("Access-Control-Allow-Headers", "*")
|
self.send_header("Access-Control-Allow-Headers", "*")
|
||||||
self.end_headers()
|
|
||||||
|
def _set_stream_headers(self, status_code: int = 200):
|
||||||
|
self.send_response(status_code)
|
||||||
|
self.send_header("Content-type", "text/event-stream")
|
||||||
|
self.send_header("Cache-Control", "no-cache")
|
||||||
|
|
||||||
def do_OPTIONS(self):
|
def do_OPTIONS(self):
|
||||||
self._set_headers(204)
|
self._set_completion_headers(204)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
if self.path == "/v1/chat/completions":
|
"""
|
||||||
content_length = int(self.headers["Content-Length"])
|
Respond to a POST request from a client
|
||||||
post_data = self.rfile.read(content_length)
|
"""
|
||||||
self._set_headers(200)
|
endpoints = {
|
||||||
|
"/v1/completions": self.handle_text_completions,
|
||||||
|
"/v1/chat/completions": self.handle_chat_completions,
|
||||||
|
}
|
||||||
|
|
||||||
response = self.handle_chat_completions(post_data)
|
if self.path not in endpoints:
|
||||||
|
self._set_completion_headers(404)
|
||||||
self.wfile.write(json.dumps(response).encode("utf-8"))
|
self.end_headers()
|
||||||
elif self.path == "/v1/completions":
|
|
||||||
content_length = int(self.headers["Content-Length"])
|
|
||||||
post_data = self.rfile.read(content_length)
|
|
||||||
self._set_headers(200)
|
|
||||||
|
|
||||||
response = self.handle_completions(post_data)
|
|
||||||
|
|
||||||
self.wfile.write(json.dumps(response).encode("utf-8"))
|
|
||||||
else:
|
|
||||||
self._set_headers(404)
|
|
||||||
self.wfile.write(b"Not Found")
|
self.wfile.write(b"Not Found")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Fetch and parse request body
|
||||||
|
content_length = int(self.headers["Content-Length"])
|
||||||
|
raw_body = self.rfile.read(content_length)
|
||||||
|
self.body = json.loads(raw_body.decode())
|
||||||
|
assert isinstance(
|
||||||
|
self.body, dict
|
||||||
|
), f"Request should be dict, but got {type(self.body)}"
|
||||||
|
|
||||||
|
# Extract request parameters from the body
|
||||||
|
self.stream = self.body.get("stream", False)
|
||||||
|
self.requested_model = self.body.get("model", "default_model")
|
||||||
|
self.max_tokens = self.body.get("max_tokens", 100)
|
||||||
|
self.temperature = self.body.get("temperature", 1.0)
|
||||||
|
self.top_p = self.body.get("top_p", 1.0)
|
||||||
|
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||||
|
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
||||||
|
|
||||||
|
# Get stop id sequences, if provided
|
||||||
|
stop_words = self.body.get("stop", [])
|
||||||
|
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||||
|
stop_id_sequences = [
|
||||||
|
TOKENIZER.encode(stop_word, add_special_tokens=False)
|
||||||
|
for stop_word in stop_words
|
||||||
|
]
|
||||||
|
|
||||||
|
# Send header type
|
||||||
|
(
|
||||||
|
self._set_stream_headers(200)
|
||||||
|
if self.stream
|
||||||
|
else self._set_completion_headers(200)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call endpoint specific method
|
||||||
|
prompt = endpoints[self.path]()
|
||||||
|
|
||||||
|
# Call method based on response type
|
||||||
|
method = self.handle_stream if self.stream else self.handle_completion
|
||||||
|
method(prompt, stop_id_sequences)
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
prompt: mx.array,
|
text: str,
|
||||||
response_id: str,
|
finish_reason: Union[Literal["length", "stop"], None],
|
||||||
requested_model: str,
|
prompt_token_count: Optional[int] = None,
|
||||||
stop_id_sequences: List[np.ndarray],
|
completion_token_count: Optional[int] = None,
|
||||||
eos_token_id: int,
|
) -> dict:
|
||||||
max_tokens: int,
|
"""
|
||||||
temperature: float,
|
Generate a single response packet based on response type (stream or not),
|
||||||
top_p: float,
|
completion type and parameters
|
||||||
repetition_penalty: Optional[float],
|
|
||||||
repetition_context_size: Optional[int],
|
Args:
|
||||||
response_creator: Callable[[str, str, mx.array, List[int], str], dict],
|
text (str): Text generated by model
|
||||||
|
finish_reason (Union[Literal["length", "stop"], None]):
|
||||||
|
The reason the response is being sent: "length", "stop" or None
|
||||||
|
prompt_token_count (Optional[int]):
|
||||||
|
The amount of tokens in the prompt,
|
||||||
|
used to populate the "usage" field (not used when stream)
|
||||||
|
completion_token_count (Optional[int]):
|
||||||
|
The amount of tokens in the response,
|
||||||
|
used to populate the "usage" field (not used when stream)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the response, imitating OpenAI's API
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Static response
|
||||||
|
response = {
|
||||||
|
"id": self.request_id,
|
||||||
|
"system_fingerprint": SYSTEM_FINGERPRINT,
|
||||||
|
"object": self.object_type,
|
||||||
|
"model": self.requested_model,
|
||||||
|
"created": self.created,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
if not self.stream:
|
||||||
|
if not (
|
||||||
|
isinstance(prompt_token_count, int)
|
||||||
|
and isinstance(completion_token_count, int)
|
||||||
):
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Response type is complete, but token counts not provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
response["usage"] = {
|
||||||
|
"prompt_tokens": prompt_token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": prompt_token_count + completion_token_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
choice = response["choices"][0]
|
||||||
|
|
||||||
|
# Add dynamic response
|
||||||
|
if self.object_type.startswith("chat.completion"):
|
||||||
|
key_name = "delta" if self.stream else "message"
|
||||||
|
choice[key_name] = {"role": "assistant", "content": text}
|
||||||
|
elif self.object_type == "text_completion":
|
||||||
|
choice.update(text=text)
|
||||||
|
else:
|
||||||
|
ValueError(f"Unsupported response type: {self.object_type}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def handle_completion(
|
||||||
|
self,
|
||||||
|
prompt: mx.array,
|
||||||
|
stop_id_sequences: List[List[int]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a response to a prompt and send it to the client in a single batch
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||||
|
stop_id_sequences (List[List[int]]):
|
||||||
|
A list of stop words passed to the stopping_criteria function
|
||||||
|
"""
|
||||||
tokens = []
|
tokens = []
|
||||||
for (token, _), _ in zip(
|
for (token, _), _ in zip(
|
||||||
generate_step(
|
generate_step(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=_model,
|
model=MODEL,
|
||||||
temp=temperature,
|
temp=self.temperature,
|
||||||
top_p=top_p,
|
top_p=self.top_p,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=self.repetition_penalty,
|
||||||
repetition_context_size=repetition_context_size,
|
repetition_context_size=self.repetition_context_size,
|
||||||
),
|
),
|
||||||
range(max_tokens),
|
range(self.max_tokens),
|
||||||
):
|
):
|
||||||
token = token.item()
|
token = token.item()
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
|
stop_condition = stopping_criteria(
|
||||||
|
tokens, stop_id_sequences, TOKENIZER.eos_token_id
|
||||||
|
)
|
||||||
if stop_condition.stop_met:
|
if stop_condition.stop_met:
|
||||||
if stop_condition.trim_length:
|
if stop_condition.trim_length:
|
||||||
tokens = tokens[: -stop_condition.trim_length]
|
tokens = tokens[: -stop_condition.trim_length]
|
||||||
break
|
break
|
||||||
|
|
||||||
text = _tokenizer.decode(tokens)
|
text = TOKENIZER.decode(tokens)
|
||||||
return response_creator(response_id, requested_model, prompt, tokens, text)
|
response = self.generate_response(text, "stop", len(prompt), len(tokens))
|
||||||
|
|
||||||
|
response_json = json.dumps(response).encode()
|
||||||
|
|
||||||
|
# Send an additional Content-Length header when it is known
|
||||||
|
self.send_header("Content-Length", str(len(response_json)))
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
self.wfile.write(response_json)
|
||||||
|
self.wfile.flush()
|
||||||
|
|
||||||
def handle_stream(
|
def handle_stream(
|
||||||
self,
|
self,
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
response_id: str,
|
stop_id_sequences: List[List[int]],
|
||||||
requested_model: str,
|
|
||||||
stop_id_sequences: List[np.ndarray],
|
|
||||||
eos_token_id: int,
|
|
||||||
max_tokens: int,
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
repetition_penalty: Optional[float],
|
|
||||||
repetition_context_size: Optional[int],
|
|
||||||
response_creator: Callable[[str, str, str], dict],
|
|
||||||
):
|
):
|
||||||
self.send_response(200)
|
"""
|
||||||
self.send_header("Content-type", "text/event-stream")
|
Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream
|
||||||
self.send_header("Cache-Control", "no-cache")
|
|
||||||
|
Args:
|
||||||
|
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||||
|
stop_id_sequences (List[List[int]]):
|
||||||
|
A list of stop words passed to the stopping_criteria function
|
||||||
|
"""
|
||||||
|
# No additional headers are needed, call end_headers
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
max_stop_id_sequence_len = (
|
|
||||||
max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
|
|
||||||
)
|
|
||||||
tokens = []
|
tokens = []
|
||||||
current_generated_text_index = 0
|
current_generated_text_index = 0
|
||||||
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
|
|
||||||
|
max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
|
||||||
|
# Buffer to store the last `max_stop_id_sequence_len` tokens
|
||||||
|
# to check for stop conditions before writing to the stream.
|
||||||
stop_sequence_buffer = []
|
stop_sequence_buffer = []
|
||||||
REPLACEMENT_CHAR = "\ufffd"
|
|
||||||
for (token, _), _ in zip(
|
for (token, _), _ in zip(
|
||||||
generate_step(
|
generate_step(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=_model,
|
model=MODEL,
|
||||||
temp=temperature,
|
temp=self.temperature,
|
||||||
top_p=top_p,
|
top_p=self.top_p,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=self.repetition_penalty,
|
||||||
repetition_context_size=repetition_context_size,
|
repetition_context_size=self.repetition_context_size,
|
||||||
),
|
),
|
||||||
range(max_tokens),
|
range(self.max_tokens),
|
||||||
):
|
):
|
||||||
token = token.item()
|
token = token.item()
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
stop_sequence_buffer.append(token)
|
stop_sequence_buffer.append(token)
|
||||||
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
|
|
||||||
if REPLACEMENT_CHAR in _tokenizer.decode(token):
|
# Continue generating tokens until buffer is as large as the longest stop_id_sequence
|
||||||
|
if len(stop_sequence_buffer) < max_stop_id_sequence_len:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# "\ufffd" is used to indicate to the tokenizer, that subsequent characters
|
||||||
|
# should be combined into a single unicode character
|
||||||
|
if "\ufffd" in TOKENIZER.decode(token):
|
||||||
|
continue
|
||||||
|
|
||||||
stop_condition = stopping_criteria(
|
stop_condition = stopping_criteria(
|
||||||
tokens,
|
tokens,
|
||||||
stop_id_sequences,
|
stop_id_sequences,
|
||||||
eos_token_id,
|
TOKENIZER.eos_token_id,
|
||||||
)
|
)
|
||||||
if stop_condition.stop_met:
|
if stop_condition.stop_met:
|
||||||
if stop_condition.trim_length:
|
if stop_condition.trim_length:
|
||||||
tokens = tokens[: -stop_condition.trim_length]
|
tokens = tokens[: -stop_condition.trim_length]
|
||||||
break
|
break
|
||||||
# This is a workaround because the llama tokenizer emits spaces when decoding token by token.
|
|
||||||
generated_text = _tokenizer.decode(tokens)
|
# Workaround for llama tokenizer emitting spaces when decoding token by token.
|
||||||
next_chunk = generated_text[current_generated_text_index:]
|
generated_text = TOKENIZER.decode(tokens)
|
||||||
|
new_text = generated_text[current_generated_text_index:]
|
||||||
current_generated_text_index = len(generated_text)
|
current_generated_text_index = len(generated_text)
|
||||||
|
|
||||||
response = response_creator(response_id, requested_model, next_chunk)
|
response = self.generate_response(new_text, None)
|
||||||
try:
|
|
||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
stop_sequence_buffer = []
|
stop_sequence_buffer = []
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
break
|
|
||||||
# check is there any remaining text to send
|
# check is there any remaining text to send
|
||||||
if stop_sequence_buffer:
|
if stop_sequence_buffer:
|
||||||
generated_text = _tokenizer.decode(tokens)
|
generated_text = TOKENIZER.decode(tokens)
|
||||||
next_chunk = generated_text[current_generated_text_index:]
|
next_chunk = generated_text[current_generated_text_index:]
|
||||||
response = response_creator(response_id, requested_model, next_chunk)
|
response = self.generate_response(next_chunk, "length")
|
||||||
try:
|
|
||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
self.wfile.write(f"data: [DONE]\n\n".encode())
|
self.wfile.write("data: [DONE]\n\n".encode())
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
|
||||||
def handle_chat_completions(self, post_data: bytes):
|
def handle_chat_completions(self) -> mx.array:
|
||||||
body = json.loads(post_data.decode("utf-8"))
|
"""
|
||||||
chat_id = f"chatcmpl-{uuid.uuid4()}"
|
Handle a chat completion request
|
||||||
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
|
|
||||||
prompt = _tokenizer.apply_chat_template(
|
Returns:
|
||||||
|
mx.array: A mx.array of the tokenized prompt from the request body
|
||||||
|
"""
|
||||||
|
body = self.body
|
||||||
|
assert "messages" in body, "Request did not contain messages"
|
||||||
|
|
||||||
|
# Determine response type
|
||||||
|
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
|
self.object_type = (
|
||||||
|
"chat.completions.chunk" if self.stream else "chat.completions"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(TOKENIZER, "apply_chat_template") and TOKENIZER.chat_template:
|
||||||
|
prompt = TOKENIZER.apply_chat_template(
|
||||||
body["messages"],
|
body["messages"],
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
return_tensors="np",
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
||||||
prompt = _tokenizer.encode(prompt, return_tensors="np")
|
prompt = TOKENIZER.encode(prompt)
|
||||||
|
|
||||||
prompt = mx.array(prompt[0])
|
return mx.array(prompt)
|
||||||
stop_words = body.get("stop", [])
|
|
||||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
|
||||||
stop_id_sequences = [
|
|
||||||
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
|
|
||||||
0
|
|
||||||
]
|
|
||||||
for stop_word in stop_words
|
|
||||||
]
|
|
||||||
eos_token_id = _tokenizer.eos_token_id
|
|
||||||
max_tokens = body.get("max_tokens", 100)
|
|
||||||
stream = body.get("stream", False)
|
|
||||||
requested_model = body.get("model", "default_model")
|
|
||||||
temperature = body.get("temperature", 1.0)
|
|
||||||
top_p = body.get("top_p", 1.0)
|
|
||||||
repetition_penalty = body.get("repetition_penalty", 1.0)
|
|
||||||
repetition_context_size = body.get("repetition_context_size", 20)
|
|
||||||
if not stream:
|
|
||||||
return self.generate_response(
|
|
||||||
prompt,
|
|
||||||
chat_id,
|
|
||||||
requested_model,
|
|
||||||
stop_id_sequences,
|
|
||||||
eos_token_id,
|
|
||||||
max_tokens,
|
|
||||||
temperature,
|
|
||||||
top_p,
|
|
||||||
repetition_penalty,
|
|
||||||
repetition_context_size,
|
|
||||||
create_chat_response,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.handle_stream(
|
|
||||||
prompt,
|
|
||||||
chat_id,
|
|
||||||
requested_model,
|
|
||||||
stop_id_sequences,
|
|
||||||
eos_token_id,
|
|
||||||
max_tokens,
|
|
||||||
temperature,
|
|
||||||
top_p,
|
|
||||||
repetition_penalty,
|
|
||||||
repetition_context_size,
|
|
||||||
create_chat_chunk_response,
|
|
||||||
)
|
|
||||||
|
|
||||||
def handle_completions(self, post_data: bytes):
|
def handle_text_completions(self) -> mx.array:
|
||||||
body = json.loads(post_data.decode("utf-8"))
|
"""
|
||||||
completion_id = f"cmpl-{uuid.uuid4()}"
|
Handle a text completion request
|
||||||
prompt_text = body["prompt"]
|
|
||||||
prompt = _tokenizer.encode(prompt_text, return_tensors="np")
|
Returns:
|
||||||
prompt = mx.array(prompt[0])
|
mx.array: A mx.array of the tokenized prompt from the request body
|
||||||
stop_words = body.get("stop", [])
|
"""
|
||||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
# Determine response type
|
||||||
stop_id_sequences = [
|
self.request_id = f"cmpl-{uuid.uuid4()}"
|
||||||
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
|
self.object_type = "text_completion"
|
||||||
0
|
|
||||||
]
|
assert "prompt" in self.body, "Request did not contain a prompt"
|
||||||
for stop_word in stop_words
|
prompt_text = self.body["prompt"]
|
||||||
]
|
prompt = TOKENIZER.encode(prompt_text)
|
||||||
eos_token_id = _tokenizer.eos_token_id
|
return mx.array(prompt)
|
||||||
max_tokens = body.get("max_tokens", 100)
|
|
||||||
stream = body.get("stream", False)
|
|
||||||
requested_model = body.get("model", "default_model")
|
|
||||||
temperature = body.get("temperature", 1.0)
|
|
||||||
top_p = body.get("top_p", 1.0)
|
|
||||||
repetition_penalty = body.get("repetition_penalty", 1.0)
|
|
||||||
repetition_context_size = body.get("repetition_context_size", 20)
|
|
||||||
if not stream:
|
|
||||||
return self.generate_response(
|
|
||||||
prompt,
|
|
||||||
completion_id,
|
|
||||||
requested_model,
|
|
||||||
stop_id_sequences,
|
|
||||||
eos_token_id,
|
|
||||||
max_tokens,
|
|
||||||
temperature,
|
|
||||||
top_p,
|
|
||||||
repetition_penalty,
|
|
||||||
repetition_context_size,
|
|
||||||
create_completion_response,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.handle_stream(
|
|
||||||
prompt,
|
|
||||||
completion_id,
|
|
||||||
requested_model,
|
|
||||||
stop_id_sequences,
|
|
||||||
eos_token_id,
|
|
||||||
max_tokens,
|
|
||||||
temperature,
|
|
||||||
top_p,
|
|
||||||
repetition_penalty,
|
|
||||||
repetition_context_size,
|
|
||||||
create_completion_chunk_response,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
|
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
|
||||||
@@ -458,6 +434,6 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
load_model(args.model, adapter_file=args.adapter_file)
|
MODEL, TOKENIZER = load(args.model, adapter_file=args.adapter_file)
|
||||||
|
|
||||||
run(args.host, args.port)
|
run(args.host, args.port)
|
||||||
|
@@ -114,7 +114,7 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
|
|||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
temp: 0.0,
|
temp: float = 0.0,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
repetition_context_size: Optional[int] = 20,
|
repetition_context_size: Optional[int] = 20,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
@@ -128,6 +128,7 @@ def generate_step(
|
|||||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||||
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
||||||
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
|
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
|
||||||
|
top_p (float, optional): Nulceus sampling, higher means model considers more less likely words
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
||||||
@@ -205,7 +206,7 @@ def generate(
|
|||||||
temp: float = 0.0,
|
temp: float = 0.0,
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
formatter: Callable = None,
|
formatter: Optional[Callable] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
repetition_context_size: Optional[int] = None,
|
repetition_context_size: Optional[int] = None,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
@@ -357,14 +358,14 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
|||||||
def load(
|
def load(
|
||||||
path_or_hf_repo: str,
|
path_or_hf_repo: str,
|
||||||
tokenizer_config={},
|
tokenizer_config={},
|
||||||
adapter_file: str = None,
|
adapter_file: Optional[str] = None,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||||
"""
|
"""
|
||||||
Load the model and tokenizer from a given path or a huggingface repository.
|
Load the model and tokenizer from a given path or a huggingface repository.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path (Path): The path or the huggingface repository to load the model from.
|
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
||||||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
||||||
Defaults to an empty dictionary.
|
Defaults to an empty dictionary.
|
||||||
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.
|
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.
|
||||||
|
Reference in New Issue
Block a user