Refactoring of mlx_lm example (#501)

* Use named tuple from typing for typehints

* Add type hints

* Simplify expression

* Type hint fix

* Improved do_POST logic

Use a map of endpoints to methods to reduce redundancy in code

* Fix format

* Improve redundancy

Call method dynamically instead of writing out all arguments twice

* Send response instead of returning

* Fix typo

* Revert change

* Make adapter_file as Optional

* Mark formatter as optional

* format

* Create message generator

Store response data that stays static for the duration of the response inside of the object:

system_fingerprint
request_id
object_type
requested_model

Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline

* Remove leftover

* Update parameters to reflect new object structure

No longer pass all arguments between functions, but use the stores values inside of the object

* Parse body before calling request specific methods

* Call super init

* Update server.py

* Fixed outdated documentation parameter name

* Add documentation

* Fix sending headers twice

During testing I found that when using the streaming option, headers have always been sent twice. This should fix that

* Simplify streaming code by using guard clauses

Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing

* Bug fix

* Use Content-Length header

Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.

* Update utils.py

* Add top_p documentation

* Type hint model and tokenizer as required

* Use static system fingerprint

System fingerprint now stays the same across requests

* Make type hint more specific

* Bug Fix

Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.

Mark upload_repo as optional

* Move more of the shared code into do_POST

Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.

* Store stop_id_sequences as lists instead of np

During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.

* Update stop_id_sequences docs

* Turn if check to non-inclusive

Only continue if buffer is smaller

* Documentation fix

* Cleared method names

Instead of handle_stream and generate_competion, we should name it handle_completion.

Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive

* Make comment clearer

* fix format

* format
This commit is contained in:
Y4hL
2024-03-06 16:24:31 +02:00
committed by GitHub
parent 710c552731
commit b8e5eda4fd
4 changed files with 280 additions and 303 deletions

View File

@@ -5,6 +5,7 @@ import glob
import json
import shutil
from pathlib import Path
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
@@ -109,7 +110,7 @@ def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
def merge(
config: str,
mlx_path: str = "mlx_model",
upload_repo: str = None,
upload_repo: Optional[str] = None,
):
with open(config, "r") as fid:
merge_conf = yaml.safe_load(fid)
@@ -117,7 +118,7 @@ def merge(
model_paths = merge_conf.get("models", [])
if len(model_paths) < 2:
raise ValueError(f"Expected at least 2 models, got {len(models)}.")
raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.")
# Load all models
base_hf_path = model_paths[0]
@@ -125,9 +126,9 @@ def merge(
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
models = []
for mp in model_paths[1:]:
model, config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
base_type = base_config["model_type"]
model_type = config["model_type"]
model_type = model_config["model_type"]
if base_type != model_type:
raise ValueError(
f"Can only merge models of the same type,"