From b8e5eda4fdbddfa8f2fdf740b7defbb56f58db1f Mon Sep 17 00:00:00 2001 From: Y4hL <43219534+Y4hL@users.noreply.github.com> Date: Wed, 6 Mar 2024 16:24:31 +0200 Subject: [PATCH] Refactoring of mlx_lm example (#501) * Use named tuple from typing for typehints * Add type hints * Simplify expression * Type hint fix * Improved do_POST logic Use a map of endpoints to methods to reduce redundancy in code * Fix format * Improve redundancy Call method dynamically instead of writing out all arguments twice * Send response instead of returning * Fix typo * Revert change * Make adapter_file as Optional * Mark formatter as optional * format * Create message generator Store response data that stays static for the duration of the response inside of the object: system_fingerprint request_id object_type requested_model Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline * Remove leftover * Update parameters to reflect new object structure No longer pass all arguments between functions, but use the stores values inside of the object * Parse body before calling request specific methods * Call super init * Update server.py * Fixed outdated documentation parameter name * Add documentation * Fix sending headers twice During testing I found that when using the streaming option, headers have always been sent twice. This should fix that * Simplify streaming code by using guard clauses Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing * Bug fix * Use Content-Length header Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion. * Update utils.py * Add top_p documentation * Type hint model and tokenizer as required * Use static system fingerprint System fingerprint now stays the same across requests * Make type hint more specific * Bug Fix Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead. Mark upload_repo as optional * Move more of the shared code into do_POST Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form. * Store stop_id_sequences as lists instead of np During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported. * Update stop_id_sequences docs * Turn if check to non-inclusive Only continue if buffer is smaller * Documentation fix * Cleared method names Instead of handle_stream and generate_competion, we should name it handle_completion. Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive * Make comment clearer * fix format * format --- ACKNOWLEDGMENTS.md | 1 - llms/mlx_lm/merge.py | 9 +- llms/mlx_lm/server.py | 564 ++++++++++++++++++++---------------------- llms/mlx_lm/utils.py | 9 +- 4 files changed, 280 insertions(+), 303 deletions(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index f9528f38..8d8557ca 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -12,5 +12,4 @@ MLX Examples was developed with contributions from the following individuals: - Shunta Saito: Added support for PLaMo models. - Gabrijel Boduljak: Implemented `CLIP`. - Markus Enzweiler: Added the `cvae` examples. -- Rasmus Kinnunen: Fixed a security hole in the `llms/mlx_lm` example - Prince Canuma: Helped add support for `Starcoder2` models. diff --git a/llms/mlx_lm/merge.py b/llms/mlx_lm/merge.py index 46fb87a8..affd034c 100644 --- a/llms/mlx_lm/merge.py +++ b/llms/mlx_lm/merge.py @@ -5,6 +5,7 @@ import glob import json import shutil from pathlib import Path +from typing import Optional import mlx.core as mx import mlx.nn as nn @@ -109,7 +110,7 @@ def merge_models(base_model: nn.Module, model: nn.Module, config: dict): def merge( config: str, mlx_path: str = "mlx_model", - upload_repo: str = None, + upload_repo: Optional[str] = None, ): with open(config, "r") as fid: merge_conf = yaml.safe_load(fid) @@ -117,7 +118,7 @@ def merge( model_paths = merge_conf.get("models", []) if len(model_paths) < 2: - raise ValueError(f"Expected at least 2 models, got {len(models)}.") + raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.") # Load all models base_hf_path = model_paths[0] @@ -125,9 +126,9 @@ def merge( base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True) models = [] for mp in model_paths[1:]: - model, config, _ = fetch_from_hub(get_model_path(mp), lazy=True) + model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True) base_type = base_config["model_type"] - model_type = config["model_type"] + model_type = model_config["model_type"] if base_type != model_type: raise ValueError( f"Can only merge models of the same type," diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 561438c7..aa395fef 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -5,43 +5,39 @@ import json import time import uuid import warnings -from collections import namedtuple from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Callable, List, Optional +from typing import List, Literal, NamedTuple, Optional, Union import mlx.core as mx import mlx.nn as nn -import numpy as np from transformers import PreTrainedTokenizer from .utils import generate_step, load -_model: Optional[nn.Module] = None -_tokenizer: Optional[PreTrainedTokenizer] = None +MODEL: nn.Module +TOKENIZER: PreTrainedTokenizer + +SYSTEM_FINGERPRINT: str = f"fp_{uuid.uuid4()}" -def load_model(model_path: str, adapter_file: Optional[str] = None): - global _model - global _tokenizer - _model, _tokenizer = load(model_path, adapter_file=adapter_file) - - -StopCondition = namedtuple("StopCondition", ["stop_met", "trim_length"]) +class StopCondition(NamedTuple): + stop_met: bool + trim_length: int def stopping_criteria( tokens: List[int], - stop_id_sequences: List[np.ndarray], - eos_token_id: int, + stop_id_sequences: List[List[int]], + eos_token_id: Union[int, None], ) -> StopCondition: """ Determines whether the token generation should stop based on predefined conditions. Args: tokens (List[int]): The current sequence of generated tokens. - stop_id_sequences (List[np.ndarray]): A list of numpy arrays, each representing a sequence of token IDs. + stop_id_sequences (List[List[[int]]): A list of integer lists, each representing a sequence of token IDs. If the end of the `tokens` list matches any of these sequences, the generation should stop. - eos_token_id (int): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this, + eos_token_id (Union[int, None]): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this, the generation should stop. Returns: @@ -53,13 +49,13 @@ def stopping_criteria( for stop_ids in stop_id_sequences: if len(tokens) >= len(stop_ids): - if np.array_equal(tokens[-len(stop_ids) :], stop_ids): + if tokens[-len(stop_ids) :] == stop_ids: return StopCondition(stop_met=True, trim_length=len(stop_ids)) return StopCondition(stop_met=False, trim_length=0) -def convert_chat(messages: any, role_mapping: Optional[dict] = None): +def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): default_role_mapping = { "system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.", "system": "ASSISTANT's RULE: ", @@ -80,344 +76,324 @@ def convert_chat(messages: any, role_mapping: Optional[dict] = None): return prompt.rstrip() -def create_chat_response(chat_id, requested_model, prompt, tokens, text): - response = { - "id": chat_id, - "object": "chat.completion", - "created": int(time.time()), - "model": requested_model, - "system_fingerprint": f"fp_{uuid.uuid4()}", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": text, - }, - "logprobs": None, - "finish_reason": None, - } - ], - "usage": { - "prompt_tokens": len(prompt), - "completion_tokens": len(tokens), - "total_tokens": len(prompt) + len(tokens), - }, - } - - return response - - -def create_completion_response(completion_id, requested_model, prompt, tokens, text): - return { - "id": completion_id, - "object": "text_completion", - "created": int(time.time()), - "model": requested_model, - "system_fingerprint": f"fp_{uuid.uuid4()}", - "choices": [ - {"text": text, "index": 0, "logprobs": None, "finish_reason": "length"} - ], - "usage": { - "prompt_tokens": len(prompt), - "completion_tokens": len(tokens), - "total_tokens": len(prompt) + len(tokens), - }, - } - - -def create_chat_chunk_response(chat_id, requested_model, next_chunk): - response = { - "id": chat_id, - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": requested_model, - "system_fingerprint": f"fp_{uuid.uuid4()}", - "choices": [ - { - "index": 0, - "delta": {"role": "assistant", "content": next_chunk}, - "logprobs": None, - "finish_reason": None, - } - ], - } - return response - - -def create_completion_chunk_response(completion_id, requested_model, next_chunk): - return { - "id": completion_id, - "object": "text_completion", - "created": int(time.time()), - "choices": [ - {"text": next_chunk, "index": 0, "logprobs": None, "finish_reason": None} - ], - "model": requested_model, - "system_fingerprint": f"fp_{uuid.uuid4()}", - } - - class APIHandler(BaseHTTPRequestHandler): + def __init__(self, *args, **kwargs): + """ + Create static request specific metadata + """ + self.created = int(time.time()) + super().__init__(*args, **kwargs) - def _set_headers(self, status_code=200): + def _set_completion_headers(self, status_code: int = 200): self.send_response(status_code) self.send_header("Content-type", "application/json") self.send_header("Access-Control-Allow-Origin", "*") self.send_header("Access-Control-Allow-Methods", "*") self.send_header("Access-Control-Allow-Headers", "*") - self.end_headers() + + def _set_stream_headers(self, status_code: int = 200): + self.send_response(status_code) + self.send_header("Content-type", "text/event-stream") + self.send_header("Cache-Control", "no-cache") def do_OPTIONS(self): - self._set_headers(204) + self._set_completion_headers(204) + self.end_headers() def do_POST(self): - if self.path == "/v1/chat/completions": - content_length = int(self.headers["Content-Length"]) - post_data = self.rfile.read(content_length) - self._set_headers(200) + """ + Respond to a POST request from a client + """ + endpoints = { + "/v1/completions": self.handle_text_completions, + "/v1/chat/completions": self.handle_chat_completions, + } - response = self.handle_chat_completions(post_data) - - self.wfile.write(json.dumps(response).encode("utf-8")) - elif self.path == "/v1/completions": - content_length = int(self.headers["Content-Length"]) - post_data = self.rfile.read(content_length) - self._set_headers(200) - - response = self.handle_completions(post_data) - - self.wfile.write(json.dumps(response).encode("utf-8")) - else: - self._set_headers(404) + if self.path not in endpoints: + self._set_completion_headers(404) + self.end_headers() self.wfile.write(b"Not Found") + return + + # Fetch and parse request body + content_length = int(self.headers["Content-Length"]) + raw_body = self.rfile.read(content_length) + self.body = json.loads(raw_body.decode()) + assert isinstance( + self.body, dict + ), f"Request should be dict, but got {type(self.body)}" + + # Extract request parameters from the body + self.stream = self.body.get("stream", False) + self.requested_model = self.body.get("model", "default_model") + self.max_tokens = self.body.get("max_tokens", 100) + self.temperature = self.body.get("temperature", 1.0) + self.top_p = self.body.get("top_p", 1.0) + self.repetition_penalty = self.body.get("repetition_penalty", 1.0) + self.repetition_context_size = self.body.get("repetition_context_size", 20) + + # Get stop id sequences, if provided + stop_words = self.body.get("stop", []) + stop_words = [stop_words] if isinstance(stop_words, str) else stop_words + stop_id_sequences = [ + TOKENIZER.encode(stop_word, add_special_tokens=False) + for stop_word in stop_words + ] + + # Send header type + ( + self._set_stream_headers(200) + if self.stream + else self._set_completion_headers(200) + ) + + # Call endpoint specific method + prompt = endpoints[self.path]() + + # Call method based on response type + method = self.handle_stream if self.stream else self.handle_completion + method(prompt, stop_id_sequences) def generate_response( + self, + text: str, + finish_reason: Union[Literal["length", "stop"], None], + prompt_token_count: Optional[int] = None, + completion_token_count: Optional[int] = None, + ) -> dict: + """ + Generate a single response packet based on response type (stream or not), + completion type and parameters + + Args: + text (str): Text generated by model + finish_reason (Union[Literal["length", "stop"], None]): + The reason the response is being sent: "length", "stop" or None + prompt_token_count (Optional[int]): + The amount of tokens in the prompt, + used to populate the "usage" field (not used when stream) + completion_token_count (Optional[int]): + The amount of tokens in the response, + used to populate the "usage" field (not used when stream) + + Returns: + dict: A dictionary containing the response, imitating OpenAI's API + """ + + # Static response + response = { + "id": self.request_id, + "system_fingerprint": SYSTEM_FINGERPRINT, + "object": self.object_type, + "model": self.requested_model, + "created": self.created, + "choices": [ + { + "index": 0, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + } + + if not self.stream: + if not ( + isinstance(prompt_token_count, int) + and isinstance(completion_token_count, int) + ): + raise ValueError( + "Response type is complete, but token counts not provided" + ) + + response["usage"] = { + "prompt_tokens": prompt_token_count, + "completion_tokens": completion_token_count, + "total_tokens": prompt_token_count + completion_token_count, + } + + choice = response["choices"][0] + + # Add dynamic response + if self.object_type.startswith("chat.completion"): + key_name = "delta" if self.stream else "message" + choice[key_name] = {"role": "assistant", "content": text} + elif self.object_type == "text_completion": + choice.update(text=text) + else: + ValueError(f"Unsupported response type: {self.object_type}") + + return response + + def handle_completion( self, prompt: mx.array, - response_id: str, - requested_model: str, - stop_id_sequences: List[np.ndarray], - eos_token_id: int, - max_tokens: int, - temperature: float, - top_p: float, - repetition_penalty: Optional[float], - repetition_context_size: Optional[int], - response_creator: Callable[[str, str, mx.array, List[int], str], dict], + stop_id_sequences: List[List[int]], ): + """ + Generate a response to a prompt and send it to the client in a single batch + + Args: + prompt (mx.array): The prompt, in token form inside of a mlx array + stop_id_sequences (List[List[int]]): + A list of stop words passed to the stopping_criteria function + """ tokens = [] for (token, _), _ in zip( generate_step( prompt=prompt, - model=_model, - temp=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - repetition_context_size=repetition_context_size, + model=MODEL, + temp=self.temperature, + top_p=self.top_p, + repetition_penalty=self.repetition_penalty, + repetition_context_size=self.repetition_context_size, ), - range(max_tokens), + range(self.max_tokens), ): token = token.item() tokens.append(token) - stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id) + stop_condition = stopping_criteria( + tokens, stop_id_sequences, TOKENIZER.eos_token_id + ) if stop_condition.stop_met: if stop_condition.trim_length: tokens = tokens[: -stop_condition.trim_length] break - text = _tokenizer.decode(tokens) - return response_creator(response_id, requested_model, prompt, tokens, text) + text = TOKENIZER.decode(tokens) + response = self.generate_response(text, "stop", len(prompt), len(tokens)) + + response_json = json.dumps(response).encode() + + # Send an additional Content-Length header when it is known + self.send_header("Content-Length", str(len(response_json))) + self.end_headers() + + self.wfile.write(response_json) + self.wfile.flush() def handle_stream( self, prompt: mx.array, - response_id: str, - requested_model: str, - stop_id_sequences: List[np.ndarray], - eos_token_id: int, - max_tokens: int, - temperature: float, - top_p: float, - repetition_penalty: Optional[float], - repetition_context_size: Optional[int], - response_creator: Callable[[str, str, str], dict], + stop_id_sequences: List[List[int]], ): - self.send_response(200) - self.send_header("Content-type", "text/event-stream") - self.send_header("Cache-Control", "no-cache") + """ + Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream + + Args: + prompt (mx.array): The prompt, in token form inside of a mlx array + stop_id_sequences (List[List[int]]): + A list of stop words passed to the stopping_criteria function + """ + # No additional headers are needed, call end_headers self.end_headers() - max_stop_id_sequence_len = ( - max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0 - ) + tokens = [] current_generated_text_index = 0 - # Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream. + + max_stop_id_sequence_len = len(max(stop_id_sequences, default=[])) + # Buffer to store the last `max_stop_id_sequence_len` tokens + # to check for stop conditions before writing to the stream. stop_sequence_buffer = [] - REPLACEMENT_CHAR = "\ufffd" + for (token, _), _ in zip( generate_step( prompt=prompt, - model=_model, - temp=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - repetition_context_size=repetition_context_size, + model=MODEL, + temp=self.temperature, + top_p=self.top_p, + repetition_penalty=self.repetition_penalty, + repetition_context_size=self.repetition_context_size, ), - range(max_tokens), + range(self.max_tokens), ): token = token.item() tokens.append(token) stop_sequence_buffer.append(token) - if len(stop_sequence_buffer) > max_stop_id_sequence_len: - if REPLACEMENT_CHAR in _tokenizer.decode(token): - continue - stop_condition = stopping_criteria( - tokens, - stop_id_sequences, - eos_token_id, - ) - if stop_condition.stop_met: - if stop_condition.trim_length: - tokens = tokens[: -stop_condition.trim_length] - break - # This is a workaround because the llama tokenizer emits spaces when decoding token by token. - generated_text = _tokenizer.decode(tokens) - next_chunk = generated_text[current_generated_text_index:] - current_generated_text_index = len(generated_text) - response = response_creator(response_id, requested_model, next_chunk) - try: - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() - stop_sequence_buffer = [] - except Exception as e: - print(e) - break + # Continue generating tokens until buffer is as large as the longest stop_id_sequence + if len(stop_sequence_buffer) < max_stop_id_sequence_len: + continue + + # "\ufffd" is used to indicate to the tokenizer, that subsequent characters + # should be combined into a single unicode character + if "\ufffd" in TOKENIZER.decode(token): + continue + + stop_condition = stopping_criteria( + tokens, + stop_id_sequences, + TOKENIZER.eos_token_id, + ) + if stop_condition.stop_met: + if stop_condition.trim_length: + tokens = tokens[: -stop_condition.trim_length] + break + + # Workaround for llama tokenizer emitting spaces when decoding token by token. + generated_text = TOKENIZER.decode(tokens) + new_text = generated_text[current_generated_text_index:] + current_generated_text_index = len(generated_text) + + response = self.generate_response(new_text, None) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + stop_sequence_buffer = [] + # check is there any remaining text to send if stop_sequence_buffer: - generated_text = _tokenizer.decode(tokens) + generated_text = TOKENIZER.decode(tokens) next_chunk = generated_text[current_generated_text_index:] - response = response_creator(response_id, requested_model, next_chunk) - try: - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() - except Exception as e: - print(e) + response = self.generate_response(next_chunk, "length") - self.wfile.write(f"data: [DONE]\n\n".encode()) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + + self.wfile.write("data: [DONE]\n\n".encode()) self.wfile.flush() - def handle_chat_completions(self, post_data: bytes): - body = json.loads(post_data.decode("utf-8")) - chat_id = f"chatcmpl-{uuid.uuid4()}" - if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template: - prompt = _tokenizer.apply_chat_template( + def handle_chat_completions(self) -> mx.array: + """ + Handle a chat completion request + + Returns: + mx.array: A mx.array of the tokenized prompt from the request body + """ + body = self.body + assert "messages" in body, "Request did not contain messages" + + # Determine response type + self.request_id = f"chatcmpl-{uuid.uuid4()}" + self.object_type = ( + "chat.completions.chunk" if self.stream else "chat.completions" + ) + + if hasattr(TOKENIZER, "apply_chat_template") and TOKENIZER.chat_template: + prompt = TOKENIZER.apply_chat_template( body["messages"], tokenize=True, add_generation_prompt=True, - return_tensors="np", ) else: prompt = convert_chat(body["messages"], body.get("role_mapping")) - prompt = _tokenizer.encode(prompt, return_tensors="np") + prompt = TOKENIZER.encode(prompt) - prompt = mx.array(prompt[0]) - stop_words = body.get("stop", []) - stop_words = [stop_words] if isinstance(stop_words, str) else stop_words - stop_id_sequences = [ - _tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[ - 0 - ] - for stop_word in stop_words - ] - eos_token_id = _tokenizer.eos_token_id - max_tokens = body.get("max_tokens", 100) - stream = body.get("stream", False) - requested_model = body.get("model", "default_model") - temperature = body.get("temperature", 1.0) - top_p = body.get("top_p", 1.0) - repetition_penalty = body.get("repetition_penalty", 1.0) - repetition_context_size = body.get("repetition_context_size", 20) - if not stream: - return self.generate_response( - prompt, - chat_id, - requested_model, - stop_id_sequences, - eos_token_id, - max_tokens, - temperature, - top_p, - repetition_penalty, - repetition_context_size, - create_chat_response, - ) - else: - self.handle_stream( - prompt, - chat_id, - requested_model, - stop_id_sequences, - eos_token_id, - max_tokens, - temperature, - top_p, - repetition_penalty, - repetition_context_size, - create_chat_chunk_response, - ) + return mx.array(prompt) - def handle_completions(self, post_data: bytes): - body = json.loads(post_data.decode("utf-8")) - completion_id = f"cmpl-{uuid.uuid4()}" - prompt_text = body["prompt"] - prompt = _tokenizer.encode(prompt_text, return_tensors="np") - prompt = mx.array(prompt[0]) - stop_words = body.get("stop", []) - stop_words = [stop_words] if isinstance(stop_words, str) else stop_words - stop_id_sequences = [ - _tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[ - 0 - ] - for stop_word in stop_words - ] - eos_token_id = _tokenizer.eos_token_id - max_tokens = body.get("max_tokens", 100) - stream = body.get("stream", False) - requested_model = body.get("model", "default_model") - temperature = body.get("temperature", 1.0) - top_p = body.get("top_p", 1.0) - repetition_penalty = body.get("repetition_penalty", 1.0) - repetition_context_size = body.get("repetition_context_size", 20) - if not stream: - return self.generate_response( - prompt, - completion_id, - requested_model, - stop_id_sequences, - eos_token_id, - max_tokens, - temperature, - top_p, - repetition_penalty, - repetition_context_size, - create_completion_response, - ) - else: - self.handle_stream( - prompt, - completion_id, - requested_model, - stop_id_sequences, - eos_token_id, - max_tokens, - temperature, - top_p, - repetition_penalty, - repetition_context_size, - create_completion_chunk_response, - ) + def handle_text_completions(self) -> mx.array: + """ + Handle a text completion request + + Returns: + mx.array: A mx.array of the tokenized prompt from the request body + """ + # Determine response type + self.request_id = f"cmpl-{uuid.uuid4()}" + self.object_type = "text_completion" + + assert "prompt" in self.body, "Request did not contain a prompt" + prompt_text = self.body["prompt"] + prompt = TOKENIZER.encode(prompt_text) + return mx.array(prompt) def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler): @@ -458,6 +434,6 @@ if __name__ == "__main__": ) args = parser.parse_args() - load_model(args.model, adapter_file=args.adapter_file) + MODEL, TOKENIZER = load(args.model, adapter_file=args.adapter_file) run(args.host, args.port) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index ec23fbd6..6adcf924 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -114,7 +114,7 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f def generate_step( prompt: mx.array, model: nn.Module, - temp: 0.0, + temp: float = 0.0, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, top_p: float = 1.0, @@ -128,6 +128,7 @@ def generate_step( temp (float): The temperature for sampling, if 0 the argmax is used. repetition_penalty (float, optional): The penalty factor for repeating tokens. repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20). + top_p (float, optional): Nulceus sampling, higher means model considers more less likely words Yields: Generator[Tuple[mx.array, mx.array]]: A generator producing @@ -205,7 +206,7 @@ def generate( temp: float = 0.0, max_tokens: int = 100, verbose: bool = False, - formatter: Callable = None, + formatter: Optional[Callable] = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = None, top_p: float = 1.0, @@ -357,14 +358,14 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: def load( path_or_hf_repo: str, tokenizer_config={}, - adapter_file: str = None, + adapter_file: Optional[str] = None, lazy: bool = False, ) -> Tuple[nn.Module, PreTrainedTokenizer]: """ Load the model and tokenizer from a given path or a huggingface repository. Args: - model_path (Path): The path or the huggingface repository to load the model from. + path_or_hf_repo (Path): The path or the huggingface repository to load the model from. tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. Defaults to an empty dictionary. adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.