mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints * Add type hints * Simplify expression * Type hint fix * Improved do_POST logic Use a map of endpoints to methods to reduce redundancy in code * Fix format * Improve redundancy Call method dynamically instead of writing out all arguments twice * Send response instead of returning * Fix typo * Revert change * Make adapter_file as Optional * Mark formatter as optional * format * Create message generator Store response data that stays static for the duration of the response inside of the object: system_fingerprint request_id object_type requested_model Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline * Remove leftover * Update parameters to reflect new object structure No longer pass all arguments between functions, but use the stores values inside of the object * Parse body before calling request specific methods * Call super init * Update server.py * Fixed outdated documentation parameter name * Add documentation * Fix sending headers twice During testing I found that when using the streaming option, headers have always been sent twice. This should fix that * Simplify streaming code by using guard clauses Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing * Bug fix * Use Content-Length header Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion. * Update utils.py * Add top_p documentation * Type hint model and tokenizer as required * Use static system fingerprint System fingerprint now stays the same across requests * Make type hint more specific * Bug Fix Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead. Mark upload_repo as optional * Move more of the shared code into do_POST Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form. * Store stop_id_sequences as lists instead of np During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported. * Update stop_id_sequences docs * Turn if check to non-inclusive Only continue if buffer is smaller * Documentation fix * Cleared method names Instead of handle_stream and generate_competion, we should name it handle_completion. Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive * Make comment clearer * fix format * format
This commit is contained in:
@@ -5,6 +5,7 @@ import glob
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -109,7 +110,7 @@ def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
|
||||
def merge(
|
||||
config: str,
|
||||
mlx_path: str = "mlx_model",
|
||||
upload_repo: str = None,
|
||||
upload_repo: Optional[str] = None,
|
||||
):
|
||||
with open(config, "r") as fid:
|
||||
merge_conf = yaml.safe_load(fid)
|
||||
@@ -117,7 +118,7 @@ def merge(
|
||||
|
||||
model_paths = merge_conf.get("models", [])
|
||||
if len(model_paths) < 2:
|
||||
raise ValueError(f"Expected at least 2 models, got {len(models)}.")
|
||||
raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.")
|
||||
|
||||
# Load all models
|
||||
base_hf_path = model_paths[0]
|
||||
@@ -125,9 +126,9 @@ def merge(
|
||||
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
|
||||
models = []
|
||||
for mp in model_paths[1:]:
|
||||
model, config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
|
||||
model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
|
||||
base_type = base_config["model_type"]
|
||||
model_type = config["model_type"]
|
||||
model_type = model_config["model_type"]
|
||||
if base_type != model_type:
|
||||
raise ValueError(
|
||||
f"Can only merge models of the same type,"
|
||||
|
Reference in New Issue
Block a user