Refactoring of mlx_lm example (#501)

* Use named tuple from typing for typehints

* Add type hints

* Simplify expression

* Type hint fix

* Improved do_POST logic

Use a map of endpoints to methods to reduce redundancy in code

* Fix format

* Improve redundancy

Call method dynamically instead of writing out all arguments twice

* Send response instead of returning

* Fix typo

* Revert change

* Make adapter_file as Optional

* Mark formatter as optional

* format

* Create message generator

Store response data that stays static for the duration of the response inside of the object:

system_fingerprint
request_id
object_type
requested_model

Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline

* Remove leftover

* Update parameters to reflect new object structure

No longer pass all arguments between functions, but use the stores values inside of the object

* Parse body before calling request specific methods

* Call super init

* Update server.py

* Fixed outdated documentation parameter name

* Add documentation

* Fix sending headers twice

During testing I found that when using the streaming option, headers have always been sent twice. This should fix that

* Simplify streaming code by using guard clauses

Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing

* Bug fix

* Use Content-Length header

Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.

* Update utils.py

* Add top_p documentation

* Type hint model and tokenizer as required

* Use static system fingerprint

System fingerprint now stays the same across requests

* Make type hint more specific

* Bug Fix

Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.

Mark upload_repo as optional

* Move more of the shared code into do_POST

Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.

* Store stop_id_sequences as lists instead of np

During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.

* Update stop_id_sequences docs

* Turn if check to non-inclusive

Only continue if buffer is smaller

* Documentation fix

* Cleared method names

Instead of handle_stream and generate_competion, we should name it handle_completion.

Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive

* Make comment clearer

* fix format

* format
This commit is contained in:
Y4hL
2024-03-06 16:24:31 +02:00
committed by GitHub
parent 710c552731
commit b8e5eda4fd
4 changed files with 280 additions and 303 deletions

View File

@@ -114,7 +114,7 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
def generate_step(
prompt: mx.array,
model: nn.Module,
temp: 0.0,
temp: float = 0.0,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
top_p: float = 1.0,
@@ -128,6 +128,7 @@ def generate_step(
temp (float): The temperature for sampling, if 0 the argmax is used.
repetition_penalty (float, optional): The penalty factor for repeating tokens.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
top_p (float, optional): Nulceus sampling, higher means model considers more less likely words
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing
@@ -205,7 +206,7 @@ def generate(
temp: float = 0.0,
max_tokens: int = 100,
verbose: bool = False,
formatter: Callable = None,
formatter: Optional[Callable] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: float = 1.0,
@@ -357,14 +358,14 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
def load(
path_or_hf_repo: str,
tokenizer_config={},
adapter_file: str = None,
adapter_file: Optional[str] = None,
lazy: bool = False,
) -> Tuple[nn.Module, PreTrainedTokenizer]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
Args:
model_path (Path): The path or the huggingface repository to load the model from.
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.