diff --git a/llava/generate.py b/llava/generate.py index 8067839e..64313858 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -79,10 +79,10 @@ def load_image(image_source): def prepare_inputs(processor, image, prompt): if isinstance(image, str): image = load_image(image) - inputs = processor(prompt, image, return_tensors="np") + inputs = processor(image, prompt, return_tensors="np") pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) - return input_ids, pixel_values + return pixel_values, input_ids def load_model(model_path, tokenizer_config={}): @@ -126,8 +126,7 @@ def main(): processor, model = load_model(args.model, tokenizer_config) prompt = codecs.decode(args.prompt, "unicode_escape") - - input_ids, pixel_values = prepare_inputs(processor, args.image, prompt) + pixel_values, input_ids = prepare_inputs(processor, args.image, prompt) print(prompt) generated_text = generate_text( diff --git a/llava/llava.py b/llava/llava.py index 9e6b7511..c5f190f8 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -104,31 +104,21 @@ class LlavaModel(nn.Module): self, image_features, inputs_embeds, input_ids ): image_token_index = self.config.image_token_index - num_images, num_image_patches, embed_dim = image_features.shape + batch_size, num_image_patches, embed_dim = image_features.shape # Positions of tokens in input_ids, assuming batch size is 1 - image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + image_positions = mx.array( + np.where(input_ids[0] == image_token_index)[0], mx.uint32 + ) - if len(image_positions) != num_images: + if len(image_positions) != num_image_patches: raise ValueError( f"The number of image tokens ({len(image_positions)}) does not " - f" match the number of image inputs ({num_images})." + f" match the number of image patches ({num_image_patches})." ) - text_segments = [] - start_idx = 0 - - for position in image_positions: - text_segments.append(inputs_embeds[:, start_idx:position]) - start_idx = position + 1 - - image_embeddings = mx.split(image_features, image_features.shape[0]) - final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] - final_embeddings += [inputs_embeds[:, start_idx:]] - - # Create a final embedding of shape - # (1, num_image_patches*num_images + sequence_len, embed_dim) - return mx.concatenate(final_embeddings, axis=1) + inputs_embeds[0, image_positions] = image_features + return inputs_embeds def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): input_embddings = self.get_input_embeddings(input_ids, pixel_values) diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 0f885fba..3af2d5fd 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.20.2" +__version__ = "0.20.4" diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py index 423d5823..c4b15748 100644 --- a/llms/mlx_lm/evaluate.py +++ b/llms/mlx_lm/evaluate.py @@ -32,7 +32,7 @@ def _len_longest_common_prefix(a, b): def _rstrip_until(s, untils): - """Limit a string to the first occurence of any substring in untils.""" + """Limit a string to the first occurrence of any substring in untils.""" l = len(s) f = [s.find(u) for u in untils] f = [l if x < 0 else x for x in f] diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0c1b4acd..84dc63ca 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -1,6 +1,7 @@ # Copyright © 2023-2024 Apple Inc. import argparse +import codecs import json import sys @@ -188,6 +189,8 @@ def main(): elif using_cache: tokenizer.chat_template = metadata["chat_template"] + prompt = codecs.decode(args.prompt, "unicode_escape") + if not args.ignore_chat_template and ( hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None @@ -199,7 +202,7 @@ def main(): messages.append( { "role": "user", - "content": sys.stdin.read() if args.prompt == "-" else args.prompt, + "content": sys.stdin.read() if prompt == "-" else prompt, } ) prompt = tokenizer.apply_chat_template( @@ -216,8 +219,6 @@ def main(): add_generation_prompt=True, ) prompt = prompt[test_prompt.index("") :] - else: - prompt = args.prompt sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py new file mode 100644 index 00000000..fcb4061b --- /dev/null +++ b/llms/mlx_lm/models/cohere2.py @@ -0,0 +1,207 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention +from .cache import KVCache, RotatingKVCache + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int = 4096 + head_dim: int = 128 + num_hidden_layers: int = 32 + intermediate_size: int = 14336 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + rope_theta: float = 50000.0 + vocab_size: int = 256000 + layer_norm_eps: float = 1e-05 + logit_scale: float = 0.0625 + attention_bias: bool = False + layer_norm_bias: bool = False + sliding_window: int = 4096 + sliding_window_pattern: int = 4 + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.args = args + self.layer_idx = layer_idx + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.head_dim = head_dim = args.head_dim + if (head_dim * n_heads) != dim: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}" + f" and `num_heads`: {n_heads})." + ) + self.scale = head_dim**-0.5 + + attetion_bias = args.attention_bias + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias) + + self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) + + self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0 + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + # Apply RoPE only if sliding window is enabled + if self.use_sliding_window: + if cache is None: + queries = self.rope(queries) + keys = self.rope(keys) + else: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + if self.use_sliding_window and mask is not None: + key_len = keys.shape[-2] + if mask.shape[-1] != key_len: + mask = mask[..., -key_len:] + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + + def __call__(self, x): + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.hidden_size = args.hidden_size + self.n_heads = args.num_attention_heads + + self.self_attn = Attention(args, layer_idx) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.LayerNorm( + args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + h = self.input_layernorm(x) + attn_h = self.self_attn(h, mask, cache) + ff_h = self.mlp(h) + return attn_h + ff_h + x + + +class CohereModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args, layer_idx=i) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.LayerNorm( + args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias + ) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + T = h.shape[1] + if T > 1: + offset = cache[0].offset if cache else 0 + mask = create_causal_mask(T, offset).astype(h.dtype) + else: + mask = None + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model_type = args.model_type + self.model = CohereModel(args) + self.args = args + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + out = self.model.embed_tokens.as_linear(out) + out = out * self.model.args.logit_scale + return out + + def make_cache(self): + caches = [] + for i in range(self.args.num_hidden_layers): + if ( + i % self.args.sliding_window_pattern + == self.args.sliding_window_pattern - 1 + ): + caches.append(KVCache()) + else: + caches.append( + RotatingKVCache(max_size=self.args.sliding_window, keep=0) + ) + return caches + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index f9868422..c77f056a 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -190,7 +190,7 @@ def make_repetition_penalty(penalty: float, context_size: int = 20): Callable[[mx.array, List[int]], mx.array]: The repetition penalty processor. """ - if penalty < 0 or not isinstance(penalty, float): + if penalty < 0 or not isinstance(penalty, (int, float)): raise ValueError(f"penalty must be a non-negative float, got {penalty}") def repetition_penalty_processor(tokens, logits): diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index ce09cf45..c12513ff 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -465,7 +465,7 @@ class APIHandler(BaseHTTPRequestHandler): text = "" tic = time.perf_counter() - sampler = make_sampler(self.temperature) + sampler = make_sampler(self.temperature, top_p=self.top_p) logits_processors = make_logits_processors( self.logit_bias, self.repetition_penalty, self.repetition_context_size ) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 10a257f6..8251e62f 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -3,8 +3,6 @@ from functools import partial from transformers import AutoTokenizer -REPLACEMENT_CHAR = "\ufffd" - class StreamingDetokenizer: """The streaming detokenizer interface so that we can detokenize one token at a time. @@ -51,11 +49,9 @@ class StreamingDetokenizer: def last_segment(self): """Return the last segment of readable text since last time this property was accessed.""" text = self.text - if text and text[-1] != REPLACEMENT_CHAR: - segment = text[self.offset :] - self.offset = len(text) - return segment - return "" + segment = text[self.offset :] + self.offset = len(text) + return segment class NaiveStreamingDetokenizer(StreamingDetokenizer): @@ -132,7 +128,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): self.tokens = [] def _flush(self): - text = self._unflushed.replace(self._sep, b" ").decode("utf-8") + text = self._unflushed.replace(self._sep, b" ").decode("utf-8", "replace") if not self.text and self.trim_space and text and text[0] == " ": text = text[1:] self.text += text @@ -199,22 +195,21 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): self.tokens.append(token) v = self.tokenmap[token] is_added = token in self._added_ids - if is_added or self._byte_decoder[v[0]] == 32: - current_text = bytearray( - self._byte_decoder[c] for c in self._unflushed - ).decode("utf-8") - self.text += self._maybe_trim_space(current_text) - if is_added: - self.text += v - self._unflushed = "" - else: - self._unflushed = v - else: + if not is_added: self._unflushed += v + text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( + "utf-8", "replace" + ) + if is_added: + text += v + if not text.endswith("\ufffd"): + self.text += self._maybe_trim_space(text) + self._unflushed = "" def finalize(self): current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( - "utf-8" + "utf-8", + "replace", ) self.text += self._maybe_trim_space(current_text) self._unflushed = "" diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 213bcad7..e5d0b975 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -96,6 +96,7 @@ def linear_to_lora_layers( "gemma2", "starcoder2", "cohere", + "cohere2", "minicpm", "deepseek", "olmo2", diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d81bb66a..4d69115e 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten, tree_map, tree_reduce +from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer # Local imports @@ -59,6 +59,7 @@ class GenerationResponse: generation_tokens (int): The number of generated tokens. generation_tps (float): The tokens-per-second for generation. peak_memory (float): The peak memory used so far in GB. + finish_reason (str): The reason the response is being sent: "length", "stop" or `None` """ text: str @@ -69,6 +70,7 @@ class GenerationResponse: generation_tokens: int generation_tps: float peak_memory: float + finish_reason: Optional[str] = None @contextlib.contextmanager @@ -185,9 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ and prompt_cache[0].offset > quantized_kv_start ): for i in range(len(prompt_cache)): - prompt_cache[i] = prompt_cache[i].to_quantized( - group_size=kv_group_size, bits=kv_bits - ) + if isinstance(prompt_cache[i], cache.KVCache): + prompt_cache[i] = prompt_cache[i].to_quantized( + group_size=kv_group_size, bits=kv_bits + ) def generate_step( @@ -297,6 +300,9 @@ def generate_step( prompt_processed_tokens = 0 while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) + maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start, kv_group_size, kv_bits + ) mx.eval([c.state for c in prompt_cache]) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) prompt_processed_tokens += prefill_step_size @@ -375,6 +381,7 @@ def stream_generate( generation_tokens=n + 1, generation_tps=(n + 1) / (time.perf_counter() - tic), peak_memory=mx.metal.get_peak_memory() / 1e9, + finish_reason=None, ) detokenizer.finalize() @@ -387,6 +394,7 @@ def stream_generate( generation_tokens=n + 1, generation_tps=(n + 1) / (time.perf_counter() - tic), peak_memory=mx.metal.get_peak_memory() / 1e9, + finish_reason="stop" if token in tokenizer.eos_token_ids else "length", ) diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 374a5113..3097c522 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -851,6 +851,22 @@ class TestModels(unittest.TestCase): model = exaone.Model(args) self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers) + def test_cohere2(self): + from mlx_lm.models import cohere2 + + args = cohere2.ModelArgs( + model_type="cohere2", + hidden_size=4096, + head_dim=128, + num_hidden_layers=40, + sliding_window=4096, + sliding_window_pattern=4, + ) + model = cohere2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + if __name__ == "__main__": unittest.main()