modified mlx_lm.server.py to support prompt caching.

This commit is contained in:
nath1295
2024-09-27 17:49:51 +01:00
parent 7e98499ee3
commit 97257511c4

View File

@@ -5,6 +5,7 @@ import json
import logging
import time
import uuid
import os
import warnings
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
@@ -12,7 +13,7 @@ from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
import mlx.core as mx
from .utils import generate_step, load
from .utils import generate_step, load, save_cache, load_cache
class StopCondition(NamedTuple):
@@ -100,6 +101,7 @@ class ModelProvider:
self.model_key = None
self.model = None
self.tokenizer = None
self.cache_history = None
# Preload the default model if it is provided
if self.cli_args.model is not None:
@@ -121,6 +123,7 @@ class ModelProvider:
self.model = None
self.tokenizer = None
self.model_key = None
self.cache_history = None
# Building tokenizer_config
tokenizer_config = {
@@ -408,7 +411,7 @@ class APIHandler(BaseHTTPRequestHandler):
logging.debug(f"Starting completion:")
token_logprobs = []
top_tokens = []
for (token, logprobs), _ in zip(
for step_output, _ in zip(
generate_step(
prompt=prompt,
model=self.model,
@@ -417,21 +420,24 @@ class APIHandler(BaseHTTPRequestHandler):
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
verbose=True,
cache_history=self.model_provider.cache_history
),
range(self.max_tokens),
):
detokenizer.add_token(token)
self.model_provider.cache_history = step_output
detokenizer.add_token(step_output.token)
logging.debug(detokenizer.text)
tokens.append(token)
tokens.append(step_output.token)
if self.logprobs > 0:
sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1)
sorted_indices = mx.argpartition(-step_output.logprobs, kth=self.logprobs - 1)
top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices]
top_logprobs = step_output.logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(dict(top_token_info))
token_logprobs.append(logprobs[token].item())
token_logprobs.append(step_output.logprobs[step_output.token].item())
stop_condition = stopping_criteria(
tokens, stop_id_sequences, self.tokenizer.eos_token_id
@@ -495,7 +501,7 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = None
logging.debug(f"Starting stream:")
for (token, _), _ in zip(
for step_output, _ in zip(
generate_step(
prompt=prompt,
model=self.model,
@@ -503,12 +509,15 @@ class APIHandler(BaseHTTPRequestHandler):
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
verbose=True,
cache_history=self.model_provider.cache_history
),
range(self.max_tokens),
):
detokenizer.add_token(token)
self.model_provider.cache_history = step_output
detokenizer.add_token(step_output.token)
logging.debug(detokenizer.text)
tokens.append(token)
tokens.append(step_output.token)
stop_condition = stopping_criteria(
tokens,