mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 01:42:31 +08:00
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints * Add type hints * Simplify expression * Type hint fix * Improved do_POST logic Use a map of endpoints to methods to reduce redundancy in code * Fix format * Improve redundancy Call method dynamically instead of writing out all arguments twice * Send response instead of returning * Fix typo * Revert change * Make adapter_file as Optional * Mark formatter as optional * format * Create message generator Store response data that stays static for the duration of the response inside of the object: system_fingerprint request_id object_type requested_model Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline * Remove leftover * Update parameters to reflect new object structure No longer pass all arguments between functions, but use the stores values inside of the object * Parse body before calling request specific methods * Call super init * Update server.py * Fixed outdated documentation parameter name * Add documentation * Fix sending headers twice During testing I found that when using the streaming option, headers have always been sent twice. This should fix that * Simplify streaming code by using guard clauses Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing * Bug fix * Use Content-Length header Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion. * Update utils.py * Add top_p documentation * Type hint model and tokenizer as required * Use static system fingerprint System fingerprint now stays the same across requests * Make type hint more specific * Bug Fix Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead. Mark upload_repo as optional * Move more of the shared code into do_POST Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form. * Store stop_id_sequences as lists instead of np During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported. * Update stop_id_sequences docs * Turn if check to non-inclusive Only continue if buffer is smaller * Documentation fix * Cleared method names Instead of handle_stream and generate_competion, we should name it handle_completion. Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive * Make comment clearer * fix format * format
This commit is contained in:
@@ -114,7 +114,7 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
temp: 0.0,
|
||||
temp: float = 0.0,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = 20,
|
||||
top_p: float = 1.0,
|
||||
@@ -128,6 +128,7 @@ def generate_step(
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
|
||||
top_p (float, optional): Nulceus sampling, higher means model considers more less likely words
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
||||
@@ -205,7 +206,7 @@ def generate(
|
||||
temp: float = 0.0,
|
||||
max_tokens: int = 100,
|
||||
verbose: bool = False,
|
||||
formatter: Callable = None,
|
||||
formatter: Optional[Callable] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
@@ -357,14 +358,14 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
||||
def load(
|
||||
path_or_hf_repo: str,
|
||||
tokenizer_config={},
|
||||
adapter_file: str = None,
|
||||
adapter_file: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
"""
|
||||
Load the model and tokenizer from a given path or a huggingface repository.
|
||||
|
||||
Args:
|
||||
model_path (Path): The path or the huggingface repository to load the model from.
|
||||
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
||||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
||||
Defaults to an empty dictionary.
|
||||
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.
|
||||
|
||||
Reference in New Issue
Block a user