mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints * Add type hints * Simplify expression * Type hint fix * Improved do_POST logic Use a map of endpoints to methods to reduce redundancy in code * Fix format * Improve redundancy Call method dynamically instead of writing out all arguments twice * Send response instead of returning * Fix typo * Revert change * Make adapter_file as Optional * Mark formatter as optional * format * Create message generator Store response data that stays static for the duration of the response inside of the object: system_fingerprint request_id object_type requested_model Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline * Remove leftover * Update parameters to reflect new object structure No longer pass all arguments between functions, but use the stores values inside of the object * Parse body before calling request specific methods * Call super init * Update server.py * Fixed outdated documentation parameter name * Add documentation * Fix sending headers twice During testing I found that when using the streaming option, headers have always been sent twice. This should fix that * Simplify streaming code by using guard clauses Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing * Bug fix * Use Content-Length header Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion. * Update utils.py * Add top_p documentation * Type hint model and tokenizer as required * Use static system fingerprint System fingerprint now stays the same across requests * Make type hint more specific * Bug Fix Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead. Mark upload_repo as optional * Move more of the shared code into do_POST Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form. * Store stop_id_sequences as lists instead of np During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported. * Update stop_id_sequences docs * Turn if check to non-inclusive Only continue if buffer is smaller * Documentation fix * Cleared method names Instead of handle_stream and generate_competion, we should name it handle_completion. Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive * Make comment clearer * fix format * format
This commit is contained in:
parent
710c552731
commit
b8e5eda4fd
@ -12,5 +12,4 @@ MLX Examples was developed with contributions from the following individuals:
|
||||
- Shunta Saito: Added support for PLaMo models.
|
||||
- Gabrijel Boduljak: Implemented `CLIP`.
|
||||
- Markus Enzweiler: Added the `cvae` examples.
|
||||
- Rasmus Kinnunen: Fixed a security hole in the `llms/mlx_lm` example
|
||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||
|
@ -5,6 +5,7 @@ import glob
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -109,7 +110,7 @@ def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
|
||||
def merge(
|
||||
config: str,
|
||||
mlx_path: str = "mlx_model",
|
||||
upload_repo: str = None,
|
||||
upload_repo: Optional[str] = None,
|
||||
):
|
||||
with open(config, "r") as fid:
|
||||
merge_conf = yaml.safe_load(fid)
|
||||
@ -117,7 +118,7 @@ def merge(
|
||||
|
||||
model_paths = merge_conf.get("models", [])
|
||||
if len(model_paths) < 2:
|
||||
raise ValueError(f"Expected at least 2 models, got {len(models)}.")
|
||||
raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.")
|
||||
|
||||
# Load all models
|
||||
base_hf_path = model_paths[0]
|
||||
@ -125,9 +126,9 @@ def merge(
|
||||
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
|
||||
models = []
|
||||
for mp in model_paths[1:]:
|
||||
model, config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
|
||||
model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
|
||||
base_type = base_config["model_type"]
|
||||
model_type = config["model_type"]
|
||||
model_type = model_config["model_type"]
|
||||
if base_type != model_type:
|
||||
raise ValueError(
|
||||
f"Can only merge models of the same type,"
|
||||
|
@ -5,43 +5,39 @@ import json
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Callable, List, Optional
|
||||
from typing import List, Literal, NamedTuple, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .utils import generate_step, load
|
||||
|
||||
_model: Optional[nn.Module] = None
|
||||
_tokenizer: Optional[PreTrainedTokenizer] = None
|
||||
MODEL: nn.Module
|
||||
TOKENIZER: PreTrainedTokenizer
|
||||
|
||||
SYSTEM_FINGERPRINT: str = f"fp_{uuid.uuid4()}"
|
||||
|
||||
|
||||
def load_model(model_path: str, adapter_file: Optional[str] = None):
|
||||
global _model
|
||||
global _tokenizer
|
||||
_model, _tokenizer = load(model_path, adapter_file=adapter_file)
|
||||
|
||||
|
||||
StopCondition = namedtuple("StopCondition", ["stop_met", "trim_length"])
|
||||
class StopCondition(NamedTuple):
|
||||
stop_met: bool
|
||||
trim_length: int
|
||||
|
||||
|
||||
def stopping_criteria(
|
||||
tokens: List[int],
|
||||
stop_id_sequences: List[np.ndarray],
|
||||
eos_token_id: int,
|
||||
stop_id_sequences: List[List[int]],
|
||||
eos_token_id: Union[int, None],
|
||||
) -> StopCondition:
|
||||
"""
|
||||
Determines whether the token generation should stop based on predefined conditions.
|
||||
|
||||
Args:
|
||||
tokens (List[int]): The current sequence of generated tokens.
|
||||
stop_id_sequences (List[np.ndarray]): A list of numpy arrays, each representing a sequence of token IDs.
|
||||
stop_id_sequences (List[List[[int]]): A list of integer lists, each representing a sequence of token IDs.
|
||||
If the end of the `tokens` list matches any of these sequences, the generation should stop.
|
||||
eos_token_id (int): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this,
|
||||
eos_token_id (Union[int, None]): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this,
|
||||
the generation should stop.
|
||||
|
||||
Returns:
|
||||
@ -53,13 +49,13 @@ def stopping_criteria(
|
||||
|
||||
for stop_ids in stop_id_sequences:
|
||||
if len(tokens) >= len(stop_ids):
|
||||
if np.array_equal(tokens[-len(stop_ids) :], stop_ids):
|
||||
if tokens[-len(stop_ids) :] == stop_ids:
|
||||
return StopCondition(stop_met=True, trim_length=len(stop_ids))
|
||||
|
||||
return StopCondition(stop_met=False, trim_length=0)
|
||||
|
||||
|
||||
def convert_chat(messages: any, role_mapping: Optional[dict] = None):
|
||||
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||
default_role_mapping = {
|
||||
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
|
||||
"system": "ASSISTANT's RULE: ",
|
||||
@ -80,344 +76,324 @@ def convert_chat(messages: any, role_mapping: Optional[dict] = None):
|
||||
return prompt.rstrip()
|
||||
|
||||
|
||||
def create_chat_response(chat_id, requested_model, prompt, tokens, text):
|
||||
response = {
|
||||
"id": chat_id,
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": requested_model,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": len(prompt),
|
||||
"completion_tokens": len(tokens),
|
||||
"total_tokens": len(prompt) + len(tokens),
|
||||
},
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def create_completion_response(completion_id, requested_model, prompt, tokens, text):
|
||||
return {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": requested_model,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
"choices": [
|
||||
{"text": text, "index": 0, "logprobs": None, "finish_reason": "length"}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": len(prompt),
|
||||
"completion_tokens": len(tokens),
|
||||
"total_tokens": len(prompt) + len(tokens),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def create_chat_chunk_response(chat_id, requested_model, next_chunk):
|
||||
response = {
|
||||
"id": chat_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": requested_model,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": next_chunk},
|
||||
"logprobs": None,
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
def create_completion_chunk_response(completion_id, requested_model, next_chunk):
|
||||
return {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"choices": [
|
||||
{"text": next_chunk, "index": 0, "logprobs": None, "finish_reason": None}
|
||||
],
|
||||
"model": requested_model,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
}
|
||||
|
||||
|
||||
class APIHandler(BaseHTTPRequestHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Create static request specific metadata
|
||||
"""
|
||||
self.created = int(time.time())
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _set_headers(self, status_code=200):
|
||||
def _set_completion_headers(self, status_code: int = 200):
|
||||
self.send_response(status_code)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.send_header("Access-Control-Allow-Origin", "*")
|
||||
self.send_header("Access-Control-Allow-Methods", "*")
|
||||
self.send_header("Access-Control-Allow-Headers", "*")
|
||||
self.end_headers()
|
||||
|
||||
def _set_stream_headers(self, status_code: int = 200):
|
||||
self.send_response(status_code)
|
||||
self.send_header("Content-type", "text/event-stream")
|
||||
self.send_header("Cache-Control", "no-cache")
|
||||
|
||||
def do_OPTIONS(self):
|
||||
self._set_headers(204)
|
||||
self._set_completion_headers(204)
|
||||
self.end_headers()
|
||||
|
||||
def do_POST(self):
|
||||
if self.path == "/v1/chat/completions":
|
||||
content_length = int(self.headers["Content-Length"])
|
||||
post_data = self.rfile.read(content_length)
|
||||
self._set_headers(200)
|
||||
"""
|
||||
Respond to a POST request from a client
|
||||
"""
|
||||
endpoints = {
|
||||
"/v1/completions": self.handle_text_completions,
|
||||
"/v1/chat/completions": self.handle_chat_completions,
|
||||
}
|
||||
|
||||
response = self.handle_chat_completions(post_data)
|
||||
|
||||
self.wfile.write(json.dumps(response).encode("utf-8"))
|
||||
elif self.path == "/v1/completions":
|
||||
content_length = int(self.headers["Content-Length"])
|
||||
post_data = self.rfile.read(content_length)
|
||||
self._set_headers(200)
|
||||
|
||||
response = self.handle_completions(post_data)
|
||||
|
||||
self.wfile.write(json.dumps(response).encode("utf-8"))
|
||||
else:
|
||||
self._set_headers(404)
|
||||
if self.path not in endpoints:
|
||||
self._set_completion_headers(404)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"Not Found")
|
||||
return
|
||||
|
||||
# Fetch and parse request body
|
||||
content_length = int(self.headers["Content-Length"])
|
||||
raw_body = self.rfile.read(content_length)
|
||||
self.body = json.loads(raw_body.decode())
|
||||
assert isinstance(
|
||||
self.body, dict
|
||||
), f"Request should be dict, but got {type(self.body)}"
|
||||
|
||||
# Extract request parameters from the body
|
||||
self.stream = self.body.get("stream", False)
|
||||
self.requested_model = self.body.get("model", "default_model")
|
||||
self.max_tokens = self.body.get("max_tokens", 100)
|
||||
self.temperature = self.body.get("temperature", 1.0)
|
||||
self.top_p = self.body.get("top_p", 1.0)
|
||||
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
||||
|
||||
# Get stop id sequences, if provided
|
||||
stop_words = self.body.get("stop", [])
|
||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||
stop_id_sequences = [
|
||||
TOKENIZER.encode(stop_word, add_special_tokens=False)
|
||||
for stop_word in stop_words
|
||||
]
|
||||
|
||||
# Send header type
|
||||
(
|
||||
self._set_stream_headers(200)
|
||||
if self.stream
|
||||
else self._set_completion_headers(200)
|
||||
)
|
||||
|
||||
# Call endpoint specific method
|
||||
prompt = endpoints[self.path]()
|
||||
|
||||
# Call method based on response type
|
||||
method = self.handle_stream if self.stream else self.handle_completion
|
||||
method(prompt, stop_id_sequences)
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
prompt: mx.array,
|
||||
response_id: str,
|
||||
requested_model: str,
|
||||
stop_id_sequences: List[np.ndarray],
|
||||
eos_token_id: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
repetition_penalty: Optional[float],
|
||||
repetition_context_size: Optional[int],
|
||||
response_creator: Callable[[str, str, mx.array, List[int], str], dict],
|
||||
text: str,
|
||||
finish_reason: Union[Literal["length", "stop"], None],
|
||||
prompt_token_count: Optional[int] = None,
|
||||
completion_token_count: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate a single response packet based on response type (stream or not),
|
||||
completion type and parameters
|
||||
|
||||
Args:
|
||||
text (str): Text generated by model
|
||||
finish_reason (Union[Literal["length", "stop"], None]):
|
||||
The reason the response is being sent: "length", "stop" or None
|
||||
prompt_token_count (Optional[int]):
|
||||
The amount of tokens in the prompt,
|
||||
used to populate the "usage" field (not used when stream)
|
||||
completion_token_count (Optional[int]):
|
||||
The amount of tokens in the response,
|
||||
used to populate the "usage" field (not used when stream)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the response, imitating OpenAI's API
|
||||
"""
|
||||
|
||||
# Static response
|
||||
response = {
|
||||
"id": self.request_id,
|
||||
"system_fingerprint": SYSTEM_FINGERPRINT,
|
||||
"object": self.object_type,
|
||||
"model": self.requested_model,
|
||||
"created": self.created,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
if not self.stream:
|
||||
if not (
|
||||
isinstance(prompt_token_count, int)
|
||||
and isinstance(completion_token_count, int)
|
||||
):
|
||||
raise ValueError(
|
||||
"Response type is complete, but token counts not provided"
|
||||
)
|
||||
|
||||
response["usage"] = {
|
||||
"prompt_tokens": prompt_token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": prompt_token_count + completion_token_count,
|
||||
}
|
||||
|
||||
choice = response["choices"][0]
|
||||
|
||||
# Add dynamic response
|
||||
if self.object_type.startswith("chat.completion"):
|
||||
key_name = "delta" if self.stream else "message"
|
||||
choice[key_name] = {"role": "assistant", "content": text}
|
||||
elif self.object_type == "text_completion":
|
||||
choice.update(text=text)
|
||||
else:
|
||||
ValueError(f"Unsupported response type: {self.object_type}")
|
||||
|
||||
return response
|
||||
|
||||
def handle_completion(
|
||||
self,
|
||||
prompt: mx.array,
|
||||
stop_id_sequences: List[List[int]],
|
||||
):
|
||||
"""
|
||||
Generate a response to a prompt and send it to the client in a single batch
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||
stop_id_sequences (List[List[int]]):
|
||||
A list of stop words passed to the stopping_criteria function
|
||||
"""
|
||||
tokens = []
|
||||
for (token, _), _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
model=_model,
|
||||
temp=temperature,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
repetition_context_size=repetition_context_size,
|
||||
model=MODEL,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
repetition_context_size=self.repetition_context_size,
|
||||
),
|
||||
range(max_tokens),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
token = token.item()
|
||||
tokens.append(token)
|
||||
stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
|
||||
stop_condition = stopping_criteria(
|
||||
tokens, stop_id_sequences, TOKENIZER.eos_token_id
|
||||
)
|
||||
if stop_condition.stop_met:
|
||||
if stop_condition.trim_length:
|
||||
tokens = tokens[: -stop_condition.trim_length]
|
||||
break
|
||||
|
||||
text = _tokenizer.decode(tokens)
|
||||
return response_creator(response_id, requested_model, prompt, tokens, text)
|
||||
text = TOKENIZER.decode(tokens)
|
||||
response = self.generate_response(text, "stop", len(prompt), len(tokens))
|
||||
|
||||
response_json = json.dumps(response).encode()
|
||||
|
||||
# Send an additional Content-Length header when it is known
|
||||
self.send_header("Content-Length", str(len(response_json)))
|
||||
self.end_headers()
|
||||
|
||||
self.wfile.write(response_json)
|
||||
self.wfile.flush()
|
||||
|
||||
def handle_stream(
|
||||
self,
|
||||
prompt: mx.array,
|
||||
response_id: str,
|
||||
requested_model: str,
|
||||
stop_id_sequences: List[np.ndarray],
|
||||
eos_token_id: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
repetition_penalty: Optional[float],
|
||||
repetition_context_size: Optional[int],
|
||||
response_creator: Callable[[str, str, str], dict],
|
||||
stop_id_sequences: List[List[int]],
|
||||
):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/event-stream")
|
||||
self.send_header("Cache-Control", "no-cache")
|
||||
"""
|
||||
Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||
stop_id_sequences (List[List[int]]):
|
||||
A list of stop words passed to the stopping_criteria function
|
||||
"""
|
||||
# No additional headers are needed, call end_headers
|
||||
self.end_headers()
|
||||
max_stop_id_sequence_len = (
|
||||
max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
|
||||
)
|
||||
|
||||
tokens = []
|
||||
current_generated_text_index = 0
|
||||
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
|
||||
|
||||
max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
|
||||
# Buffer to store the last `max_stop_id_sequence_len` tokens
|
||||
# to check for stop conditions before writing to the stream.
|
||||
stop_sequence_buffer = []
|
||||
REPLACEMENT_CHAR = "\ufffd"
|
||||
|
||||
for (token, _), _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
model=_model,
|
||||
temp=temperature,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
repetition_context_size=repetition_context_size,
|
||||
model=MODEL,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
repetition_context_size=self.repetition_context_size,
|
||||
),
|
||||
range(max_tokens),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
token = token.item()
|
||||
tokens.append(token)
|
||||
stop_sequence_buffer.append(token)
|
||||
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
|
||||
if REPLACEMENT_CHAR in _tokenizer.decode(token):
|
||||
|
||||
# Continue generating tokens until buffer is as large as the longest stop_id_sequence
|
||||
if len(stop_sequence_buffer) < max_stop_id_sequence_len:
|
||||
continue
|
||||
|
||||
# "\ufffd" is used to indicate to the tokenizer, that subsequent characters
|
||||
# should be combined into a single unicode character
|
||||
if "\ufffd" in TOKENIZER.decode(token):
|
||||
continue
|
||||
|
||||
stop_condition = stopping_criteria(
|
||||
tokens,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
TOKENIZER.eos_token_id,
|
||||
)
|
||||
if stop_condition.stop_met:
|
||||
if stop_condition.trim_length:
|
||||
tokens = tokens[: -stop_condition.trim_length]
|
||||
break
|
||||
# This is a workaround because the llama tokenizer emits spaces when decoding token by token.
|
||||
generated_text = _tokenizer.decode(tokens)
|
||||
next_chunk = generated_text[current_generated_text_index:]
|
||||
|
||||
# Workaround for llama tokenizer emitting spaces when decoding token by token.
|
||||
generated_text = TOKENIZER.decode(tokens)
|
||||
new_text = generated_text[current_generated_text_index:]
|
||||
current_generated_text_index = len(generated_text)
|
||||
|
||||
response = response_creator(response_id, requested_model, next_chunk)
|
||||
try:
|
||||
response = self.generate_response(new_text, None)
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
stop_sequence_buffer = []
|
||||
except Exception as e:
|
||||
print(e)
|
||||
break
|
||||
|
||||
# check is there any remaining text to send
|
||||
if stop_sequence_buffer:
|
||||
generated_text = _tokenizer.decode(tokens)
|
||||
generated_text = TOKENIZER.decode(tokens)
|
||||
next_chunk = generated_text[current_generated_text_index:]
|
||||
response = response_creator(response_id, requested_model, next_chunk)
|
||||
try:
|
||||
response = self.generate_response(next_chunk, "length")
|
||||
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
self.wfile.write(f"data: [DONE]\n\n".encode())
|
||||
self.wfile.write("data: [DONE]\n\n".encode())
|
||||
self.wfile.flush()
|
||||
|
||||
def handle_chat_completions(self, post_data: bytes):
|
||||
body = json.loads(post_data.decode("utf-8"))
|
||||
chat_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
|
||||
prompt = _tokenizer.apply_chat_template(
|
||||
def handle_chat_completions(self) -> mx.array:
|
||||
"""
|
||||
Handle a chat completion request
|
||||
|
||||
Returns:
|
||||
mx.array: A mx.array of the tokenized prompt from the request body
|
||||
"""
|
||||
body = self.body
|
||||
assert "messages" in body, "Request did not contain messages"
|
||||
|
||||
# Determine response type
|
||||
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
self.object_type = (
|
||||
"chat.completions.chunk" if self.stream else "chat.completions"
|
||||
)
|
||||
|
||||
if hasattr(TOKENIZER, "apply_chat_template") and TOKENIZER.chat_template:
|
||||
prompt = TOKENIZER.apply_chat_template(
|
||||
body["messages"],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
else:
|
||||
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
||||
prompt = _tokenizer.encode(prompt, return_tensors="np")
|
||||
prompt = TOKENIZER.encode(prompt)
|
||||
|
||||
prompt = mx.array(prompt[0])
|
||||
stop_words = body.get("stop", [])
|
||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||
stop_id_sequences = [
|
||||
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
|
||||
0
|
||||
]
|
||||
for stop_word in stop_words
|
||||
]
|
||||
eos_token_id = _tokenizer.eos_token_id
|
||||
max_tokens = body.get("max_tokens", 100)
|
||||
stream = body.get("stream", False)
|
||||
requested_model = body.get("model", "default_model")
|
||||
temperature = body.get("temperature", 1.0)
|
||||
top_p = body.get("top_p", 1.0)
|
||||
repetition_penalty = body.get("repetition_penalty", 1.0)
|
||||
repetition_context_size = body.get("repetition_context_size", 20)
|
||||
if not stream:
|
||||
return self.generate_response(
|
||||
prompt,
|
||||
chat_id,
|
||||
requested_model,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
repetition_context_size,
|
||||
create_chat_response,
|
||||
)
|
||||
else:
|
||||
self.handle_stream(
|
||||
prompt,
|
||||
chat_id,
|
||||
requested_model,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
repetition_context_size,
|
||||
create_chat_chunk_response,
|
||||
)
|
||||
return mx.array(prompt)
|
||||
|
||||
def handle_completions(self, post_data: bytes):
|
||||
body = json.loads(post_data.decode("utf-8"))
|
||||
completion_id = f"cmpl-{uuid.uuid4()}"
|
||||
prompt_text = body["prompt"]
|
||||
prompt = _tokenizer.encode(prompt_text, return_tensors="np")
|
||||
prompt = mx.array(prompt[0])
|
||||
stop_words = body.get("stop", [])
|
||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||
stop_id_sequences = [
|
||||
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
|
||||
0
|
||||
]
|
||||
for stop_word in stop_words
|
||||
]
|
||||
eos_token_id = _tokenizer.eos_token_id
|
||||
max_tokens = body.get("max_tokens", 100)
|
||||
stream = body.get("stream", False)
|
||||
requested_model = body.get("model", "default_model")
|
||||
temperature = body.get("temperature", 1.0)
|
||||
top_p = body.get("top_p", 1.0)
|
||||
repetition_penalty = body.get("repetition_penalty", 1.0)
|
||||
repetition_context_size = body.get("repetition_context_size", 20)
|
||||
if not stream:
|
||||
return self.generate_response(
|
||||
prompt,
|
||||
completion_id,
|
||||
requested_model,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
repetition_context_size,
|
||||
create_completion_response,
|
||||
)
|
||||
else:
|
||||
self.handle_stream(
|
||||
prompt,
|
||||
completion_id,
|
||||
requested_model,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
repetition_context_size,
|
||||
create_completion_chunk_response,
|
||||
)
|
||||
def handle_text_completions(self) -> mx.array:
|
||||
"""
|
||||
Handle a text completion request
|
||||
|
||||
Returns:
|
||||
mx.array: A mx.array of the tokenized prompt from the request body
|
||||
"""
|
||||
# Determine response type
|
||||
self.request_id = f"cmpl-{uuid.uuid4()}"
|
||||
self.object_type = "text_completion"
|
||||
|
||||
assert "prompt" in self.body, "Request did not contain a prompt"
|
||||
prompt_text = self.body["prompt"]
|
||||
prompt = TOKENIZER.encode(prompt_text)
|
||||
return mx.array(prompt)
|
||||
|
||||
|
||||
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
|
||||
@ -458,6 +434,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
load_model(args.model, adapter_file=args.adapter_file)
|
||||
MODEL, TOKENIZER = load(args.model, adapter_file=args.adapter_file)
|
||||
|
||||
run(args.host, args.port)
|
||||
|
@ -114,7 +114,7 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
temp: 0.0,
|
||||
temp: float = 0.0,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = 20,
|
||||
top_p: float = 1.0,
|
||||
@ -128,6 +128,7 @@ def generate_step(
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
|
||||
top_p (float, optional): Nulceus sampling, higher means model considers more less likely words
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
||||
@ -205,7 +206,7 @@ def generate(
|
||||
temp: float = 0.0,
|
||||
max_tokens: int = 100,
|
||||
verbose: bool = False,
|
||||
formatter: Callable = None,
|
||||
formatter: Optional[Callable] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
@ -357,14 +358,14 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
||||
def load(
|
||||
path_or_hf_repo: str,
|
||||
tokenizer_config={},
|
||||
adapter_file: str = None,
|
||||
adapter_file: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
"""
|
||||
Load the model and tokenizer from a given path or a huggingface repository.
|
||||
|
||||
Args:
|
||||
model_path (Path): The path or the huggingface repository to load the model from.
|
||||
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
||||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
||||
Defaults to an empty dictionary.
|
||||
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.
|
||||
|
Loading…
Reference in New Issue
Block a user