Refactoring of mlx_lm example (#501)

* Use named tuple from typing for typehints

* Add type hints

* Simplify expression

* Type hint fix

* Improved do_POST logic

Use a map of endpoints to methods to reduce redundancy in code

* Fix format

* Improve redundancy

Call method dynamically instead of writing out all arguments twice

* Send response instead of returning

* Fix typo

* Revert change

* Make adapter_file as Optional

* Mark formatter as optional

* format

* Create message generator

Store response data that stays static for the duration of the response inside of the object:

system_fingerprint
request_id
object_type
requested_model

Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline

* Remove leftover

* Update parameters to reflect new object structure

No longer pass all arguments between functions, but use the stores values inside of the object

* Parse body before calling request specific methods

* Call super init

* Update server.py

* Fixed outdated documentation parameter name

* Add documentation

* Fix sending headers twice

During testing I found that when using the streaming option, headers have always been sent twice. This should fix that

* Simplify streaming code by using guard clauses

Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing

* Bug fix

* Use Content-Length header

Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.

* Update utils.py

* Add top_p documentation

* Type hint model and tokenizer as required

* Use static system fingerprint

System fingerprint now stays the same across requests

* Make type hint more specific

* Bug Fix

Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.

Mark upload_repo as optional

* Move more of the shared code into do_POST

Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.

* Store stop_id_sequences as lists instead of np

During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.

* Update stop_id_sequences docs

* Turn if check to non-inclusive

Only continue if buffer is smaller

* Documentation fix

* Cleared method names

Instead of handle_stream and generate_competion, we should name it handle_completion.

Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive

* Make comment clearer

* fix format

* format
This commit is contained in:
Y4hL 2024-03-06 16:24:31 +02:00 committed by GitHub
parent 710c552731
commit b8e5eda4fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 280 additions and 303 deletions

View File

@ -12,5 +12,4 @@ MLX Examples was developed with contributions from the following individuals:
- Shunta Saito: Added support for PLaMo models. - Shunta Saito: Added support for PLaMo models.
- Gabrijel Boduljak: Implemented `CLIP`. - Gabrijel Boduljak: Implemented `CLIP`.
- Markus Enzweiler: Added the `cvae` examples. - Markus Enzweiler: Added the `cvae` examples.
- Rasmus Kinnunen: Fixed a security hole in the `llms/mlx_lm` example
- Prince Canuma: Helped add support for `Starcoder2` models. - Prince Canuma: Helped add support for `Starcoder2` models.

View File

@ -5,6 +5,7 @@ import glob
import json import json
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -109,7 +110,7 @@ def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
def merge( def merge(
config: str, config: str,
mlx_path: str = "mlx_model", mlx_path: str = "mlx_model",
upload_repo: str = None, upload_repo: Optional[str] = None,
): ):
with open(config, "r") as fid: with open(config, "r") as fid:
merge_conf = yaml.safe_load(fid) merge_conf = yaml.safe_load(fid)
@ -117,7 +118,7 @@ def merge(
model_paths = merge_conf.get("models", []) model_paths = merge_conf.get("models", [])
if len(model_paths) < 2: if len(model_paths) < 2:
raise ValueError(f"Expected at least 2 models, got {len(models)}.") raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.")
# Load all models # Load all models
base_hf_path = model_paths[0] base_hf_path = model_paths[0]
@ -125,9 +126,9 @@ def merge(
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True) base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
models = [] models = []
for mp in model_paths[1:]: for mp in model_paths[1:]:
model, config, _ = fetch_from_hub(get_model_path(mp), lazy=True) model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
base_type = base_config["model_type"] base_type = base_config["model_type"]
model_type = config["model_type"] model_type = model_config["model_type"]
if base_type != model_type: if base_type != model_type:
raise ValueError( raise ValueError(
f"Can only merge models of the same type," f"Can only merge models of the same type,"

View File

@ -5,43 +5,39 @@ import json
import time import time
import uuid import uuid
import warnings import warnings
from collections import namedtuple
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Callable, List, Optional from typing import List, Literal, NamedTuple, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from .utils import generate_step, load from .utils import generate_step, load
_model: Optional[nn.Module] = None MODEL: nn.Module
_tokenizer: Optional[PreTrainedTokenizer] = None TOKENIZER: PreTrainedTokenizer
SYSTEM_FINGERPRINT: str = f"fp_{uuid.uuid4()}"
def load_model(model_path: str, adapter_file: Optional[str] = None): class StopCondition(NamedTuple):
global _model stop_met: bool
global _tokenizer trim_length: int
_model, _tokenizer = load(model_path, adapter_file=adapter_file)
StopCondition = namedtuple("StopCondition", ["stop_met", "trim_length"])
def stopping_criteria( def stopping_criteria(
tokens: List[int], tokens: List[int],
stop_id_sequences: List[np.ndarray], stop_id_sequences: List[List[int]],
eos_token_id: int, eos_token_id: Union[int, None],
) -> StopCondition: ) -> StopCondition:
""" """
Determines whether the token generation should stop based on predefined conditions. Determines whether the token generation should stop based on predefined conditions.
Args: Args:
tokens (List[int]): The current sequence of generated tokens. tokens (List[int]): The current sequence of generated tokens.
stop_id_sequences (List[np.ndarray]): A list of numpy arrays, each representing a sequence of token IDs. stop_id_sequences (List[List[[int]]): A list of integer lists, each representing a sequence of token IDs.
If the end of the `tokens` list matches any of these sequences, the generation should stop. If the end of the `tokens` list matches any of these sequences, the generation should stop.
eos_token_id (int): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this, eos_token_id (Union[int, None]): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this,
the generation should stop. the generation should stop.
Returns: Returns:
@ -53,13 +49,13 @@ def stopping_criteria(
for stop_ids in stop_id_sequences: for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids): if len(tokens) >= len(stop_ids):
if np.array_equal(tokens[-len(stop_ids) :], stop_ids): if tokens[-len(stop_ids) :] == stop_ids:
return StopCondition(stop_met=True, trim_length=len(stop_ids)) return StopCondition(stop_met=True, trim_length=len(stop_ids))
return StopCondition(stop_met=False, trim_length=0) return StopCondition(stop_met=False, trim_length=0)
def convert_chat(messages: any, role_mapping: Optional[dict] = None): def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
default_role_mapping = { default_role_mapping = {
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.", "system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
"system": "ASSISTANT's RULE: ", "system": "ASSISTANT's RULE: ",
@ -80,344 +76,324 @@ def convert_chat(messages: any, role_mapping: Optional[dict] = None):
return prompt.rstrip() return prompt.rstrip()
def create_chat_response(chat_id, requested_model, prompt, tokens, text):
response = {
"id": chat_id,
"object": "chat.completion",
"created": int(time.time()),
"model": requested_model,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": text,
},
"logprobs": None,
"finish_reason": None,
}
],
"usage": {
"prompt_tokens": len(prompt),
"completion_tokens": len(tokens),
"total_tokens": len(prompt) + len(tokens),
},
}
return response
def create_completion_response(completion_id, requested_model, prompt, tokens, text):
return {
"id": completion_id,
"object": "text_completion",
"created": int(time.time()),
"model": requested_model,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"choices": [
{"text": text, "index": 0, "logprobs": None, "finish_reason": "length"}
],
"usage": {
"prompt_tokens": len(prompt),
"completion_tokens": len(tokens),
"total_tokens": len(prompt) + len(tokens),
},
}
def create_chat_chunk_response(chat_id, requested_model, next_chunk):
response = {
"id": chat_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": requested_model,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": next_chunk},
"logprobs": None,
"finish_reason": None,
}
],
}
return response
def create_completion_chunk_response(completion_id, requested_model, next_chunk):
return {
"id": completion_id,
"object": "text_completion",
"created": int(time.time()),
"choices": [
{"text": next_chunk, "index": 0, "logprobs": None, "finish_reason": None}
],
"model": requested_model,
"system_fingerprint": f"fp_{uuid.uuid4()}",
}
class APIHandler(BaseHTTPRequestHandler): class APIHandler(BaseHTTPRequestHandler):
def __init__(self, *args, **kwargs):
"""
Create static request specific metadata
"""
self.created = int(time.time())
super().__init__(*args, **kwargs)
def _set_headers(self, status_code=200): def _set_completion_headers(self, status_code: int = 200):
self.send_response(status_code) self.send_response(status_code)
self.send_header("Content-type", "application/json") self.send_header("Content-type", "application/json")
self.send_header("Access-Control-Allow-Origin", "*") self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Methods", "*") self.send_header("Access-Control-Allow-Methods", "*")
self.send_header("Access-Control-Allow-Headers", "*") self.send_header("Access-Control-Allow-Headers", "*")
self.end_headers()
def _set_stream_headers(self, status_code: int = 200):
self.send_response(status_code)
self.send_header("Content-type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
def do_OPTIONS(self): def do_OPTIONS(self):
self._set_headers(204) self._set_completion_headers(204)
self.end_headers()
def do_POST(self): def do_POST(self):
if self.path == "/v1/chat/completions": """
content_length = int(self.headers["Content-Length"]) Respond to a POST request from a client
post_data = self.rfile.read(content_length) """
self._set_headers(200) endpoints = {
"/v1/completions": self.handle_text_completions,
"/v1/chat/completions": self.handle_chat_completions,
}
response = self.handle_chat_completions(post_data) if self.path not in endpoints:
self._set_completion_headers(404)
self.wfile.write(json.dumps(response).encode("utf-8")) self.end_headers()
elif self.path == "/v1/completions":
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)
self._set_headers(200)
response = self.handle_completions(post_data)
self.wfile.write(json.dumps(response).encode("utf-8"))
else:
self._set_headers(404)
self.wfile.write(b"Not Found") self.wfile.write(b"Not Found")
return
# Fetch and parse request body
content_length = int(self.headers["Content-Length"])
raw_body = self.rfile.read(content_length)
self.body = json.loads(raw_body.decode())
assert isinstance(
self.body, dict
), f"Request should be dict, but got {type(self.body)}"
# Extract request parameters from the body
self.stream = self.body.get("stream", False)
self.requested_model = self.body.get("model", "default_model")
self.max_tokens = self.body.get("max_tokens", 100)
self.temperature = self.body.get("temperature", 1.0)
self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
self.repetition_context_size = self.body.get("repetition_context_size", 20)
# Get stop id sequences, if provided
stop_words = self.body.get("stop", [])
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
TOKENIZER.encode(stop_word, add_special_tokens=False)
for stop_word in stop_words
]
# Send header type
(
self._set_stream_headers(200)
if self.stream
else self._set_completion_headers(200)
)
# Call endpoint specific method
prompt = endpoints[self.path]()
# Call method based on response type
method = self.handle_stream if self.stream else self.handle_completion
method(prompt, stop_id_sequences)
def generate_response( def generate_response(
self, self,
prompt: mx.array, text: str,
response_id: str, finish_reason: Union[Literal["length", "stop"], None],
requested_model: str, prompt_token_count: Optional[int] = None,
stop_id_sequences: List[np.ndarray], completion_token_count: Optional[int] = None,
eos_token_id: int, ) -> dict:
max_tokens: int, """
temperature: float, Generate a single response packet based on response type (stream or not),
top_p: float, completion type and parameters
repetition_penalty: Optional[float],
repetition_context_size: Optional[int], Args:
response_creator: Callable[[str, str, mx.array, List[int], str], dict], text (str): Text generated by model
finish_reason (Union[Literal["length", "stop"], None]):
The reason the response is being sent: "length", "stop" or None
prompt_token_count (Optional[int]):
The amount of tokens in the prompt,
used to populate the "usage" field (not used when stream)
completion_token_count (Optional[int]):
The amount of tokens in the response,
used to populate the "usage" field (not used when stream)
Returns:
dict: A dictionary containing the response, imitating OpenAI's API
"""
# Static response
response = {
"id": self.request_id,
"system_fingerprint": SYSTEM_FINGERPRINT,
"object": self.object_type,
"model": self.requested_model,
"created": self.created,
"choices": [
{
"index": 0,
"logprobs": None,
"finish_reason": finish_reason,
}
],
}
if not self.stream:
if not (
isinstance(prompt_token_count, int)
and isinstance(completion_token_count, int)
): ):
raise ValueError(
"Response type is complete, but token counts not provided"
)
response["usage"] = {
"prompt_tokens": prompt_token_count,
"completion_tokens": completion_token_count,
"total_tokens": prompt_token_count + completion_token_count,
}
choice = response["choices"][0]
# Add dynamic response
if self.object_type.startswith("chat.completion"):
key_name = "delta" if self.stream else "message"
choice[key_name] = {"role": "assistant", "content": text}
elif self.object_type == "text_completion":
choice.update(text=text)
else:
ValueError(f"Unsupported response type: {self.object_type}")
return response
def handle_completion(
self,
prompt: mx.array,
stop_id_sequences: List[List[int]],
):
"""
Generate a response to a prompt and send it to the client in a single batch
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
stop_id_sequences (List[List[int]]):
A list of stop words passed to the stopping_criteria function
"""
tokens = [] tokens = []
for (token, _), _ in zip( for (token, _), _ in zip(
generate_step( generate_step(
prompt=prompt, prompt=prompt,
model=_model, model=MODEL,
temp=temperature, temp=self.temperature,
top_p=top_p, top_p=self.top_p,
repetition_penalty=repetition_penalty, repetition_penalty=self.repetition_penalty,
repetition_context_size=repetition_context_size, repetition_context_size=self.repetition_context_size,
), ),
range(max_tokens), range(self.max_tokens),
): ):
token = token.item() token = token.item()
tokens.append(token) tokens.append(token)
stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id) stop_condition = stopping_criteria(
tokens, stop_id_sequences, TOKENIZER.eos_token_id
)
if stop_condition.stop_met: if stop_condition.stop_met:
if stop_condition.trim_length: if stop_condition.trim_length:
tokens = tokens[: -stop_condition.trim_length] tokens = tokens[: -stop_condition.trim_length]
break break
text = _tokenizer.decode(tokens) text = TOKENIZER.decode(tokens)
return response_creator(response_id, requested_model, prompt, tokens, text) response = self.generate_response(text, "stop", len(prompt), len(tokens))
response_json = json.dumps(response).encode()
# Send an additional Content-Length header when it is known
self.send_header("Content-Length", str(len(response_json)))
self.end_headers()
self.wfile.write(response_json)
self.wfile.flush()
def handle_stream( def handle_stream(
self, self,
prompt: mx.array, prompt: mx.array,
response_id: str, stop_id_sequences: List[List[int]],
requested_model: str,
stop_id_sequences: List[np.ndarray],
eos_token_id: int,
max_tokens: int,
temperature: float,
top_p: float,
repetition_penalty: Optional[float],
repetition_context_size: Optional[int],
response_creator: Callable[[str, str, str], dict],
): ):
self.send_response(200) """
self.send_header("Content-type", "text/event-stream") Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream
self.send_header("Cache-Control", "no-cache")
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
stop_id_sequences (List[List[int]]):
A list of stop words passed to the stopping_criteria function
"""
# No additional headers are needed, call end_headers
self.end_headers() self.end_headers()
max_stop_id_sequence_len = (
max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
)
tokens = [] tokens = []
current_generated_text_index = 0 current_generated_text_index = 0
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
# Buffer to store the last `max_stop_id_sequence_len` tokens
# to check for stop conditions before writing to the stream.
stop_sequence_buffer = [] stop_sequence_buffer = []
REPLACEMENT_CHAR = "\ufffd"
for (token, _), _ in zip( for (token, _), _ in zip(
generate_step( generate_step(
prompt=prompt, prompt=prompt,
model=_model, model=MODEL,
temp=temperature, temp=self.temperature,
top_p=top_p, top_p=self.top_p,
repetition_penalty=repetition_penalty, repetition_penalty=self.repetition_penalty,
repetition_context_size=repetition_context_size, repetition_context_size=self.repetition_context_size,
), ),
range(max_tokens), range(self.max_tokens),
): ):
token = token.item() token = token.item()
tokens.append(token) tokens.append(token)
stop_sequence_buffer.append(token) stop_sequence_buffer.append(token)
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
if REPLACEMENT_CHAR in _tokenizer.decode(token): # Continue generating tokens until buffer is as large as the longest stop_id_sequence
if len(stop_sequence_buffer) < max_stop_id_sequence_len:
continue continue
# "\ufffd" is used to indicate to the tokenizer, that subsequent characters
# should be combined into a single unicode character
if "\ufffd" in TOKENIZER.decode(token):
continue
stop_condition = stopping_criteria( stop_condition = stopping_criteria(
tokens, tokens,
stop_id_sequences, stop_id_sequences,
eos_token_id, TOKENIZER.eos_token_id,
) )
if stop_condition.stop_met: if stop_condition.stop_met:
if stop_condition.trim_length: if stop_condition.trim_length:
tokens = tokens[: -stop_condition.trim_length] tokens = tokens[: -stop_condition.trim_length]
break break
# This is a workaround because the llama tokenizer emits spaces when decoding token by token.
generated_text = _tokenizer.decode(tokens) # Workaround for llama tokenizer emitting spaces when decoding token by token.
next_chunk = generated_text[current_generated_text_index:] generated_text = TOKENIZER.decode(tokens)
new_text = generated_text[current_generated_text_index:]
current_generated_text_index = len(generated_text) current_generated_text_index = len(generated_text)
response = response_creator(response_id, requested_model, next_chunk) response = self.generate_response(new_text, None)
try:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush() self.wfile.flush()
stop_sequence_buffer = [] stop_sequence_buffer = []
except Exception as e:
print(e)
break
# check is there any remaining text to send # check is there any remaining text to send
if stop_sequence_buffer: if stop_sequence_buffer:
generated_text = _tokenizer.decode(tokens) generated_text = TOKENIZER.decode(tokens)
next_chunk = generated_text[current_generated_text_index:] next_chunk = generated_text[current_generated_text_index:]
response = response_creator(response_id, requested_model, next_chunk) response = self.generate_response(next_chunk, "length")
try:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush() self.wfile.flush()
except Exception as e:
print(e)
self.wfile.write(f"data: [DONE]\n\n".encode()) self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush() self.wfile.flush()
def handle_chat_completions(self, post_data: bytes): def handle_chat_completions(self) -> mx.array:
body = json.loads(post_data.decode("utf-8")) """
chat_id = f"chatcmpl-{uuid.uuid4()}" Handle a chat completion request
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
prompt = _tokenizer.apply_chat_template( Returns:
mx.array: A mx.array of the tokenized prompt from the request body
"""
body = self.body
assert "messages" in body, "Request did not contain messages"
# Determine response type
self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = (
"chat.completions.chunk" if self.stream else "chat.completions"
)
if hasattr(TOKENIZER, "apply_chat_template") and TOKENIZER.chat_template:
prompt = TOKENIZER.apply_chat_template(
body["messages"], body["messages"],
tokenize=True, tokenize=True,
add_generation_prompt=True, add_generation_prompt=True,
return_tensors="np",
) )
else: else:
prompt = convert_chat(body["messages"], body.get("role_mapping")) prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = _tokenizer.encode(prompt, return_tensors="np") prompt = TOKENIZER.encode(prompt)
prompt = mx.array(prompt[0]) return mx.array(prompt)
stop_words = body.get("stop", [])
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
0
]
for stop_word in stop_words
]
eos_token_id = _tokenizer.eos_token_id
max_tokens = body.get("max_tokens", 100)
stream = body.get("stream", False)
requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
repetition_penalty = body.get("repetition_penalty", 1.0)
repetition_context_size = body.get("repetition_context_size", 20)
if not stream:
return self.generate_response(
prompt,
chat_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
repetition_penalty,
repetition_context_size,
create_chat_response,
)
else:
self.handle_stream(
prompt,
chat_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
repetition_penalty,
repetition_context_size,
create_chat_chunk_response,
)
def handle_completions(self, post_data: bytes): def handle_text_completions(self) -> mx.array:
body = json.loads(post_data.decode("utf-8")) """
completion_id = f"cmpl-{uuid.uuid4()}" Handle a text completion request
prompt_text = body["prompt"]
prompt = _tokenizer.encode(prompt_text, return_tensors="np") Returns:
prompt = mx.array(prompt[0]) mx.array: A mx.array of the tokenized prompt from the request body
stop_words = body.get("stop", []) """
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words # Determine response type
stop_id_sequences = [ self.request_id = f"cmpl-{uuid.uuid4()}"
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[ self.object_type = "text_completion"
0
] assert "prompt" in self.body, "Request did not contain a prompt"
for stop_word in stop_words prompt_text = self.body["prompt"]
] prompt = TOKENIZER.encode(prompt_text)
eos_token_id = _tokenizer.eos_token_id return mx.array(prompt)
max_tokens = body.get("max_tokens", 100)
stream = body.get("stream", False)
requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
repetition_penalty = body.get("repetition_penalty", 1.0)
repetition_context_size = body.get("repetition_context_size", 20)
if not stream:
return self.generate_response(
prompt,
completion_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
repetition_penalty,
repetition_context_size,
create_completion_response,
)
else:
self.handle_stream(
prompt,
completion_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
repetition_penalty,
repetition_context_size,
create_completion_chunk_response,
)
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler): def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
@ -458,6 +434,6 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
load_model(args.model, adapter_file=args.adapter_file) MODEL, TOKENIZER = load(args.model, adapter_file=args.adapter_file)
run(args.host, args.port) run(args.host, args.port)

