Add Cohere2

This commit is contained in:
N8 2024-12-14 15:17:12 -05:00
parent 9f2ea5892e
commit 75fbb7ed34
4 changed files with 411 additions and 8 deletions

View File

@ -34,18 +34,22 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
def create_attention_mask(h: mx.array, cache: Optional[Any] = None, reference_cache_idx: Optional[int] = None) -> mx.array:
T = h.shape[1]
if T > 1:
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
c = cache[0]
if reference_cache_idx is not None:
c = cache[reference_cache_idx]
else:
c = cache[0]
if hasattr(c, "max_size"):
offset = min(c.max_size, c.offset)
window_size = c.max_size
else:
offset = c.offset
mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:

View File

@ -6,7 +6,6 @@ import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
@ -33,7 +32,7 @@ def make_prompt_cache(
]
else:
return [KVCache() for _ in range(num_layers)]
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
"""
@ -264,6 +263,13 @@ class KVCache(_BaseCache):
n = min(self.offset, n)
self.offset -= n
return n
def trim_from_behind(self, n):
old_size = self.keys.shape[2]
self.keys = self.keys[..., -n:, :]
self.values = self.values[..., -n:, :]
new_size = self.keys.shape[2]
trimmed = old_size - new_size
self.offset -= trimmed
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
@ -416,7 +422,8 @@ class RotatingKVCache(_BaseCache):
return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
raise NotImplementedError("RotatingKVCache Quantization NYI")
return self
#raise NotImplementedError("RotatingKVCache Quantization NYI")
class MambaCache(_BaseCache):

View File

@ -0,0 +1,165 @@
from dataclasses import dataclass
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
from .cache import KVCache, RotatingKVCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int
rope_theta: float
vocab_size: int
layer_norm_eps: float
logit_scale: float
attention_bias: bool
# Additional Cohere2-specific arguments:
# rope_type and max_position_embeddings might influence the rope setup
rope_type: str = "default"
max_position_embeddings: int = 2048
sliding_window: Optional[int] = None,
sliding_window_pattern: Optional[int] = None,
order_of_interleaved_layers: Optional[int] = None,
use_cache: bool = True
class Cohere2Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = args.num_attention_heads
self.n_kv_heads = args.num_key_value_heads
head_dim = dim // self.n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, self.n_heads * head_dim, bias=args.attention_bias)
self.k_proj = nn.Linear(dim, self.n_kv_heads * head_dim, bias=args.attention_bias)
self.v_proj = nn.Linear(dim, self.n_kv_heads * head_dim, bias=args.attention_bias)
self.o_proj = nn.Linear(self.n_heads * head_dim, dim, bias=args.attention_bias)
self.sliding_window = args.sliding_window # Not yet implemented :(
self.use_qk_norm = False # Assuming QK norm not used by Cohere2 (adjust if needed)
# Initialize RoPE for Cohere2
self.rope = initialize_rope(
dims=head_dim,
base=args.rope_theta,
traditional=True,
max_position_embeddings=args.max_position_embeddings,
)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, rope = True) -> mx.array:
B, L, D = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
# Apply RoPE
# In Cohere2, the original code applies RoPE before caching updates. We replicate that:
if cache is not None:
if rope:
q = self.rope(q, offset=cache.offset)
k = self.rope(k, offset=cache.offset)
k, v = cache.update_and_fetch(k, v)
if rope:
k = k[:, :, -self.sliding_window:, :]
v = v[:, :, -self.sliding_window:, :]
elif rope:
q = self.rope(q)
k = self.rope(k)
# Compute attention
out = scaled_dot_product_attention(
q, k, v, cache=cache, scale=self.scale, mask=mask
)
out = out.transpose(0, 2, 1, 3).reshape(B, L, D)
return self.o_proj(out)
class Cohere2MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hdim = args.intermediate_size
self.gate_proj = nn.Linear(dim, hdim, bias=False)
self.up_proj = nn.Linear(dim, hdim, bias=False)
self.down_proj = nn.Linear(hdim, dim, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class Cohere2TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.self_attn = Cohere2Attention(args)
self.mlp = Cohere2MLP(args)
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.layer_norm_eps, affine=True, bias=False)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, rope = True) -> mx.array:
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache, rope=rope)
ff_h = self.mlp(h)
return x + attn_h + ff_h
class Cohere2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [Cohere2TransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.LayerNorm(args.hidden_size, eps=args.layer_norm_eps, affine=True, bias=False)
self.sliding_window = args.sliding_window
self.sliding_window_pattern = args.sliding_window_pattern
def __call__(self, inputs: mx.array, cache: Optional[Any] = None) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache, reference_cache_idx=self.sliding_window_pattern - 1)
sliding_window_mask = mask[:, -self.sliding_window:] if mask is not None else None
if cache is None:
cache = [None] * len(self.layers)
for i, (layer, c) in enumerate(zip(self.layers, cache)):
if self.sliding_window is not None:
index = i % self.sliding_window_pattern
if index < self.sliding_window_pattern - 1:
h = layer(h, mask=sliding_window_mask, cache=c)
else:
h = layer(h, mask=mask, cache=c, rope=False)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = Cohere2Model(args)
self.args = args
def __call__(self, inputs: mx.array, cache=None):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out) * self.args.logit_scale
return out
@property
def layers(self):
return self.model.layers
def make_cache(self):
caches = []
for i in range(self.args.num_hidden_layers):
if i % self.args.sliding_window_pattern == self.args.sliding_window_pattern - 1:
caches.append(KVCache())
else:
caches.append(RotatingKVCache(max_size=self.args.sliding_window, keep=0))
return caches

View File

@ -187,9 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
and prompt_cache[0].offset > quantized_kv_start
):
for i in range(len(prompt_cache)):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)
if isinstance(prompt_cache[i], cache.KVCache):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)
def generate_step(
@ -403,6 +404,7 @@ def generate(
prompt: str,
verbose: bool = False,
formatter: Optional[Callable] = None,
stop_strings: Optional[List[str]] = None,
**kwargs,
) -> str:
"""
@ -431,6 +433,8 @@ def generate(
if verbose:
print(response.text, end="", flush=True)
text += response.text
if stop_strings is not None and any(s in text for s in stop_strings):
break
if verbose:
print()
@ -865,3 +869,226 @@ def convert(
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)
from tqdm import tqdm
def generate_batched_response(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]],
batch_size: int,
max_tokens: int = 256,
sampler: Optional[Callable[[mx.array], mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
max_kv_size: Optional[int] = None,
prompt_cache: Optional[List[Any]] = None,
prefill_step_size: int = 512,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
min_tokens_to_keep: Optional[int] = None,
verbose: bool = False,
) -> List[str]:
"""
Generate multiple responses to the same prompt in parallel and return only the generated
sequences (excluding the prompt), stopping at the first EOS token.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer or TokenizerWrapper): The tokenizer.
prompt (Union[str, mx.array, List[int]]): The input prompt.
batch_size (int): Number of responses to generate in parallel.
max_tokens (int): Maximum number of generated tokens per sequence.
sampler (Callable): Sampler function.
logits_processors (List[Callable]): List of logits processors.
max_kv_size (int): Maximum KV cache size.
prompt_cache (List[Any]): Precomputed prompt cache.
prefill_step_size (int): Step size for prompt processing.
kv_bits (int): Bits for KV cache quantization.
kv_group_size (int): Group size for KV quantization.
quantized_kv_start (int): Step to begin quantizing KV.
prompt_progress_callback (Callable): Callback for prompt progress.
temp (float): Temperature for sampling (deprecated, pass to sampler).
repetition_penalty (float): Repetition penalty (deprecated, use logits_processors).
repetition_context_size (int): Context size for repetition.
top_p (float): Top-p sampling (deprecated, pass to sampler).
min_p (float): Minimum p sampling (deprecated, pass to sampler).
min_tokens_to_keep (int): Minimum number of tokens to keep.
verbose (bool): If True, show a progress bar for token generation.
Returns:
List[str]: A list of decoded response strings for each batch element, excluding the prompt
and stopping at the first EOS token.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
# Convert prompt to tokens if necessary
if not isinstance(prompt, mx.array):
prompt = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
# Expand prompt to batch
prompt_length = prompt.size
prompt = mx.expand_dims(prompt, 0) # (1, prompt_length)
prompt = mx.repeat(prompt, batch_size, axis=0) # (B, prompt_length)
B = batch_size
if prompt_progress_callback is None:
prompt_progress_callback = lambda *_: None
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
print(
"[Warning] Specifying sampling arguments directly is deprecated. "
"Pass in a `sampler` if needed."
)
if repetition_penalty is not None:
print(
"[Warning] Specifying `repetition_penalty` is deprecated. "
"Use `logits_processors` instead."
)
sampler = sampler or make_sampler(
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
)
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
)
# Create or verify prompt cache
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
# Process the prompt to fill the cache in increments
total_prompt_tokens = prompt_length
prompt_processed_tokens = 0
remaining_prompt = prompt
tic = time.perf_counter()
with mx.stream(generation_stream):
while remaining_prompt.shape[1] > prefill_step_size:
model(remaining_prompt[:, :prefill_step_size], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size
remaining_prompt = remaining_prompt[:, prefill_step_size:]
mx.metal.clear_cache()
# Process any remaining prompt tokens
if remaining_prompt.shape[1] > 0:
model(remaining_prompt, cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
prompt_time = time.perf_counter() - tic
prompt_tps = (total_prompt_tokens * B) / prompt_time
# Initialization for generation
tokens = prompt
finished = mx.zeros((B,), dtype=tokens.dtype)
generation_count = 0
eos_ids = tokenizer.eos_token_ids
# Setup progress bar if verbose
pbar = None
if verbose:
if max_tokens >= 0:
pbar = tqdm(total=max_tokens, desc="Generating tokens", ncols=80)
else:
# If we don't have a max_tokens limit, no total is known.
# We'll just display a progress bar that counts up.
pbar = tqdm(desc="Generating tokens", ncols=80)
tic = time.perf_counter()
while True:
if (max_tokens >= 0) and (generation_count >= max_tokens):
break
# If all sequences finished, break
sum_finished = mx.sum(finished)
mx.eval(sum_finished)
if sum_finished.item() == B:
break
# Prepare last token
next_input = tokens[:, -1:] # (B,1)
with mx.stream(generation_stream):
logits = model(next_input, cache=prompt_cache)
# logits: (B, 1, vocab)
logits = logits[:, -1, :] # (B, vocab)
# Apply logits processors
if logits_processors:
for processor in logits_processors:
logits = processor(tokens, logits)
maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits)
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) # (B,vocab)
sampled_tokens = sampler(logprobs) # (B,)
mx.async_eval(sampled_tokens, logprobs)
# Check EOS
is_eos = mx.zeros_like(sampled_tokens).astype(tokens.dtype)
for eid in eos_ids:
diff = sampled_tokens - eid
sq = diff * diff
val = 1.0 / (sq + 1.0)
mask = val.astype(tokens.dtype)
is_eos = is_eos + mask
ones = mx.ones_like(is_eos)
is_eos = mx.minimum(is_eos, ones)
finished = mx.maximum(finished, is_eos)
sampled_tokens = sampled_tokens[:, None] # (B,1)
tokens = mx.concatenate([tokens, sampled_tokens], axis=1)
generation_count += 1
if pbar is not None:
pbar.update(1)
if (generation_count % 256) == 0:
mx.metal.clear_cache()
if pbar is not None:
pbar.close()
generation_time = time.perf_counter() - tic
generation_tps = (generation_count * B) / generation_time if generation_count > 0 else 0.0
peak_memory = mx.metal.get_peak_memory() / 1e9
results = []
for i in range(B):
seq = tokens[i][prompt_length:].tolist() # Exclude the prompt
# Find the first EOS token
eos_pos = None
for idx, t in enumerate(seq):
if t in eos_ids:
eos_pos = idx
break
# Slice up to EOS if found
if eos_pos is not None:
seq = seq[:eos_pos]
text = tokenizer.decode(seq)
results.append(text)
if verbose:
print("=" * 10)
print(f"Prompt: {total_prompt_tokens} tokens * {B} sequences, {prompt_tps:.3f} tps")
print(
f"Generation: {generation_count} tokens * {B} sequences, "
f"{generation_tps:.3f} tps"
)
print(f"Peak memory: {peak_memory:.3f} GB")
return results