mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
modified mlx_lm.server.py to support prompt caching.
This commit is contained in:
@@ -5,6 +5,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
import os
|
||||
import warnings
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
@@ -12,7 +13,7 @@ from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .utils import generate_step, load
|
||||
from .utils import generate_step, load, save_cache, load_cache
|
||||
|
||||
|
||||
class StopCondition(NamedTuple):
|
||||
@@ -100,6 +101,7 @@ class ModelProvider:
|
||||
self.model_key = None
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.cache_history = None
|
||||
|
||||
# Preload the default model if it is provided
|
||||
if self.cli_args.model is not None:
|
||||
@@ -121,6 +123,7 @@ class ModelProvider:
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.model_key = None
|
||||
self.cache_history = None
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = {
|
||||
@@ -408,7 +411,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
logging.debug(f"Starting completion:")
|
||||
token_logprobs = []
|
||||
top_tokens = []
|
||||
for (token, logprobs), _ in zip(
|
||||
for step_output, _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
model=self.model,
|
||||
@@ -417,21 +420,24 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
repetition_context_size=self.repetition_context_size,
|
||||
logit_bias=self.logit_bias,
|
||||
verbose=True,
|
||||
cache_history=self.model_provider.cache_history
|
||||
),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
detokenizer.add_token(token)
|
||||
self.model_provider.cache_history = step_output
|
||||
detokenizer.add_token(step_output.token)
|
||||
logging.debug(detokenizer.text)
|
||||
tokens.append(token)
|
||||
tokens.append(step_output.token)
|
||||
|
||||
if self.logprobs > 0:
|
||||
sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1)
|
||||
sorted_indices = mx.argpartition(-step_output.logprobs, kth=self.logprobs - 1)
|
||||
top_indices = sorted_indices[: self.logprobs]
|
||||
top_logprobs = logprobs[top_indices]
|
||||
top_logprobs = step_output.logprobs[top_indices]
|
||||
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
|
||||
top_tokens.append(dict(top_token_info))
|
||||
|
||||
token_logprobs.append(logprobs[token].item())
|
||||
token_logprobs.append(step_output.logprobs[step_output.token].item())
|
||||
|
||||
stop_condition = stopping_criteria(
|
||||
tokens, stop_id_sequences, self.tokenizer.eos_token_id
|
||||
@@ -495,7 +501,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
stop_sequence_suffix = None
|
||||
logging.debug(f"Starting stream:")
|
||||
|
||||
for (token, _), _ in zip(
|
||||
for step_output, _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
model=self.model,
|
||||
@@ -503,12 +509,15 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
top_p=self.top_p,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
repetition_context_size=self.repetition_context_size,
|
||||
verbose=True,
|
||||
cache_history=self.model_provider.cache_history
|
||||
),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
detokenizer.add_token(token)
|
||||
self.model_provider.cache_history = step_output
|
||||
detokenizer.add_token(step_output.token)
|
||||
logging.debug(detokenizer.text)
|
||||
tokens.append(token)
|
||||
tokens.append(step_output.token)
|
||||
|
||||
stop_condition = stopping_criteria(
|
||||
tokens,
|
||||
|
Reference in New Issue
Block a user