mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
Merge branch 'ml-explore:main' into adding-support-for-mamba2
This commit is contained in:
commit
68533e2a8f
@ -79,10 +79,10 @@ def load_image(image_source):
|
||||
def prepare_inputs(processor, image, prompt):
|
||||
if isinstance(image, str):
|
||||
image = load_image(image)
|
||||
inputs = processor(prompt, image, return_tensors="np")
|
||||
inputs = processor(image, prompt, return_tensors="np")
|
||||
pixel_values = mx.array(inputs["pixel_values"])
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
return input_ids, pixel_values
|
||||
return pixel_values, input_ids
|
||||
|
||||
|
||||
def load_model(model_path, tokenizer_config={}):
|
||||
@ -126,8 +126,7 @@ def main():
|
||||
processor, model = load_model(args.model, tokenizer_config)
|
||||
|
||||
prompt = codecs.decode(args.prompt, "unicode_escape")
|
||||
|
||||
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
|
||||
pixel_values, input_ids = prepare_inputs(processor, args.image, prompt)
|
||||
|
||||
print(prompt)
|
||||
generated_text = generate_text(
|
||||
|
@ -104,31 +104,21 @@ class LlavaModel(nn.Module):
|
||||
self, image_features, inputs_embeds, input_ids
|
||||
):
|
||||
image_token_index = self.config.image_token_index
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, num_image_patches, embed_dim = image_features.shape
|
||||
|
||||
# Positions of <image> tokens in input_ids, assuming batch size is 1
|
||||
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
|
||||
image_positions = mx.array(
|
||||
np.where(input_ids[0] == image_token_index)[0], mx.uint32
|
||||
)
|
||||
|
||||
if len(image_positions) != num_images:
|
||||
if len(image_positions) != num_image_patches:
|
||||
raise ValueError(
|
||||
f"The number of image tokens ({len(image_positions)}) does not "
|
||||
f" match the number of image inputs ({num_images})."
|
||||
f" match the number of image patches ({num_image_patches})."
|
||||
)
|
||||
|
||||
text_segments = []
|
||||
start_idx = 0
|
||||
|
||||
for position in image_positions:
|
||||
text_segments.append(inputs_embeds[:, start_idx:position])
|
||||
start_idx = position + 1
|
||||
|
||||
image_embeddings = mx.split(image_features, image_features.shape[0])
|
||||
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
|
||||
final_embeddings += [inputs_embeds[:, start_idx:]]
|
||||
|
||||
# Create a final embedding of shape
|
||||
# (1, num_image_patches*num_images + sequence_len, embed_dim)
|
||||
return mx.concatenate(final_embeddings, axis=1)
|
||||
inputs_embeds[0, image_positions] = image_features
|
||||
return inputs_embeds
|
||||
|
||||
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
|
||||
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
|
||||
|
@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.20.2"
|
||||
__version__ = "0.20.4"
|
||||
|
@ -32,7 +32,7 @@ def _len_longest_common_prefix(a, b):
|
||||
|
||||
|
||||
def _rstrip_until(s, untils):
|
||||
"""Limit a string <s> to the first occurence of any substring in untils."""
|
||||
"""Limit a string <s> to the first occurrence of any substring in untils."""
|
||||
l = len(s)
|
||||
f = [s.find(u) for u in untils]
|
||||
f = [l if x < 0 else x for x in f]
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import codecs
|
||||
import json
|
||||
import sys
|
||||
|
||||
@ -188,6 +189,8 @@ def main():
|
||||
elif using_cache:
|
||||
tokenizer.chat_template = metadata["chat_template"]
|
||||
|
||||
prompt = codecs.decode(args.prompt, "unicode_escape")
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
@ -199,7 +202,7 @@ def main():
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
|
||||
"content": sys.stdin.read() if prompt == "-" else prompt,
|
||||
}
|
||||
)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
@ -216,8 +219,6 @@ def main():
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prompt = prompt[test_prompt.index("<query>") :]
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||
response = generate(
|
||||
|
207
llms/mlx_lm/models/cohere2.py
Normal file
207
llms/mlx_lm/models/cohere2.py
Normal file
@ -0,0 +1,207 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention
|
||||
from .cache import KVCache, RotatingKVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int = 4096
|
||||
head_dim: int = 128
|
||||
num_hidden_layers: int = 32
|
||||
intermediate_size: int = 14336
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
rope_theta: float = 50000.0
|
||||
vocab_size: int = 256000
|
||||
layer_norm_eps: float = 1e-05
|
||||
logit_scale: float = 0.0625
|
||||
attention_bias: bool = False
|
||||
layer_norm_bias: bool = False
|
||||
sliding_window: int = 4096
|
||||
sliding_window_pattern: int = 4
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.head_dim = head_dim = args.head_dim
|
||||
if (head_dim * n_heads) != dim:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}"
|
||||
f" and `num_heads`: {n_heads})."
|
||||
)
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
attetion_bias = args.attention_bias
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
|
||||
|
||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||
|
||||
self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Apply RoPE only if sliding window is enabled
|
||||
if self.use_sliding_window:
|
||||
if cache is None:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
else:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
if self.use_sliding_window and mask is not None:
|
||||
key_len = keys.shape[-2]
|
||||
if mask.shape[-1] != key_len:
|
||||
mask = mask[..., -key_len:]
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.n_heads = args.num_attention_heads
|
||||
|
||||
self.self_attn = Attention(args, layer_idx)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
h = self.input_layernorm(x)
|
||||
attn_h = self.self_attn(h, mask, cache)
|
||||
ff_h = self.mlp(h)
|
||||
return attn_h + ff_h + x
|
||||
|
||||
|
||||
class CohereModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args, layer_idx=i)
|
||||
for i in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
offset = cache[0].offset if cache else 0
|
||||
mask = create_causal_mask(T, offset).astype(h.dtype)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = CohereModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = out * self.model.args.logit_scale
|
||||
return out
|
||||
|
||||
def make_cache(self):
|
||||
caches = []
|
||||
for i in range(self.args.num_hidden_layers):
|
||||
if (
|
||||
i % self.args.sliding_window_pattern
|
||||
== self.args.sliding_window_pattern - 1
|
||||
):
|
||||
caches.append(KVCache())
|
||||
else:
|
||||
caches.append(
|
||||
RotatingKVCache(max_size=self.args.sliding_window, keep=0)
|
||||
)
|
||||
return caches
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
@ -190,7 +190,7 @@ def make_repetition_penalty(penalty: float, context_size: int = 20):
|
||||
Callable[[mx.array, List[int]], mx.array]:
|
||||
The repetition penalty processor.
|
||||
"""
|
||||
if penalty < 0 or not isinstance(penalty, float):
|
||||
if penalty < 0 or not isinstance(penalty, (int, float)):
|
||||
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
|
||||
|
||||
def repetition_penalty_processor(tokens, logits):
|
||||
|
@ -465,7 +465,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
text = ""
|
||||
tic = time.perf_counter()
|
||||
sampler = make_sampler(self.temperature)
|
||||
sampler = make_sampler(self.temperature, top_p=self.top_p)
|
||||
logits_processors = make_logits_processors(
|
||||
self.logit_bias, self.repetition_penalty, self.repetition_context_size
|
||||
)
|
||||
|
@ -3,8 +3,6 @@ from functools import partial
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
REPLACEMENT_CHAR = "\ufffd"
|
||||
|
||||
|
||||
class StreamingDetokenizer:
|
||||
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
||||
@ -51,11 +49,9 @@ class StreamingDetokenizer:
|
||||
def last_segment(self):
|
||||
"""Return the last segment of readable text since last time this property was accessed."""
|
||||
text = self.text
|
||||
if text and text[-1] != REPLACEMENT_CHAR:
|
||||
segment = text[self.offset :]
|
||||
self.offset = len(text)
|
||||
return segment
|
||||
return ""
|
||||
segment = text[self.offset :]
|
||||
self.offset = len(text)
|
||||
return segment
|
||||
|
||||
|
||||
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
@ -132,7 +128,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
self.tokens = []
|
||||
|
||||
def _flush(self):
|
||||
text = self._unflushed.replace(self._sep, b" ").decode("utf-8")
|
||||
text = self._unflushed.replace(self._sep, b" ").decode("utf-8", "replace")
|
||||
if not self.text and self.trim_space and text and text[0] == " ":
|
||||
text = text[1:]
|
||||
self.text += text
|
||||
@ -199,22 +195,21 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
self.tokens.append(token)
|
||||
v = self.tokenmap[token]
|
||||
is_added = token in self._added_ids
|
||||
if is_added or self._byte_decoder[v[0]] == 32:
|
||||
current_text = bytearray(
|
||||
self._byte_decoder[c] for c in self._unflushed
|
||||
).decode("utf-8")
|
||||
self.text += self._maybe_trim_space(current_text)
|
||||
if is_added:
|
||||
self.text += v
|
||||
self._unflushed = ""
|
||||
else:
|
||||
self._unflushed = v
|
||||
else:
|
||||
if not is_added:
|
||||
self._unflushed += v
|
||||
text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
|
||||
"utf-8", "replace"
|
||||
)
|
||||
if is_added:
|
||||
text += v
|
||||
if not text.endswith("\ufffd"):
|
||||
self.text += self._maybe_trim_space(text)
|
||||
self._unflushed = ""
|
||||
|
||||
def finalize(self):
|
||||
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
|
||||
"utf-8"
|
||||
"utf-8",
|
||||
"replace",
|
||||
)
|
||||
self.text += self._maybe_trim_space(current_text)
|
||||
self._unflushed = ""
|
||||
|
@ -96,6 +96,7 @@ def linear_to_lora_layers(
|
||||
"gemma2",
|
||||
"starcoder2",
|
||||
"cohere",
|
||||
"cohere2",
|
||||
"minicpm",
|
||||
"deepseek",
|
||||
"olmo2",
|
||||
|
@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
from mlx.utils import tree_flatten, tree_reduce
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
@ -59,6 +59,7 @@ class GenerationResponse:
|
||||
generation_tokens (int): The number of generated tokens.
|
||||
generation_tps (float): The tokens-per-second for generation.
|
||||
peak_memory (float): The peak memory used so far in GB.
|
||||
finish_reason (str): The reason the response is being sent: "length", "stop" or `None`
|
||||
"""
|
||||
|
||||
text: str
|
||||
@ -69,6 +70,7 @@ class GenerationResponse:
|
||||
generation_tokens: int
|
||||
generation_tps: float
|
||||
peak_memory: float
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -185,9 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
||||
and prompt_cache[0].offset > quantized_kv_start
|
||||
):
|
||||
for i in range(len(prompt_cache)):
|
||||
prompt_cache[i] = prompt_cache[i].to_quantized(
|
||||
group_size=kv_group_size, bits=kv_bits
|
||||
)
|
||||
if isinstance(prompt_cache[i], cache.KVCache):
|
||||
prompt_cache[i] = prompt_cache[i].to_quantized(
|
||||
group_size=kv_group_size, bits=kv_bits
|
||||
)
|
||||
|
||||
|
||||
def generate_step(
|
||||
@ -297,6 +300,9 @@ def generate_step(
|
||||
prompt_processed_tokens = 0
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
mx.eval([c.state for c in prompt_cache])
|
||||
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
||||
prompt_processed_tokens += prefill_step_size
|
||||
@ -375,6 +381,7 @@ def stream_generate(
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
detokenizer.finalize()
|
||||
@ -387,6 +394,7 @@ def stream_generate(
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||
finish_reason="stop" if token in tokenizer.eos_token_ids else "length",
|
||||
)
|
||||
|
||||
|
||||
|
@ -851,6 +851,22 @@ class TestModels(unittest.TestCase):
|
||||
model = exaone.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers)
|
||||
|
||||
def test_cohere2(self):
|
||||
from mlx_lm.models import cohere2
|
||||
|
||||
args = cohere2.ModelArgs(
|
||||
model_type="cohere2",
|
||||
hidden_size=4096,
|
||||
head_dim=128,
|
||||
num_hidden_layers=40,
|
||||
sliding_window=4096,
|
||||
sliding_window_pattern=4,
|
||||
)
|
||||
model = cohere2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user