diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 139e0b18..f32d62dd 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -263,13 +263,6 @@ class KVCache(_BaseCache): n = min(self.offset, n) self.offset -= n return n - def trim_from_behind(self, n): - old_size = self.keys.shape[2] - self.keys = self.keys[..., -n:, :] - self.values = self.values[..., -n:, :] - new_size = self.keys.shape[2] - trimmed = old_size - new_size - self.offset -= trimmed def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d56855ca..4d69115e 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -404,7 +404,6 @@ def generate( prompt: str, verbose: bool = False, formatter: Optional[Callable] = None, - stop_strings: Optional[List[str]] = None, **kwargs, ) -> str: """ @@ -433,8 +432,6 @@ def generate( if verbose: print(response.text, end="", flush=True) text += response.text - if stop_strings is not None and any(s in text for s in stop_strings): - break if verbose: print() @@ -869,226 +866,3 @@ def convert( if upload_repo is not None: upload_to_hub(mlx_path, upload_repo, hf_path) -from tqdm import tqdm - -def generate_batched_response( - model: nn.Module, - tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: Union[str, mx.array, List[int]], - batch_size: int, - max_tokens: int = 256, - sampler: Optional[Callable[[mx.array], mx.array]] = None, - logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, - max_kv_size: Optional[int] = None, - prompt_cache: Optional[List[Any]] = None, - prefill_step_size: int = 512, - kv_bits: Optional[int] = None, - kv_group_size: int = 64, - quantized_kv_start: int = 0, - prompt_progress_callback: Optional[Callable[[int, int], None]] = None, - temp: Optional[float] = None, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = None, - top_p: Optional[float] = None, - min_p: Optional[float] = None, - min_tokens_to_keep: Optional[int] = None, - verbose: bool = False, -) -> List[str]: - """ - Generate multiple responses to the same prompt in parallel and return only the generated - sequences (excluding the prompt), stopping at the first EOS token. - - Args: - model (nn.Module): The language model. - tokenizer (PreTrainedTokenizer or TokenizerWrapper): The tokenizer. - prompt (Union[str, mx.array, List[int]]): The input prompt. - batch_size (int): Number of responses to generate in parallel. - max_tokens (int): Maximum number of generated tokens per sequence. - sampler (Callable): Sampler function. - logits_processors (List[Callable]): List of logits processors. - max_kv_size (int): Maximum KV cache size. - prompt_cache (List[Any]): Precomputed prompt cache. - prefill_step_size (int): Step size for prompt processing. - kv_bits (int): Bits for KV cache quantization. - kv_group_size (int): Group size for KV quantization. - quantized_kv_start (int): Step to begin quantizing KV. - prompt_progress_callback (Callable): Callback for prompt progress. - temp (float): Temperature for sampling (deprecated, pass to sampler). - repetition_penalty (float): Repetition penalty (deprecated, use logits_processors). - repetition_context_size (int): Context size for repetition. - top_p (float): Top-p sampling (deprecated, pass to sampler). - min_p (float): Minimum p sampling (deprecated, pass to sampler). - min_tokens_to_keep (int): Minimum number of tokens to keep. - verbose (bool): If True, show a progress bar for token generation. - - Returns: - List[str]: A list of decoded response strings for each batch element, excluding the prompt - and stopping at the first EOS token. - """ - if not isinstance(tokenizer, TokenizerWrapper): - tokenizer = TokenizerWrapper(tokenizer) - - # Convert prompt to tokens if necessary - if not isinstance(prompt, mx.array): - prompt = mx.array( - prompt if isinstance(prompt, list) else tokenizer.encode(prompt) - ) - - # Expand prompt to batch - prompt_length = prompt.size - prompt = mx.expand_dims(prompt, 0) # (1, prompt_length) - prompt = mx.repeat(prompt, batch_size, axis=0) # (B, prompt_length) - B = batch_size - - if prompt_progress_callback is None: - prompt_progress_callback = lambda *_: None - - if temp is not None or top_p is not None or min_tokens_to_keep is not None: - print( - "[Warning] Specifying sampling arguments directly is deprecated. " - "Pass in a `sampler` if needed." - ) - if repetition_penalty is not None: - print( - "[Warning] Specifying `repetition_penalty` is deprecated. " - "Use `logits_processors` instead." - ) - - sampler = sampler or make_sampler( - temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1 - ) - logits_processors = logits_processors or make_logits_processors( - None, repetition_penalty, repetition_context_size or 20 - ) - - # Create or verify prompt cache - if prompt_cache is None: - prompt_cache = cache.make_prompt_cache(model, max_kv_size) - elif len(prompt_cache) != len(model.layers): - raise ValueError("Wrong number of layers in the prompt cache.") - - # Process the prompt to fill the cache in increments - total_prompt_tokens = prompt_length - prompt_processed_tokens = 0 - remaining_prompt = prompt - tic = time.perf_counter() - with mx.stream(generation_stream): - while remaining_prompt.shape[1] > prefill_step_size: - model(remaining_prompt[:, :prefill_step_size], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) - prompt_processed_tokens += prefill_step_size - remaining_prompt = remaining_prompt[:, prefill_step_size:] - mx.metal.clear_cache() - - # Process any remaining prompt tokens - if remaining_prompt.shape[1] > 0: - model(remaining_prompt, cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) - - prompt_time = time.perf_counter() - tic - prompt_tps = (total_prompt_tokens * B) / prompt_time - - # Initialization for generation - tokens = prompt - finished = mx.zeros((B,), dtype=tokens.dtype) - generation_count = 0 - eos_ids = tokenizer.eos_token_ids - - # Setup progress bar if verbose - pbar = None - if verbose: - if max_tokens >= 0: - pbar = tqdm(total=max_tokens, desc="Generating tokens", ncols=80) - else: - # If we don't have a max_tokens limit, no total is known. - # We'll just display a progress bar that counts up. - pbar = tqdm(desc="Generating tokens", ncols=80) - - tic = time.perf_counter() - - while True: - if (max_tokens >= 0) and (generation_count >= max_tokens): - break - - # If all sequences finished, break - sum_finished = mx.sum(finished) - mx.eval(sum_finished) - if sum_finished.item() == B: - break - - # Prepare last token - next_input = tokens[:, -1:] # (B,1) - with mx.stream(generation_stream): - logits = model(next_input, cache=prompt_cache) - # logits: (B, 1, vocab) - logits = logits[:, -1, :] # (B, vocab) - - # Apply logits processors - if logits_processors: - for processor in logits_processors: - logits = processor(tokens, logits) - - maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits) - - logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) # (B,vocab) - sampled_tokens = sampler(logprobs) # (B,) - - mx.async_eval(sampled_tokens, logprobs) - - # Check EOS - is_eos = mx.zeros_like(sampled_tokens).astype(tokens.dtype) - for eid in eos_ids: - diff = sampled_tokens - eid - sq = diff * diff - val = 1.0 / (sq + 1.0) - mask = val.astype(tokens.dtype) - is_eos = is_eos + mask - - ones = mx.ones_like(is_eos) - is_eos = mx.minimum(is_eos, ones) - finished = mx.maximum(finished, is_eos) - - sampled_tokens = sampled_tokens[:, None] # (B,1) - tokens = mx.concatenate([tokens, sampled_tokens], axis=1) - - generation_count += 1 - if pbar is not None: - pbar.update(1) - - if (generation_count % 256) == 0: - mx.metal.clear_cache() - - if pbar is not None: - pbar.close() - - generation_time = time.perf_counter() - tic - generation_tps = (generation_count * B) / generation_time if generation_count > 0 else 0.0 - peak_memory = mx.metal.get_peak_memory() / 1e9 - - results = [] - for i in range(B): - seq = tokens[i][prompt_length:].tolist() # Exclude the prompt - # Find the first EOS token - eos_pos = None - for idx, t in enumerate(seq): - if t in eos_ids: - eos_pos = idx - break - # Slice up to EOS if found - if eos_pos is not None: - seq = seq[:eos_pos] - text = tokenizer.decode(seq) - results.append(text) - - if verbose: - print("=" * 10) - print(f"Prompt: {total_prompt_tokens} tokens * {B} sequences, {prompt_tps:.3f} tps") - print( - f"Generation: {generation_count} tokens * {B} sequences, " - f"{generation_tps:.3f} tps" - ) - print(f"Peak memory: {peak_memory:.3f} GB") - - return results