View File

@ -114,7 +114,7 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
def generate_step( def generate_step(
prompt: mx.array, prompt: mx.array,
model: nn.Module, model: nn.Module,
temp: 0.0, temp: float = 0.0,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20, repetition_context_size: Optional[int] = 20,
top_p: float = 1.0, top_p: float = 1.0,
@ -128,6 +128,7 @@ def generate_step(
temp (float): The temperature for sampling, if 0 the argmax is used. temp (float): The temperature for sampling, if 0 the argmax is used.
repetition_penalty (float, optional): The penalty factor for repeating tokens. repetition_penalty (float, optional): The penalty factor for repeating tokens.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20). repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
top_p (float, optional): Nulceus sampling, higher means model considers more less likely words
Yields: Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing Generator[Tuple[mx.array, mx.array]]: A generator producing
@ -205,7 +206,7 @@ def generate(
temp: float = 0.0, temp: float = 0.0,
max_tokens: int = 100, max_tokens: int = 100,
verbose: bool = False, verbose: bool = False,
formatter: Callable = None, formatter: Optional[Callable] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None, repetition_context_size: Optional[int] = None,
top_p: float = 1.0, top_p: float = 1.0,
@ -357,14 +358,14 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
def load( def load(
path_or_hf_repo: str, path_or_hf_repo: str,
tokenizer_config={}, tokenizer_config={},
adapter_file: str = None, adapter_file: Optional[str] = None,
lazy: bool = False, lazy: bool = False,
) -> Tuple[nn.Module, PreTrainedTokenizer]: ) -> Tuple[nn.Module, PreTrainedTokenizer]:
""" """
Load the model and tokenizer from a given path or a huggingface repository. Load the model and tokenizer from a given path or a huggingface repository.
Args: Args:
model_path (Path): The path or the huggingface repository to load the model from. path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary. Defaults to an empty dictionary.
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model. adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.