diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index de02704d..1a39ab8b 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -452,17 +452,24 @@ class APIHandler(BaseHTTPRequestHandler): def get_prompt_cache(self, prompt): cache_len = len(self.prompt_cache.tokens) + # Check if the cache is valid for the current prompt if ( self.prompt_cache.model_key != self.model_provider.model_key or cache_len >= len(prompt) or self.prompt_cache.tokens != prompt[:cache_len] ): + # Reinitialize the cache entirely self.prompt_cache.model_key = self.model_provider.model_key self.prompt_cache.cache = make_prompt_cache(self.model_provider.model) + # Reset the cache tokens to be empty because the cache was re-created + self.prompt_cache.tokens = [] + new_prompt = prompt else: - prompt = prompt[cache_len:] - self.prompt_cache.tokens.extend(prompt) - return prompt + # Use the already cached tokens; only process the tail of the prompt + new_prompt = prompt[cache_len:] + # Update the cache tokens with the new tokens being processed + self.prompt_cache.tokens.extend(new_prompt) + return new_prompt def handle_completion( self,