diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index f02f49b1..fb600fcd 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -34,18 +34,22 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non return mask * -1e9 -def create_attention_mask(h: mx.array, cache: Optional[Any] = None): +def create_attention_mask(h: mx.array, cache: Optional[Any] = None, reference_cache_idx: Optional[int] = None) -> mx.array: T = h.shape[1] if T > 1: window_size = None offset = 0 if cache is not None and cache[0] is not None: - c = cache[0] + if reference_cache_idx is not None: + c = cache[reference_cache_idx] + else: + c = cache[0] if hasattr(c, "max_size"): offset = min(c.max_size, c.offset) window_size = c.max_size else: offset = c.offset + mask = create_causal_mask(T, offset, window_size=window_size) mask = mask.astype(h.dtype) else: diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 14026f0c..139e0b18 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -6,7 +6,6 @@ import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_flatten, tree_map, tree_unflatten - def make_prompt_cache( model: nn.Module, max_kv_size: Optional[int] = None, @@ -33,7 +32,7 @@ def make_prompt_cache( ] else: return [KVCache() for _ in range(num_layers)] - + def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}): """ @@ -264,6 +263,13 @@ class KVCache(_BaseCache): n = min(self.offset, n) self.offset -= n return n + def trim_from_behind(self, n): + old_size = self.keys.shape[2] + self.keys = self.keys[..., -n:, :] + self.values = self.values[..., -n:, :] + new_size = self.keys.shape[2] + trimmed = old_size - new_size + self.offset -= trimmed def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) @@ -416,7 +422,8 @@ class RotatingKVCache(_BaseCache): return n def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: - raise NotImplementedError("RotatingKVCache Quantization NYI") + return self + #raise NotImplementedError("RotatingKVCache Quantization NYI") class MambaCache(_BaseCache): diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py new file mode 100644 index 00000000..5796e49c --- /dev/null +++ b/llms/mlx_lm/models/cohere2.py @@ -0,0 +1,165 @@ +from dataclasses import dataclass +from typing import Any, Optional + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope +from .cache import KVCache, RotatingKVCache + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + num_key_value_heads: int + rope_theta: float + vocab_size: int + layer_norm_eps: float + logit_scale: float + attention_bias: bool + # Additional Cohere2-specific arguments: + # rope_type and max_position_embeddings might influence the rope setup + rope_type: str = "default" + max_position_embeddings: int = 2048 + sliding_window: Optional[int] = None, + sliding_window_pattern: Optional[int] = None, + order_of_interleaved_layers: Optional[int] = None, + use_cache: bool = True + + + +class Cohere2Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + self.n_heads = args.num_attention_heads + self.n_kv_heads = args.num_key_value_heads + head_dim = dim // self.n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, self.n_heads * head_dim, bias=args.attention_bias) + self.k_proj = nn.Linear(dim, self.n_kv_heads * head_dim, bias=args.attention_bias) + self.v_proj = nn.Linear(dim, self.n_kv_heads * head_dim, bias=args.attention_bias) + self.o_proj = nn.Linear(self.n_heads * head_dim, dim, bias=args.attention_bias) + + self.sliding_window = args.sliding_window # Not yet implemented :( + self.use_qk_norm = False # Assuming QK norm not used by Cohere2 (adjust if needed) + + # Initialize RoPE for Cohere2 + self.rope = initialize_rope( + dims=head_dim, + base=args.rope_theta, + traditional=True, + max_position_embeddings=args.max_position_embeddings, + ) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, rope = True) -> mx.array: + B, L, D = x.shape + q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + # Apply RoPE + # In Cohere2, the original code applies RoPE before caching updates. We replicate that: + if cache is not None: + if rope: + q = self.rope(q, offset=cache.offset) + k = self.rope(k, offset=cache.offset) + k, v = cache.update_and_fetch(k, v) + if rope: + k = k[:, :, -self.sliding_window:, :] + v = v[:, :, -self.sliding_window:, :] + elif rope: + q = self.rope(q) + k = self.rope(k) + # Compute attention + out = scaled_dot_product_attention( + q, k, v, cache=cache, scale=self.scale, mask=mask + ) + + out = out.transpose(0, 2, 1, 3).reshape(B, L, D) + return self.o_proj(out) + + +class Cohere2MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + hdim = args.intermediate_size + self.gate_proj = nn.Linear(dim, hdim, bias=False) + self.up_proj = nn.Linear(dim, hdim, bias=False) + self.down_proj = nn.Linear(hdim, dim, bias=False) + + def __call__(self, x: mx.array) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class Cohere2TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.self_attn = Cohere2Attention(args) + self.mlp = Cohere2MLP(args) + self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.layer_norm_eps, affine=True, bias=False) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, rope = True) -> mx.array: + h = self.input_layernorm(x) + attn_h = self.self_attn(h, mask, cache, rope=rope) + ff_h = self.mlp(h) + return x + attn_h + ff_h + + +class Cohere2Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [Cohere2TransformerBlock(args) for _ in range(args.num_hidden_layers)] + self.norm = nn.LayerNorm(args.hidden_size, eps=args.layer_norm_eps, affine=True, bias=False) + self.sliding_window = args.sliding_window + self.sliding_window_pattern = args.sliding_window_pattern + def __call__(self, inputs: mx.array, cache: Optional[Any] = None) -> mx.array: + h = self.embed_tokens(inputs) + mask = create_attention_mask(h, cache, reference_cache_idx=self.sliding_window_pattern - 1) + sliding_window_mask = mask[:, -self.sliding_window:] if mask is not None else None + if cache is None: + cache = [None] * len(self.layers) + for i, (layer, c) in enumerate(zip(self.layers, cache)): + if self.sliding_window is not None: + index = i % self.sliding_window_pattern + if index < self.sliding_window_pattern - 1: + h = layer(h, mask=sliding_window_mask, cache=c) + else: + h = layer(h, mask=mask, cache=c, rope=False) + + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model_type = args.model_type + + self.model = Cohere2Model(args) + self.args = args + + def __call__(self, inputs: mx.array, cache=None): + out = self.model(inputs, cache) + out = self.model.embed_tokens.as_linear(out) * self.args.logit_scale + return out + + @property + def layers(self): + return self.model.layers + + def make_cache(self): + caches = [] + for i in range(self.args.num_hidden_layers): + if i % self.args.sliding_window_pattern == self.args.sliding_window_pattern - 1: + caches.append(KVCache()) + else: + caches.append(RotatingKVCache(max_size=self.args.sliding_window, keep=0)) + return caches diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b87f5a24..d56855ca 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -187,9 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ and prompt_cache[0].offset > quantized_kv_start ): for i in range(len(prompt_cache)): - prompt_cache[i] = prompt_cache[i].to_quantized( - group_size=kv_group_size, bits=kv_bits - ) + if isinstance(prompt_cache[i], cache.KVCache): + prompt_cache[i] = prompt_cache[i].to_quantized( + group_size=kv_group_size, bits=kv_bits + ) def generate_step( @@ -403,6 +404,7 @@ def generate( prompt: str, verbose: bool = False, formatter: Optional[Callable] = None, + stop_strings: Optional[List[str]] = None, **kwargs, ) -> str: """ @@ -431,6 +433,8 @@ def generate( if verbose: print(response.text, end="", flush=True) text += response.text + if stop_strings is not None and any(s in text for s in stop_strings): + break if verbose: print() @@ -865,3 +869,226 @@ def convert( if upload_repo is not None: upload_to_hub(mlx_path, upload_repo, hf_path) +from tqdm import tqdm + +def generate_batched_response( + model: nn.Module, + tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], + prompt: Union[str, mx.array, List[int]], + batch_size: int, + max_tokens: int = 256, + sampler: Optional[Callable[[mx.array], mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + max_kv_size: Optional[int] = None, + prompt_cache: Optional[List[Any]] = None, + prefill_step_size: int = 512, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, + prompt_progress_callback: Optional[Callable[[int, int], None]] = None, + temp: Optional[float] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = None, + top_p: Optional[float] = None, + min_p: Optional[float] = None, + min_tokens_to_keep: Optional[int] = None, + verbose: bool = False, +) -> List[str]: + """ + Generate multiple responses to the same prompt in parallel and return only the generated + sequences (excluding the prompt), stopping at the first EOS token. + + Args: + model (nn.Module): The language model. + tokenizer (PreTrainedTokenizer or TokenizerWrapper): The tokenizer. + prompt (Union[str, mx.array, List[int]]): The input prompt. + batch_size (int): Number of responses to generate in parallel. + max_tokens (int): Maximum number of generated tokens per sequence. + sampler (Callable): Sampler function. + logits_processors (List[Callable]): List of logits processors. + max_kv_size (int): Maximum KV cache size. + prompt_cache (List[Any]): Precomputed prompt cache. + prefill_step_size (int): Step size for prompt processing. + kv_bits (int): Bits for KV cache quantization. + kv_group_size (int): Group size for KV quantization. + quantized_kv_start (int): Step to begin quantizing KV. + prompt_progress_callback (Callable): Callback for prompt progress. + temp (float): Temperature for sampling (deprecated, pass to sampler). + repetition_penalty (float): Repetition penalty (deprecated, use logits_processors). + repetition_context_size (int): Context size for repetition. + top_p (float): Top-p sampling (deprecated, pass to sampler). + min_p (float): Minimum p sampling (deprecated, pass to sampler). + min_tokens_to_keep (int): Minimum number of tokens to keep. + verbose (bool): If True, show a progress bar for token generation. + + Returns: + List[str]: A list of decoded response strings for each batch element, excluding the prompt + and stopping at the first EOS token. + """ + if not isinstance(tokenizer, TokenizerWrapper): + tokenizer = TokenizerWrapper(tokenizer) + + # Convert prompt to tokens if necessary + if not isinstance(prompt, mx.array): + prompt = mx.array( + prompt if isinstance(prompt, list) else tokenizer.encode(prompt) + ) + + # Expand prompt to batch + prompt_length = prompt.size + prompt = mx.expand_dims(prompt, 0) # (1, prompt_length) + prompt = mx.repeat(prompt, batch_size, axis=0) # (B, prompt_length) + B = batch_size + + if prompt_progress_callback is None: + prompt_progress_callback = lambda *_: None + + if temp is not None or top_p is not None or min_tokens_to_keep is not None: + print( + "[Warning] Specifying sampling arguments directly is deprecated. " + "Pass in a `sampler` if needed." + ) + if repetition_penalty is not None: + print( + "[Warning] Specifying `repetition_penalty` is deprecated. " + "Use `logits_processors` instead." + ) + + sampler = sampler or make_sampler( + temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1 + ) + logits_processors = logits_processors or make_logits_processors( + None, repetition_penalty, repetition_context_size or 20 + ) + + # Create or verify prompt cache + if prompt_cache is None: + prompt_cache = cache.make_prompt_cache(model, max_kv_size) + elif len(prompt_cache) != len(model.layers): + raise ValueError("Wrong number of layers in the prompt cache.") + + # Process the prompt to fill the cache in increments + total_prompt_tokens = prompt_length + prompt_processed_tokens = 0 + remaining_prompt = prompt + tic = time.perf_counter() + with mx.stream(generation_stream): + while remaining_prompt.shape[1] > prefill_step_size: + model(remaining_prompt[:, :prefill_step_size], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) + prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) + prompt_processed_tokens += prefill_step_size + remaining_prompt = remaining_prompt[:, prefill_step_size:] + mx.metal.clear_cache() + + # Process any remaining prompt tokens + if remaining_prompt.shape[1] > 0: + model(remaining_prompt, cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) + prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) + + prompt_time = time.perf_counter() - tic + prompt_tps = (total_prompt_tokens * B) / prompt_time + + # Initialization for generation + tokens = prompt + finished = mx.zeros((B,), dtype=tokens.dtype) + generation_count = 0 + eos_ids = tokenizer.eos_token_ids + + # Setup progress bar if verbose + pbar = None + if verbose: + if max_tokens >= 0: + pbar = tqdm(total=max_tokens, desc="Generating tokens", ncols=80) + else: + # If we don't have a max_tokens limit, no total is known. + # We'll just display a progress bar that counts up. + pbar = tqdm(desc="Generating tokens", ncols=80) + + tic = time.perf_counter() + + while True: + if (max_tokens >= 0) and (generation_count >= max_tokens): + break + + # If all sequences finished, break + sum_finished = mx.sum(finished) + mx.eval(sum_finished) + if sum_finished.item() == B: + break + + # Prepare last token + next_input = tokens[:, -1:] # (B,1) + with mx.stream(generation_stream): + logits = model(next_input, cache=prompt_cache) + # logits: (B, 1, vocab) + logits = logits[:, -1, :] # (B, vocab) + + # Apply logits processors + if logits_processors: + for processor in logits_processors: + logits = processor(tokens, logits) + + maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits) + + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) # (B,vocab) + sampled_tokens = sampler(logprobs) # (B,) + + mx.async_eval(sampled_tokens, logprobs) + + # Check EOS + is_eos = mx.zeros_like(sampled_tokens).astype(tokens.dtype) + for eid in eos_ids: + diff = sampled_tokens - eid + sq = diff * diff + val = 1.0 / (sq + 1.0) + mask = val.astype(tokens.dtype) + is_eos = is_eos + mask + + ones = mx.ones_like(is_eos) + is_eos = mx.minimum(is_eos, ones) + finished = mx.maximum(finished, is_eos) + + sampled_tokens = sampled_tokens[:, None] # (B,1) + tokens = mx.concatenate([tokens, sampled_tokens], axis=1) + + generation_count += 1 + if pbar is not None: + pbar.update(1) + + if (generation_count % 256) == 0: + mx.metal.clear_cache() + + if pbar is not None: + pbar.close() + + generation_time = time.perf_counter() - tic + generation_tps = (generation_count * B) / generation_time if generation_count > 0 else 0.0 + peak_memory = mx.metal.get_peak_memory() / 1e9 + + results = [] + for i in range(B): + seq = tokens[i][prompt_length:].tolist() # Exclude the prompt + # Find the first EOS token + eos_pos = None + for idx, t in enumerate(seq): + if t in eos_ids: + eos_pos = idx + break + # Slice up to EOS if found + if eos_pos is not None: + seq = seq[:eos_pos] + text = tokenizer.decode(seq) + results.append(text) + + if verbose: + print("=" * 10) + print(f"Prompt: {total_prompt_tokens} tokens * {B} sequences, {prompt_tps:.3f} tps") + print( + f"Generation: {generation_count} tokens * {B} sequences, " + f"{generation_tps:.3f} tps" + ) + print(f"Peak memory: {peak_memory:.3f} GB") + + return results