mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:17:07 +08:00
Add Cohere2
This commit is contained in:
parent
9f2ea5892e
commit
75fbb7ed34
@ -34,18 +34,22 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
|
|||||||
return mask * -1e9
|
return mask * -1e9
|
||||||
|
|
||||||
|
|
||||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
def create_attention_mask(h: mx.array, cache: Optional[Any] = None, reference_cache_idx: Optional[int] = None) -> mx.array:
|
||||||
T = h.shape[1]
|
T = h.shape[1]
|
||||||
if T > 1:
|
if T > 1:
|
||||||
window_size = None
|
window_size = None
|
||||||
offset = 0
|
offset = 0
|
||||||
if cache is not None and cache[0] is not None:
|
if cache is not None and cache[0] is not None:
|
||||||
c = cache[0]
|
if reference_cache_idx is not None:
|
||||||
|
c = cache[reference_cache_idx]
|
||||||
|
else:
|
||||||
|
c = cache[0]
|
||||||
if hasattr(c, "max_size"):
|
if hasattr(c, "max_size"):
|
||||||
offset = min(c.max_size, c.offset)
|
offset = min(c.max_size, c.offset)
|
||||||
window_size = c.max_size
|
window_size = c.max_size
|
||||||
else:
|
else:
|
||||||
offset = c.offset
|
offset = c.offset
|
||||||
|
|
||||||
mask = create_causal_mask(T, offset, window_size=window_size)
|
mask = create_causal_mask(T, offset, window_size=window_size)
|
||||||
mask = mask.astype(h.dtype)
|
mask = mask.astype(h.dtype)
|
||||||
else:
|
else:
|
||||||
|
@ -6,7 +6,6 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
|
|
||||||
|
|
||||||
def make_prompt_cache(
|
def make_prompt_cache(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
@ -33,7 +32,7 @@ def make_prompt_cache(
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
return [KVCache() for _ in range(num_layers)]
|
return [KVCache() for _ in range(num_layers)]
|
||||||
|
|
||||||
|
|
||||||
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
|
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
|
||||||
"""
|
"""
|
||||||
@ -264,6 +263,13 @@ class KVCache(_BaseCache):
|
|||||||
n = min(self.offset, n)
|
n = min(self.offset, n)
|
||||||
self.offset -= n
|
self.offset -= n
|
||||||
return n
|
return n
|
||||||
|
def trim_from_behind(self, n):
|
||||||
|
old_size = self.keys.shape[2]
|
||||||
|
self.keys = self.keys[..., -n:, :]
|
||||||
|
self.values = self.values[..., -n:, :]
|
||||||
|
new_size = self.keys.shape[2]
|
||||||
|
trimmed = old_size - new_size
|
||||||
|
self.offset -= trimmed
|
||||||
|
|
||||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||||
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
||||||
@ -416,7 +422,8 @@ class RotatingKVCache(_BaseCache):
|
|||||||
return n
|
return n
|
||||||
|
|
||||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||||
raise NotImplementedError("RotatingKVCache Quantization NYI")
|
return self
|
||||||
|
#raise NotImplementedError("RotatingKVCache Quantization NYI")
|
||||||
|
|
||||||
|
|
||||||
class MambaCache(_BaseCache):
|
class MambaCache(_BaseCache):
|
||||||
|
165
llms/mlx_lm/models/cohere2.py
Normal file
165
llms/mlx_lm/models/cohere2.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
from .rope_utils import initialize_rope
|
||||||
|
from .cache import KVCache, RotatingKVCache
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
num_key_value_heads: int
|
||||||
|
rope_theta: float
|
||||||
|
vocab_size: int
|
||||||
|
layer_norm_eps: float
|
||||||
|
logit_scale: float
|
||||||
|
attention_bias: bool
|
||||||
|
# Additional Cohere2-specific arguments:
|
||||||
|
# rope_type and max_position_embeddings might influence the rope setup
|
||||||
|
rope_type: str = "default"
|
||||||
|
max_position_embeddings: int = 2048
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
sliding_window_pattern: Optional[int] = None,
|
||||||
|
order_of_interleaved_layers: Optional[int] = None,
|
||||||
|
use_cache: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Cohere2Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = args.num_key_value_heads
|
||||||
|
head_dim = dim // self.n_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(dim, self.n_heads * head_dim, bias=args.attention_bias)
|
||||||
|
self.k_proj = nn.Linear(dim, self.n_kv_heads * head_dim, bias=args.attention_bias)
|
||||||
|
self.v_proj = nn.Linear(dim, self.n_kv_heads * head_dim, bias=args.attention_bias)
|
||||||
|
self.o_proj = nn.Linear(self.n_heads * head_dim, dim, bias=args.attention_bias)
|
||||||
|
|
||||||
|
self.sliding_window = args.sliding_window # Not yet implemented :(
|
||||||
|
self.use_qk_norm = False # Assuming QK norm not used by Cohere2 (adjust if needed)
|
||||||
|
|
||||||
|
# Initialize RoPE for Cohere2
|
||||||
|
self.rope = initialize_rope(
|
||||||
|
dims=head_dim,
|
||||||
|
base=args.rope_theta,
|
||||||
|
traditional=True,
|
||||||
|
max_position_embeddings=args.max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, rope = True) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
# Apply RoPE
|
||||||
|
# In Cohere2, the original code applies RoPE before caching updates. We replicate that:
|
||||||
|
if cache is not None:
|
||||||
|
if rope:
|
||||||
|
q = self.rope(q, offset=cache.offset)
|
||||||
|
k = self.rope(k, offset=cache.offset)
|
||||||
|
k, v = cache.update_and_fetch(k, v)
|
||||||
|
if rope:
|
||||||
|
k = k[:, :, -self.sliding_window:, :]
|
||||||
|
v = v[:, :, -self.sliding_window:, :]
|
||||||
|
elif rope:
|
||||||
|
q = self.rope(q)
|
||||||
|
k = self.rope(k)
|
||||||
|
# Compute attention
|
||||||
|
out = scaled_dot_product_attention(
|
||||||
|
q, k, v, cache=cache, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
out = out.transpose(0, 2, 1, 3).reshape(B, L, D)
|
||||||
|
return self.o_proj(out)
|
||||||
|
|
||||||
|
|
||||||
|
class Cohere2MLP(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
dim = args.hidden_size
|
||||||
|
hdim = args.intermediate_size
|
||||||
|
self.gate_proj = nn.Linear(dim, hdim, bias=False)
|
||||||
|
self.up_proj = nn.Linear(dim, hdim, bias=False)
|
||||||
|
self.down_proj = nn.Linear(hdim, dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class Cohere2TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = Cohere2Attention(args)
|
||||||
|
self.mlp = Cohere2MLP(args)
|
||||||
|
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.layer_norm_eps, affine=True, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, rope = True) -> mx.array:
|
||||||
|
h = self.input_layernorm(x)
|
||||||
|
attn_h = self.self_attn(h, mask, cache, rope=rope)
|
||||||
|
ff_h = self.mlp(h)
|
||||||
|
return x + attn_h + ff_h
|
||||||
|
|
||||||
|
|
||||||
|
class Cohere2Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [Cohere2TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
||||||
|
self.norm = nn.LayerNorm(args.hidden_size, eps=args.layer_norm_eps, affine=True, bias=False)
|
||||||
|
self.sliding_window = args.sliding_window
|
||||||
|
self.sliding_window_pattern = args.sliding_window_pattern
|
||||||
|
def __call__(self, inputs: mx.array, cache: Optional[Any] = None) -> mx.array:
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
mask = create_attention_mask(h, cache, reference_cache_idx=self.sliding_window_pattern - 1)
|
||||||
|
sliding_window_mask = mask[:, -self.sliding_window:] if mask is not None else None
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
for i, (layer, c) in enumerate(zip(self.layers, cache)):
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
index = i % self.sliding_window_pattern
|
||||||
|
if index < self.sliding_window_pattern - 1:
|
||||||
|
h = layer(h, mask=sliding_window_mask, cache=c)
|
||||||
|
else:
|
||||||
|
h = layer(h, mask=mask, cache=c, rope=False)
|
||||||
|
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = args.model_type
|
||||||
|
|
||||||
|
self.model = Cohere2Model(args)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(self, inputs: mx.array, cache=None):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
out = self.model.embed_tokens.as_linear(out) * self.args.logit_scale
|
||||||
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
def make_cache(self):
|
||||||
|
caches = []
|
||||||
|
for i in range(self.args.num_hidden_layers):
|
||||||
|
if i % self.args.sliding_window_pattern == self.args.sliding_window_pattern - 1:
|
||||||
|
caches.append(KVCache())
|
||||||
|
else:
|
||||||
|
caches.append(RotatingKVCache(max_size=self.args.sliding_window, keep=0))
|
||||||
|
return caches
|
@ -187,9 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
|||||||
and prompt_cache[0].offset > quantized_kv_start
|
and prompt_cache[0].offset > quantized_kv_start
|
||||||
):
|
):
|
||||||
for i in range(len(prompt_cache)):
|
for i in range(len(prompt_cache)):
|
||||||
prompt_cache[i] = prompt_cache[i].to_quantized(
|
if isinstance(prompt_cache[i], cache.KVCache):
|
||||||
group_size=kv_group_size, bits=kv_bits
|
prompt_cache[i] = prompt_cache[i].to_quantized(
|
||||||
)
|
group_size=kv_group_size, bits=kv_bits
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_step(
|
def generate_step(
|
||||||
@ -403,6 +404,7 @@ def generate(
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
formatter: Optional[Callable] = None,
|
formatter: Optional[Callable] = None,
|
||||||
|
stop_strings: Optional[List[str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@ -431,6 +433,8 @@ def generate(
|
|||||||
if verbose:
|
if verbose:
|
||||||
print(response.text, end="", flush=True)
|
print(response.text, end="", flush=True)
|
||||||
text += response.text
|
text += response.text
|
||||||
|
if stop_strings is not None and any(s in text for s in stop_strings):
|
||||||
|
break
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print()
|
print()
|
||||||
@ -865,3 +869,226 @@ def convert(
|
|||||||
|
|
||||||
if upload_repo is not None:
|
if upload_repo is not None:
|
||||||
upload_to_hub(mlx_path, upload_repo, hf_path)
|
upload_to_hub(mlx_path, upload_repo, hf_path)
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def generate_batched_response(
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
|
prompt: Union[str, mx.array, List[int]],
|
||||||
|
batch_size: int,
|
||||||
|
max_tokens: int = 256,
|
||||||
|
sampler: Optional[Callable[[mx.array], mx.array]] = None,
|
||||||
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
|
max_kv_size: Optional[int] = None,
|
||||||
|
prompt_cache: Optional[List[Any]] = None,
|
||||||
|
prefill_step_size: int = 512,
|
||||||
|
kv_bits: Optional[int] = None,
|
||||||
|
kv_group_size: int = 64,
|
||||||
|
quantized_kv_start: int = 0,
|
||||||
|
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||||
|
temp: Optional[float] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
repetition_context_size: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
min_p: Optional[float] = None,
|
||||||
|
min_tokens_to_keep: Optional[int] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Generate multiple responses to the same prompt in parallel and return only the generated
|
||||||
|
sequences (excluding the prompt), stopping at the first EOS token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The language model.
|
||||||
|
tokenizer (PreTrainedTokenizer or TokenizerWrapper): The tokenizer.
|
||||||
|
prompt (Union[str, mx.array, List[int]]): The input prompt.
|
||||||
|
batch_size (int): Number of responses to generate in parallel.
|
||||||
|
max_tokens (int): Maximum number of generated tokens per sequence.
|
||||||
|
sampler (Callable): Sampler function.
|
||||||
|
logits_processors (List[Callable]): List of logits processors.
|
||||||
|
max_kv_size (int): Maximum KV cache size.
|
||||||
|
prompt_cache (List[Any]): Precomputed prompt cache.
|
||||||
|
prefill_step_size (int): Step size for prompt processing.
|
||||||
|
kv_bits (int): Bits for KV cache quantization.
|
||||||
|
kv_group_size (int): Group size for KV quantization.
|
||||||
|
quantized_kv_start (int): Step to begin quantizing KV.
|
||||||
|
prompt_progress_callback (Callable): Callback for prompt progress.
|
||||||
|
temp (float): Temperature for sampling (deprecated, pass to sampler).
|
||||||
|
repetition_penalty (float): Repetition penalty (deprecated, use logits_processors).
|
||||||
|
repetition_context_size (int): Context size for repetition.
|
||||||
|
top_p (float): Top-p sampling (deprecated, pass to sampler).
|
||||||
|
min_p (float): Minimum p sampling (deprecated, pass to sampler).
|
||||||
|
min_tokens_to_keep (int): Minimum number of tokens to keep.
|
||||||
|
verbose (bool): If True, show a progress bar for token generation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of decoded response strings for each batch element, excluding the prompt
|
||||||
|
and stopping at the first EOS token.
|
||||||
|
"""
|
||||||
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
|
|
||||||
|
# Convert prompt to tokens if necessary
|
||||||
|
if not isinstance(prompt, mx.array):
|
||||||
|
prompt = mx.array(
|
||||||
|
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expand prompt to batch
|
||||||
|
prompt_length = prompt.size
|
||||||
|
prompt = mx.expand_dims(prompt, 0) # (1, prompt_length)
|
||||||
|
prompt = mx.repeat(prompt, batch_size, axis=0) # (B, prompt_length)
|
||||||
|
B = batch_size
|
||||||
|
|
||||||
|
if prompt_progress_callback is None:
|
||||||
|
prompt_progress_callback = lambda *_: None
|
||||||
|
|
||||||
|
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
|
||||||
|
print(
|
||||||
|
"[Warning] Specifying sampling arguments directly is deprecated. "
|
||||||
|
"Pass in a `sampler` if needed."
|
||||||
|
)
|
||||||
|
if repetition_penalty is not None:
|
||||||
|
print(
|
||||||
|
"[Warning] Specifying `repetition_penalty` is deprecated. "
|
||||||
|
"Use `logits_processors` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler = sampler or make_sampler(
|
||||||
|
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
|
||||||
|
)
|
||||||
|
logits_processors = logits_processors or make_logits_processors(
|
||||||
|
None, repetition_penalty, repetition_context_size or 20
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create or verify prompt cache
|
||||||
|
if prompt_cache is None:
|
||||||
|
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
|
||||||
|
elif len(prompt_cache) != len(model.layers):
|
||||||
|
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||||
|
|
||||||
|
# Process the prompt to fill the cache in increments
|
||||||
|
total_prompt_tokens = prompt_length
|
||||||
|
prompt_processed_tokens = 0
|
||||||
|
remaining_prompt = prompt
|
||||||
|
tic = time.perf_counter()
|
||||||
|
with mx.stream(generation_stream):
|
||||||
|
while remaining_prompt.shape[1] > prefill_step_size:
|
||||||
|
model(remaining_prompt[:, :prefill_step_size], cache=prompt_cache)
|
||||||
|
mx.eval([c.state for c in prompt_cache])
|
||||||
|
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
||||||
|
prompt_processed_tokens += prefill_step_size
|
||||||
|
remaining_prompt = remaining_prompt[:, prefill_step_size:]
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
|
# Process any remaining prompt tokens
|
||||||
|
if remaining_prompt.shape[1] > 0:
|
||||||
|
model(remaining_prompt, cache=prompt_cache)
|
||||||
|
mx.eval([c.state for c in prompt_cache])
|
||||||
|
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||||
|
|
||||||
|
prompt_time = time.perf_counter() - tic
|
||||||
|
prompt_tps = (total_prompt_tokens * B) / prompt_time
|
||||||
|
|
||||||
|
# Initialization for generation
|
||||||
|
tokens = prompt
|
||||||
|
finished = mx.zeros((B,), dtype=tokens.dtype)
|
||||||
|
generation_count = 0
|
||||||
|
eos_ids = tokenizer.eos_token_ids
|
||||||
|
|
||||||
|
# Setup progress bar if verbose
|
||||||
|
pbar = None
|
||||||
|
if verbose:
|
||||||
|
if max_tokens >= 0:
|
||||||
|
pbar = tqdm(total=max_tokens, desc="Generating tokens", ncols=80)
|
||||||
|
else:
|
||||||
|
# If we don't have a max_tokens limit, no total is known.
|
||||||
|
# We'll just display a progress bar that counts up.
|
||||||
|
pbar = tqdm(desc="Generating tokens", ncols=80)
|
||||||
|
|
||||||
|
tic = time.perf_counter()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if (max_tokens >= 0) and (generation_count >= max_tokens):
|
||||||
|
break
|
||||||
|
|
||||||
|
# If all sequences finished, break
|
||||||
|
sum_finished = mx.sum(finished)
|
||||||
|
mx.eval(sum_finished)
|
||||||
|
if sum_finished.item() == B:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Prepare last token
|
||||||
|
next_input = tokens[:, -1:] # (B,1)
|
||||||
|
with mx.stream(generation_stream):
|
||||||
|
logits = model(next_input, cache=prompt_cache)
|
||||||
|
# logits: (B, 1, vocab)
|
||||||
|
logits = logits[:, -1, :] # (B, vocab)
|
||||||
|
|
||||||
|
# Apply logits processors
|
||||||
|
if logits_processors:
|
||||||
|
for processor in logits_processors:
|
||||||
|
logits = processor(tokens, logits)
|
||||||
|
|
||||||
|
maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits)
|
||||||
|
|
||||||
|
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) # (B,vocab)
|
||||||
|
sampled_tokens = sampler(logprobs) # (B,)
|
||||||
|
|
||||||
|
mx.async_eval(sampled_tokens, logprobs)
|
||||||
|
|
||||||
|
# Check EOS
|
||||||
|
is_eos = mx.zeros_like(sampled_tokens).astype(tokens.dtype)
|
||||||
|
for eid in eos_ids:
|
||||||
|
diff = sampled_tokens - eid
|
||||||
|
sq = diff * diff
|
||||||
|
val = 1.0 / (sq + 1.0)
|
||||||
|
mask = val.astype(tokens.dtype)
|
||||||
|
is_eos = is_eos + mask
|
||||||
|
|
||||||
|
ones = mx.ones_like(is_eos)
|
||||||
|
is_eos = mx.minimum(is_eos, ones)
|
||||||
|
finished = mx.maximum(finished, is_eos)
|
||||||
|
|
||||||
|
sampled_tokens = sampled_tokens[:, None] # (B,1)
|
||||||
|
tokens = mx.concatenate([tokens, sampled_tokens], axis=1)
|
||||||
|
|
||||||
|
generation_count += 1
|
||||||
|
if pbar is not None:
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
if (generation_count % 256) == 0:
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
|
if pbar is not None:
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
generation_time = time.perf_counter() - tic
|
||||||
|
generation_tps = (generation_count * B) / generation_time if generation_count > 0 else 0.0
|
||||||
|
peak_memory = mx.metal.get_peak_memory() / 1e9
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in range(B):
|
||||||
|
seq = tokens[i][prompt_length:].tolist() # Exclude the prompt
|
||||||
|
# Find the first EOS token
|
||||||
|
eos_pos = None
|
||||||
|
for idx, t in enumerate(seq):
|
||||||
|
if t in eos_ids:
|
||||||
|
eos_pos = idx
|
||||||
|
break
|
||||||
|
# Slice up to EOS if found
|
||||||
|
if eos_pos is not None:
|
||||||
|
seq = seq[:eos_pos]
|
||||||
|
text = tokenizer.decode(seq)
|
||||||
|
results.append(text)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print("=" * 10)
|
||||||
|
print(f"Prompt: {total_prompt_tokens} tokens * {B} sequences, {prompt_tps:.3f} tps")
|
||||||
|
print(
|
||||||
|
f"Generation: {generation_count} tokens * {B} sequences, "
|
||||||
|
f"{generation_tps:.3f} tps"
|
||||||
|
)
|
||||||
|
print(f"Peak memory: {peak_memory:.3f} GB")
|
||||||
|
|
||||||
|
return results
|
||||||
|
Loading…
Reference in New Issue
Block a user