From 9392bc70f793d6080efd8764931fbf1b8046d586 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 24 Feb 2025 08:51:12 -0800 Subject: [PATCH] cleanup --- llms/mlx_lm/models/plamo2.py | 1585 ++++------------------------------ 1 file changed, 191 insertions(+), 1394 deletions(-) diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index 7d4a439e..ba4c2bb1 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -1,446 +1,38 @@ -import enum +# Copyright © 2025 Apple Inc. + import math -from collections import OrderedDict from dataclasses import dataclass -from typing import Any, Literal, NamedTuple, Optional, Union +from typing import Any, Optional, Union import mlx.core as mx import mlx.nn as nn - from mlx_lm.models.base import BaseModelArgs, create_attention_mask - -def _is_first_token(mask: mx.array) -> mx.array: - assert mask.dtype == mx.bool_ # type: ignore - B, Nh, q_len, kv_len = mask.shape - mask = mask[:, :, :, -q_len:] - cont = q_len != kv_len - v = False if cont else True - out = mx.logical_not(mx.diagonal(mask, offset=-1, axis1=-2, axis2=-1).astype(mx.bool_)) # type: ignore - out = mx.concatenate([mx.full(shape=(B, Nh, 1), dtype=mx.bool_, vals=v), out], axis=-1) # type: ignore - return out - - -def _swiglu(h: mx.array) -> mx.array: - size = h.shape[-1] - chunks = 2 - _current_idx = 0 - split_sizes = [] - for i in range(chunks - 1): - _current_idx += size // chunks + (1 if i < size % chunks else 0) - split_sizes.append(_current_idx) - hs = mx.split(h, split_sizes, axis=-1) - return nn.silu(hs[0]) * hs[1] - - -class RotaryEmbedding(nn.Module): - def __init__(self, dim: int, max_position_embeddings: int = 2048, base: int = 10000) -> None: - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)) - self._inv_freq = inv_freq - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, dtype=mx.float32) - - def _set_cos_sin_cache(self, seq_len: int, dtype: Any) -> None: - self.max_seq_len_cached = seq_len - t = mx.arange(self.max_seq_len_cached, dtype=self._inv_freq.dtype) # type: ignore - - freqs = mx.einsum("i,j->ij", t, self._inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = mx.concatenate([freqs, freqs], axis=-1) - self._cos_cached = emb.cos()[None, None, :, :].astype(mx.float32) - self._sin_cached = emb.sin()[None, None, :, :].astype(mx.float32) - - def __call__(self, x: mx.array, seq_len: int) -> tuple[mx.array, mx.array]: - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, dtype=x.dtype) - - return ( - self._cos_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore - self._sin_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore - ) - - -def _rotate_half(x: mx.array) -> mx.array: - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return mx.concatenate([-x2, x1], axis=-1) - - -def _rotary_pos_emb(x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array) -> mx.array: - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = mx.expand_dims(cos[position_ids], 1) # [bs, 1, seq_len, dim] - sin = mx.expand_dims(sin[position_ids], 1) # [bs, 1, seq_len, dim] - x_embed = (x * cos) + (_rotate_half(x) * sin) - return x_embed - - -class LinearType(str, enum.Enum): - Normal = "normal" - Fp8 = "fp8" - Fp8Retain = "fp8-retain" +from .cache import KVCache, MambaCache @dataclass -class ModelArgs(BaseModelArgs): # type: ignore +class ModelArgs(BaseModelArgs): model_type: str = "plamo2" - - def __init__( - self, - hidden_size: int = 4096, - num_hidden_layers: int = 32, - rms_norm_eps: float = 1e-6, - tie_word_embeddings: bool = True, - # Attention - num_attention_heads: int = 32, - num_key_value_heads: int = 4, - hidden_size_per_head: int = 128, - max_position_embeddings: int = 2048, - attention_window_size: int = 2048, - full_attention_idx: list[int] | None = None, - # Mamba - mamba_d_state: int = 64, - mamba_d_conv: int = 4, - mamba_num_heads: int = 64, - mamba_step: int = 2, - mamba_chunk_size: int = 256, - mamba_enabled: bool = True, - # MLP - intermediate_size: int = 13312, - # Tokenizer - vocab_size: int = 32000, - tokenizer_class: str = "PlamoTokenizer", - pad_token_id: Optional[int] = None, - bos_token_id: int = 1, - eos_token_id: int = 2, - # Multimodal - image_token_id: Optional[int] = None, - image_feature_size: Optional[int] = None, - image_proj_type: Literal["linear", "mlp"] = "linear", - # FP8 - linear_type: LinearType = LinearType.Normal, - fp8_accum_dtype: Optional[str] = None, - # Evaluation - eval_attention_n_bit: Optional[int] = None, - eval_mlp_n_bit: Optional[int] = None, - use_cache: bool = True, - **kwargs: Any, - ) -> None: - # max_position_embeddings is often used to determine the max length during inference, - # but samba should have extrapolation abilities - self.max_position_embeddings = max(10 * 1024 * 1024, max_position_embeddings) - self.hidden_size = hidden_size - self.rms_norm_eps = rms_norm_eps - - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_size_per_head = hidden_size_per_head - self.num_key_value_heads = num_key_value_heads - self.attention_window_size = attention_window_size - self.full_attention_idx = full_attention_idx if full_attention_idx is not None else [] - - self.mamba_d_state = mamba_d_state - self.mamba_d_conv = mamba_d_conv - self.mamba_num_heads = mamba_num_heads - self.mamba_step = mamba_step - self.mamba_chunk_size = mamba_chunk_size - self.mamba_enabled = mamba_enabled - - self.intermediate_size = intermediate_size - - self.vocab_size = vocab_size - - self.image_token_id = image_token_id - self.image_feature_size = image_feature_size - self.image_proj_type = image_proj_type - - self.linear_type = linear_type - self.fp8_accum_dtype = fp8_accum_dtype - - self.eval_attention_n_bit = eval_attention_n_bit - self.eval_mlp_n_bit = eval_mlp_n_bit - self.use_cache = use_cache - - # fields for vLLM - self.sliding_window = attention_window_size - - self.tokenizer_class = tokenizer_class - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.tie_word_embeddings = tie_word_embeddings - - # From PretrainedConfig of transformers - self.use_return_dict = kwargs.pop("use_return_dict", True) - self.output_attentions = kwargs.pop("output_attentions", False) - self.output_hidden_states = kwargs.pop("output_hidden_states", False) - - -class PlamoAttentionCache(nn.Module): - def __init__(self, key: mx.array, value: mx.array) -> None: - super().__init__() - B, nh, L, c = key.shape - assert len(value.shape) == 4 - assert value.shape[0] == B - assert value.shape[2] == L - self.key = key - self.value = value - - @property - def state(self) -> tuple[mx.array, mx.array]: - return self.key, self.value - - @state.setter - def state(self, key: mx.array, value: mx.array) -> None: - self.key = key - self.value = value - - -class PlamoMambaCache(nn.Module): - def __init__(self, conv_state: mx.array, ssm_state: mx.array) -> None: - super().__init__() - # conv_state: [B, C, d_conv] - # ssm_state: [B, nhead, nchanel_per_head, d_state] - assert len(conv_state.shape) == 3 - assert len(ssm_state.shape) == 4 - assert conv_state.shape[0] == ssm_state.shape[0] - self.conv_state = conv_state - self.ssm_state = ssm_state - - @property - def state(self) -> tuple[mx.array, mx.array]: - return self.conv_state, self.ssm_state - - @state.setter - def state(self, conv_satte: mx.array, ssm_state: mx.array) -> None: - self.conv_state = conv_satte - self.ssm_state = ssm_state - - -PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache - - -class PlamoCache(nn.Module): - """ - stores states of the model for fast decoding. - `transformers` uses `transformers.Cache` for this purpose, but the interface and variable names are - deeply dependent on Transformers architecture (e.g., `key_states`) and it is difficult to use - other architectures (e.g., Mamba). - This class provides a similar interface to `transformers.Cache`, but is designed to also handle - the state of Mamba properly. - """ - - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - self.cache: list[Optional[PlamoLayerCache]] = [None for _ in range(config.num_hidden_layers)] - - def append_kv(self, key: mx.array, value: mx.array, layer_idx: int) -> tuple[mx.array, mx.array]: - c = self.cache[layer_idx] - if c is None: - return key, value - assert isinstance(c, PlamoAttentionCache) - - def _validate(cache: mx.array, new_tensor: mx.array) -> None: - assert len(cache.shape) == 4 - assert len(new_tensor.shape) == 4 - assert cache.shape[0] == new_tensor.shape[0] - assert cache.shape[1] == new_tensor.shape[1] - assert cache.shape[3] == new_tensor.shape[3] - - _validate(c.key, key) - _validate(c.value, value) - assert key.shape[2] == value.shape[2] - return mx.concatenate([c.key, key], axis=2), mx.concatenate([c.value, value], axis=2) - - def update_attention(self, key_states: mx.array, value_states: mx.array, layer_idx: int) -> PlamoAttentionCache: - full_attn = layer_idx in self.config.full_attention_idx - window_size = self.config.attention_window_size - - if self.cache[layer_idx] is None: - if full_attn: - self.cache[layer_idx] = PlamoAttentionCache(key_states, value_states) - else: - self.cache[layer_idx] = PlamoAttentionCache( - key_states[:, :, -window_size:, :], - value_states[:, :, -window_size:, :], - ) - else: - c = self.cache[layer_idx] - assert isinstance(c, PlamoAttentionCache) - k, v = self.append_kv(key_states, value_states, layer_idx) - if full_attn: - c.key = k - c.value = v - else: - c.key = k[:, :, -window_size:, :] - c.value = v[:, :, -window_size:, :] - self.cache[layer_idx] = c - return self.cache[layer_idx] # type: ignore - - def update_mamba(self, conv_state: mx.array, ssm_state: mx.array, layer_idx: int) -> PlamoMambaCache: - if self.cache[layer_idx] is None: - self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state) - else: - c = self.cache[layer_idx] - assert isinstance(c, PlamoMambaCache) - assert c.conv_state.shape == conv_state.shape - assert c.ssm_state.shape == ssm_state.shape - c.conv_state = conv_state - c.ssm_state = ssm_state - return self.cache[layer_idx] # type: ignore - - def __getitem__(self, layer_idx: int) -> PlamoLayerCache | None: - assert layer_idx < len(self.cache) - layer_cache = self.cache[layer_idx] - return layer_cache # type: ignore - - def __iter__(self): - self._counter = 0 - return self - - def __next__(self): - if self._counter < len(self.cache): - layer_cache = self.cache[self._counter] - self._counter += 1 - return layer_cache - else: - raise StopIteration - - @property - def state(self): - return self.cache - - @state.setter - def state(self, v): - self.cache = v - - def __len__(self) -> int: - return len(self.cache) - - def get_seq_length(self, layer_idx: Optional[int] = None) -> int: - if layer_idx is not None: - c = self.cache[layer_idx] - assert isinstance(c, PlamoAttentionCache) - return c.key.shape[2] # type: ignore - - sequence_length: int = 0 - for layer_cache in self.cache: - if isinstance(layer_cache, PlamoAttentionCache): - sequence_length = ( - max(layer_cache.key.shape[2], sequence_length) - if sequence_length is not None - else layer_cache.key.shape[2] - ) - return sequence_length - - def get_max_length(self) -> int | None: - return None - - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_length() - previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length - - def reorder_cache(self, beam_idx: mx.array) -> None: - def _mamba(cache: PlamoMambaCache) -> PlamoMambaCache: - return PlamoMambaCache( - conv_state=mx.take(cache.conv_state, beam_idx, axis=0), - ssm_state=mx.take(cache.ssm_state, beam_idx, axis=0), - ) - - def _attention(cache: PlamoAttentionCache) -> PlamoAttentionCache: - return PlamoAttentionCache( - key=mx.take(cache.key, beam_idx, axis=0), - value=mx.take(cache.value, beam_idx, axis=0), - ) - - for i in range(len(self.cache)): - if self.cache[i] is None: - continue - layer_cache = self.cache[i] - if isinstance(layer_cache, PlamoMambaCache): - self.cache[i] = _mamba(layer_cache) - else: - assert isinstance(layer_cache, PlamoAttentionCache) - self.cache[i] = _attention(layer_cache) - - @property - def seen_tokens(self) -> int | None: - return None - - -class DecoderInput(NamedTuple): - hidden_states: mx.array - attention_mask: Optional[mx.array] = None - past_states: Optional[PlamoCache] = None - output_hidden_states: Optional[bool] = False - output_attentions: Optional[bool] = False - gradient_checkpointing: bool = False - input_ids: Optional[mx.array] = None - - -class DecoderOutput(NamedTuple): - hidden_states: mx.array - all_hidden_states: Optional[tuple[mx.array, ...]] - all_self_attns: Optional[tuple[mx.array, ...]] - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: tuple[int, int], dtype: mx.Dtype, past_key_values_length: int = 0) -> mx.array: - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = mx.full((tgt_len, tgt_len), float("-inf")) - mask_cond = mx.arange(mask.shape[-1]) - mask = mx.where(mask_cond < (mask_cond + 1).reshape((mask.shape[-1], 1)), 0, mask) - mask = mask.astype(dtype) - - if past_key_values_length > 0: - mask = mx.concatenate([mx.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1) - return mx.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length)) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None) -> mx.array: - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.shape - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mx.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)).astype(dtype) - - inverted_mask = 1.0 - expanded_mask - - return mx.where(inverted_mask.astype(mx.bool_), float("-inf"), inverted_mask) # type: ignore - - -def _rms_norm(hidden_states: mx.array, weight: Optional[mx.array], eps: float, offset: float = 1.0) -> mx.array: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.astype(mx.float32) - variance = mx.power(hidden_states, 2).mean(-1, keepdims=True) - hidden_states = hidden_states * mx.rsqrt(variance + eps) - hidden_states = hidden_states.astype(input_dtype) - if weight is not None: - hidden_states = (offset + weight) * hidden_states - return hidden_states + hidden_size: int = 4096 + num_hidden_layers: int = 32 + rms_norm_eps: float = 1e-6 + tie_word_embeddings: bool = True + num_attention_heads: int = 32 + num_key_value_heads: int = 4 + hidden_size_per_head: int = 128 + max_position_embeddings: int = 2048 + attention_window_size: int = 2048 + full_attention_idx: Optional[list[int]] = None + mamba_d_state: int = 64 + mamba_d_conv: int = 4 + mamba_num_heads: int = 64 + mamba_step: int = 2 + mamba_chunk_size: int = 256 + mamba_enabled: bool = True + intermediate_size: int = 13312 + vocab_size: int = 32000 + max_position_embeddings: int = 10 * 1024 * 1024 class RMSNorm(nn.Module): @@ -456,13 +48,18 @@ class RMSNorm(nn.Module): self.offset = offset def __call__(self, hidden_states: mx.array) -> mx.array: - return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset) + return mx.fast.rms_norm( + hidden_states, self.weight + self.offset, self.variance_epsilon + ) def get_initial_dt_bias(num_heads: int) -> mx.array: dt_min = 0.001 dt_max = 0.1 - dt = mx.exp(mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) + dt = mx.exp( + mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) dt = mx.clip(dt, a_min=1e-4, a_max=None) inv_dt = dt + mx.log(-mx.expm1(-dt)) return inv_dt @@ -535,8 +132,10 @@ def selective_state_update_ref( mx.repeat(mx.expand_dims(C, axis=2), nheads // ngroups, 2), (batch, nheads, dstate), ) # (batch, nheads, dstate) - dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(B, axis=-2) # (batch, nheads, dim, dstate) - state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate + dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims( + B, axis=-2 + ) # (batch, nheads, dim, dstate) + state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate) out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C) if D is not None: out += (x * D).astype(out.dtype) @@ -563,9 +162,15 @@ def ssd_update_state( hidden_size_per_head = x.shape[-1] d_state = B.shape[-1] - A = mx.broadcast_to(A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)).astype(mx.float32) - dt = mx.broadcast_to(dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head)) - dt_bias = mx.broadcast_to(dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head)) + A = mx.broadcast_to( + A[:, None, None], (A.shape[0], hidden_size_per_head, d_state) + ).astype(mx.float32) + dt = mx.broadcast_to( + dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head) + ) + dt_bias = mx.broadcast_to( + dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head) + ) D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head)) out, ssm_state = selective_state_update_ref( ssm_state, @@ -582,7 +187,7 @@ def ssd_update_state( return out[:, None], ssm_state -def _ssd_chunk_scan_combined_naive( +def ssd_chunk_scan_combined( x: mx.array, dt: mx.array, A: mx.array, @@ -592,19 +197,12 @@ def _ssd_chunk_scan_combined_naive( z: mx.array, dt_bias: mx.array, dt_softplus: bool, - seq_idx: mx.array | None, ssm_state: mx.array, ) -> tuple[mx.array, mx.array]: assert ssm_state.dtype == mx.float32 length = x.shape[1] ys = [] for i in range(length): - if i != 0 and seq_idx is not None: - ssm_state = mx.where( - mx.array(seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None, None], - mx.zeros_like(ssm_state), - ssm_state, - ) y, ssm_state = ssd_update_state( ssm_state, x[:, i], @@ -621,227 +219,46 @@ def _ssd_chunk_scan_combined_naive( return mx.concatenate(ys, axis=1), ssm_state -def ssd_chunk_scan_combined( - x: mx.array, - dt: mx.array, - A: mx.array, - B: mx.array, - C: mx.array, - chunk_size: int, - D: mx.array, - z: mx.array, - dt_bias: mx.array, - dt_softplus: bool, - return_final_states: bool, - seq_idx: mx.array | None, - ssm_state: mx.array | None, -) -> tuple[mx.array, mx.array] | mx.array: - if seq_idx is not None: - assert seq_idx.dtype == mx.int32 - assert ssm_state is None - assert not return_final_states - if ssm_state is not None: - assert ssm_state.dtype == mx.float32 - assert seq_idx is None - """ - state will be updates by following: - ``` - dt = softplus(dt) - dA = exp(dt * A) - state_next = state * dA + dB * x - ``` - To avoid updating state, we set dt to -inf and x to 0 - because `softplus(-inf) = 0` and `exp(0) = 1` - """ - if ssm_state is None: - bsize, _, num_heads, channel = x.shape - state = B.shape[-1] - ssm_state = mx.zeros((bsize, num_heads, channel, state), dtype=mx.float32) - tmp, ssm_state = _ssd_chunk_scan_combined_naive( - x, - dt, - A, - B, - C, - D, - z=z, - dt_bias=dt_bias, - dt_softplus=dt_softplus, - seq_idx=seq_idx, - ssm_state=ssm_state, - ) - if return_final_states: - return tmp, ssm_state - else: - return tmp - - -def _causal_conv1d( - conv_state: mx.array | None, weight: mx.array, x: mx.array, seq_idx: mx.array | None -) -> tuple[mx.array, mx.array | None]: - dtype = x.dtype - if conv_state is not None: - dtype = conv_state.dtype - assert seq_idx is None - if seq_idx is not None: - assert seq_idx.dtype == mx.int32 - assert conv_state is None - weight = weight.astype(dtype) - x = x.astype(dtype) - - return_final_states = conv_state is not None - if conv_state is None: - bsize = x.shape[0] - dim = weight.shape[0] - d_conv = weight.shape[-1] - conv_state = mx.zeros((bsize, dim, d_conv - 1), dtype=x.dtype) - length = x.shape[-1] - out = mx.zeros_like(x) - for i in range(length): - if i != 0 and seq_idx is not None: - conv_state = mx.where( - seq_idx[:, i - 1][:, None, None] != seq_idx[:, i][:, None, None], - mx.zeros_like(conv_state), - conv_state, - ) - out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1]) - x = out - if return_final_states: - return x, conv_state - else: - return x, None - - -# From: https://github.com/Dao-AILab/causal-conv1d/blob/82867a9d2e6907cc0f637ac6aff318f696838548/causal_conv1d/causal_conv1d_interface.py#L206 -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None) -> tuple[mx.array, mx.array]: - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - unsqueeze = x.ndim == 2 - if unsqueeze: - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape +def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]: + batch, seqlen, dim = x.shape width = weight.shape[1] - state_len = conv_state.shape[-1] - assert conv_state.shape == (batch, dim, state_len) - assert weight.shape == (dim, width) - x_new = mx.concatenate([conv_state, x], axis=-1).astype(weight.dtype) # (batch, dim, state_len + seqlen) - conv_state = x_new[:, :, -state_len:] - assert bias is None - # x_new: (N, C, L) -> (N, L, C) + state_len = conv_state.shape[-2] + x = mx.concatenate([conv_state, x], axis=-2) + conv_state = x[:, -state_len:] out = mx.conv1d( - x_new.transpose(0, 2, 1), - mx.expand_dims(weight, axis=2), + x, + weight, padding=0, groups=dim, - ).transpose(0, 2, 1)[:, :, -seqlen:] - if unsqueeze: - out = out.squeeze(-1) - return (out if activation is None else nn.silu(out)).astype(dtype_in), conv_state - - -def _causal_conv1d_update(conv_state: mx.array, weight: mx.array, xBC: mx.array) -> tuple[mx.array, mx.array]: - dtype = conv_state.dtype - xBC = xBC.astype(dtype) - weight = weight.astype(dtype) - - x, conv_state = causal_conv1d_update( - x=xBC, - conv_state=conv_state, - weight=weight[:, :, 0], - activation="silu", - ) - return x, conv_state - - -# Based on: https://github.com/Dao-AILab/causal-conv1d/blob/82867a9d2e6907cc0f637ac6aff318f696838548/causal_conv1d/causal_conv1d_interface.py#L206 -def causal_conv1d(x, weight, bias=None, activation=None): - """ - MLX implementation of a causal depthwise 1D convolution. - Args: - x (mx.array): Input tensor of shape (batch, channels, seq_len). - weight (mx.array): Convolution filters of shape (channels, kernel_width). - Each channel has its own filter (depthwise conv). - bias (mx.array, optional): Bias for each channel of shape (channels,). - activation (str, optional): Activation to apply ("silu" or "swish" supported). - Returns: - mx.array: Output tensor of shape (batch, channels, seq_len). - """ - x = mx.array(x) if not isinstance(x, mx.array) else x - weight = mx.array(weight) if not isinstance(weight, mx.array) else weight - if bias is not None: - bias = mx.array(bias) if not isinstance(bias, mx.array) else bias - - batch, channels, seq_len = x.shape - _, kernel_width = weight.shape # weight shape: (channels, kernel_width) - - # Reshape weight for depthwise conv: (out_channels, in_channels/groups, kernel_width) - # Here out_channels = channels, in_channels/groups = 1 (depthwise conv per channel) - w = weight.reshape((channels, 1, kernel_width)) - - # Pad input on the left with (kernel_width-1) zeros for causal convolution - if kernel_width > 1: - pad_shape = (batch, channels, kernel_width - 1) - pad_zeros = mx.zeros(pad_shape, dtype=x.dtype) - x_padded = mx.concatenate([pad_zeros, x], axis=2) # concat along time axis - else: - x_padded = x - - # Perform depthwise convolution. Padding is already applied manually, so use padding=0 in conv1d. - y = mx.conv1d(x_padded, w, stride=1, padding=0, groups=channels) - # After convolution, y shape = (batch, channels, seq_len) because: - # input length = seq_len + kernel_width - 1, no padding in conv, so output length = seq_len. - - # Add bias if provided (bias shape (channels,) broadcasts to (batch, channels, seq_len)) - if bias is not None: - y = y + bias.reshape((1, channels, 1)) - - # Apply activation if specified - if activation in ("silu", "swish"): - # SiLU (swish) activation: y * sigmoid(y) - y = y * mx.sigmoid(y) - elif activation is not None: - raise ValueError(f"Unsupported activation: {activation}") - - return y + )[:, -seqlen:] + return nn.silu(out), conv_state class Mamba(nn.Module): - def __init__(self, config: ModelArgs, layer_idx: int) -> None: + def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config - self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.d_state = config.mamba_d_state self.d_conv = config.mamba_d_conv self.chunk_size = config.mamba_chunk_size self.num_heads = config.mamba_num_heads - # TODO add mamba_hidden_size_per_head config (?) self.hidden_size_per_head = config.hidden_size_per_head self.intermediate_size = self.num_heads * self.hidden_size_per_head - self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False) + self.in_proj = nn.Linear( + self.hidden_size, 2 * self.intermediate_size, bias=False + ) self.conv1d = nn.Conv1d( in_channels=self.intermediate_size, out_channels=self.intermediate_size, - bias=False, # TODO the original implementation uses bias + bias=False, kernel_size=self.d_conv, groups=self.intermediate_size, padding=0, ) self.dt_dim = max(64, self.hidden_size // 16) - # Notes: - # Mamba2 removes this linear projection for simplicity (Figure 6 in the paper), - # but it may degrade the ability of content-length extrapolation. self.bcdt_proj = nn.Linear( self.intermediate_size, self.dt_dim + 2 * self.d_state, @@ -853,55 +270,24 @@ class Mamba(nn.Module): self.A_log = get_initial_A(self.num_heads) self.D = mx.ones(self.num_heads, dtype=mx.float32) - # TODO norm weight before gating like Mamba2 self.dt_norm_weight = mx.ones(self.dt_dim) self.B_norm_weight = mx.ones(self.d_state) self.C_norm_weight = mx.ones(self.d_state) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - def _no_weight_decay_param_names(self) -> set[str]: - return set(["D", "dt_bias", "A_log"]) - def __call__( self, hidden_states: mx.array, - attention_mask: Optional[mx.array] = None, - past_states: Optional[PlamoCache] = None, - ) -> tuple[mx.array, Optional[PlamoCache]]: + mask: Optional[mx.array] = None, + cache=None, + ): bsize, length, _ = hidden_states.shape - is_update = length == 1 and past_states is not None + is_update = length == 1 and cache[0] is not None - bool_mask: mx.array | None = None - seq_idx: mx.array | None = None - if attention_mask is not None: - if len(attention_mask.shape) == 2: - attention_mask = mx.broadcast_to( - attention_mask[None, None], - (bsize, 1, attention_mask.shape[0], attention_mask.shape[1]), - ) - assert len(attention_mask.shape) == 4 - - if past_states is None: - # TODO: support seq_idx with cache - bool_mask_4d = mx.array(attention_mask == 0, dtype=mx.bool_) # type: ignore - is_first_token = _is_first_token(bool_mask_4d)[:, 0, :] - seq_idx = mx.cumsum(is_first_token, axis=-1) - 1 - seq_idx = seq_idx.astype(mx.int32) - - # `generate` function creates attention mask that contains past tokens, - # but mamba does not use them - attention_mask = attention_mask[:, 0, -length:, -length:] - bool_mask = mx.array(mx.diagonal(attention_mask, axis1=-2, axis2=-1) == 0) - - conv_state: mx.array | None - ssm_state: mx.array | None - if past_states is None: - conv_state = None - ssm_state = None - elif past_states[self.layer_idx] is None: + if not is_update: conv_state = mx.zeros( - (bsize, self.intermediate_size, self.d_conv - 1), + (bsize, self.d_conv - 1, self.intermediate_size), dtype=hidden_states.dtype, ) ssm_state = mx.zeros( @@ -909,10 +295,8 @@ class Mamba(nn.Module): dtype=mx.float32, ) else: - c = past_states[self.layer_idx] - assert isinstance(c, PlamoMambaCache) - conv_state = c.conv_state - ssm_state = c.ssm_state + conv_state = cache[0] + ssm_state = cache[1] zx = self.in_proj(hidden_states) zx = zx.reshape(bsize, length, self.num_heads, -1) @@ -927,21 +311,8 @@ class Mamba(nn.Module): ) # conv - x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length) - if bool_mask is not None: - x = mx.where(bool_mask[:, None, :], x, 0.0) - if is_update: - assert conv_state is not None - x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x) - else: - x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx) - x = x.astype(hidden_states.dtype) - x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size) - x = x.reshape(bsize, length, -1) - # x: (bsize, length, num_heads, hidden_size_per_head) - # B: (bsize, length, 1, d_state) - # C: (bsize, length, 1, d_state) - # dt: (bsize, length, dt_dim) + x = x.reshape(bsize, -1, self.num_heads * self.hidden_size_per_head) + x, conv_state = causal_conv1d_update(conv_state, x, self.conv1d.weight) BCdt = self.bcdt_proj(x) x = x.reshape(bsize, length, self.num_heads, -1) B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1) @@ -949,9 +320,9 @@ class Mamba(nn.Module): C = C[:, :, None, :] A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,) - dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :] - B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :] - C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :] + dt = mx.fast.rms_norm(dt, self.dt_norm_weight, self.config.rms_norm_eps) + B = mx.fast.rms_norm(B, self.B_norm_weight, self.config.rms_norm_eps) + C = mx.fast.rms_norm(C, self.C_norm_weight, self.config.rms_norm_eps) # (bsize, length, num_heads, 1) dt = self.dt_proj(dt)[..., None] @@ -960,81 +331,30 @@ class Mamba(nn.Module): B = mx.broadcast_to(B, (B.shape[0], B.shape[1], self.num_heads, B.shape[3])) C = mx.broadcast_to(C, (C.shape[0], C.shape[1], self.num_heads, C.shape[3])) - if bool_mask is not None: - """ - state will be updates by following: - ``` - dt = softplus(dt) - dA = exp(dt * A) - state_next = state * dA + dB * x - ``` - To avoid updating state, we set dt to -inf and x to 0 - because `softplus(-inf) = 0` and `exp(0) = 1` - """ - dt = mx.where(bool_mask[:, :, None, None], dt, float("-inf")) - x = mx.where(bool_mask[:, :, None, None], x, 0.0) - - # ssm - if is_update: - assert ssm_state is not None - out, ssm_state = ssd_update_state( - ssm_state, - x[:, 0], - dt[:, 0].reshape(bsize, -1), - A, - B[:, 0], - C[:, 0], - D=self.D, - z=z[:, 0], - dt_bias=self.dt_bias, - dt_softplus=True, - ) - else: - tmp = ssd_chunk_scan_combined( - x, - dt.reshape(bsize, length, -1), - A, - B, - C, - self.chunk_size, - D=self.D, - z=z, - dt_bias=self.dt_bias, - dt_softplus=True, - return_final_states=past_states is not None, - seq_idx=seq_idx, - ssm_state=ssm_state, - ) - if past_states is not None: - out, ssm_state = tmp - else: - assert isinstance(tmp, mx.array) - out = tmp + out, ssm_state = ssd_chunk_scan_combined( + x, + dt.reshape(bsize, length, -1), + A, + B, + C, + D=self.D, + z=z, + dt_bias=self.dt_bias, + dt_softplus=True, + ssm_state=ssm_state, + ) + cache[0] = conv_state + cache[1] = ssm_state y = self.out_proj(out.reshape(bsize, length, -1)) - if past_states is not None: - assert ssm_state is not None - assert conv_state is not None - past_states.update_mamba(conv_state, ssm_state, self.layer_idx) - - return y, past_states - - -def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array: - max_len = max(q_len, kv_len) - mask = mx.tril( - mx.triu(mx.ones((max_len, max_len), dtype=mx.bool_), k=-window_size), # type: ignore - k=window_size, - ) - return mask[-q_len:, -kv_len:] + return y class Attention(nn.Module): - def __init__(self, config: ModelArgs, layer_idx: int) -> None: + def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config - self.layer_idx = layer_idx self.hidden_size = config.hidden_size head_dim = config.hidden_size_per_head self.max_position_embeddings = config.max_position_embeddings @@ -1054,113 +374,53 @@ class Attention(nn.Module): self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, bias=False, ) - self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False) + self.o_proj = nn.Linear( + self.q_num_heads * self.v_dim, self.hidden_size, bias=False + ) self.q_weight = mx.ones((self.q_num_heads, self.qk_dim)) self.k_weight = mx.ones((self.k_num_heads, self.qk_dim)) - self.rotary_emb = RotaryEmbedding(self.qk_dim, max_position_embeddings=self.config.attention_window_size) + self.rope = nn.RoPE(self.qk_dim) def __call__( self, hidden_states: mx.array, - attention_mask: Optional[mx.array] = None, - past_states: Optional[PlamoCache] = None, - output_attentions: bool = False, - ) -> tuple[mx.array, Optional[mx.array], Optional[PlamoCache]]: - bsz, q_len, _ = hidden_states.shape + mask: Optional[mx.array] = None, + cache=None, + ): + B, T, _ = hidden_states.shape qkv = self.qkv_proj(hidden_states) - query_states, key_states, value_states = mx.split( + q, k, v = mx.split( qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1 ) - query_states = query_states.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3) - key_states = key_states.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) - value_states = value_states.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) + q = q.reshape(B, T, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3) + k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) + v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) - attn_dtype = query_states.dtype + q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None] + k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None] - query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None] - key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None] - - if past_states is not None: - # reuse k, v, self_attention - key_states_new = key_states - value_states_new = value_states - key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) # type: ignore - past_states.update_attention(key_states_new, value_states_new, self.layer_idx) - - kv_seq_len = key_states.shape[-2] - position_ids = mx.arange(kv_seq_len, dtype=mx.int64)[None] - q_position_ids = position_ids[:, -query_states.shape[2] :] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states = _rotary_pos_emb(query_states, cos, sin, q_position_ids) - key_states = _rotary_pos_emb(key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] - - # expand shared kv - assert self.k_num_heads == self.v_num_heads - key_states = mx.repeat(key_states, self.n_group, 1) - value_states = mx.repeat(value_states, self.n_group, 1) - - full_attn = self.layer_idx in self.config.full_attention_idx - - query_states = query_states.astype(attn_dtype) - key_states = key_states.astype(attn_dtype) - value_states = value_states.astype(attn_dtype) - if attention_mask is not None and attention_mask.dtype != bool: - attention_mask = attention_mask.astype(attn_dtype) - if attention_mask is None: - if not full_attn: - assert key_states.shape[2] <= self.config.attention_window_size + 1 - mask = create_attention_mask(hidden_states) - attn_output = mx.fast.scaled_dot_product_attention( - query_states, - key_states, - value_states, - scale=self.scale, - mask=mask, - ) + if cache is not None: + q = self.rope(q, offset=cache.offset) + k = self.rope(k, offset=cache.offset) + k, v = cache.update_and_fetch(k, v) else: - if attention_mask.dtype == bool: - attention_mask = mx.where(attention_mask, mx.array(0.0, dtype=mx.float16), float("-inf")) - if len(attention_mask.shape) == 2: - attention_mask = attention_mask[None, None] - assert len(attention_mask.shape) == 4 + q = self.rope(q) + k = self.rope(k) - if not full_attn: - m_swa = swa_mask( - query_states.shape[2], - key_states.shape[2], - self.config.attention_window_size, - ) - # `generate` function creates attention mask that does not consider sliding window - m_swa = m_swa[None, None] - attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :] - attention_mask = mx.where(m_swa, attention_mask, float("-inf")) - - # like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers, - # we need to attend to all tokens in masked rows for `scaled_dot_product_attention` - bool_mask = mx.logical_not(mx.isneginf(attention_mask)) - valid_tokens = mx.sum(bool_mask, axis=-1).astype(mx.bool_) # type: ignore # (..., q_len) - attention_mask = mx.where(valid_tokens[..., None], attention_mask, float(0.0)) - attn_output = mx.fast.scaled_dot_product_attention( - query_states, - key_states, - value_states, - scale=self.scale, - mask=attention_mask, - ) - - attn_output = attn_output.transpose(0, 2, 1, 3) - - attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_states + output = mx.fast.scaled_dot_product_attention( + q, + k, + v, + scale=self.scale, + mask=mask, + ) + output = output.transpose(0, 2, 1, 3).reshape( + B, T, self.q_num_heads * self.v_dim + ) + return self.o_proj(output) class MLP(nn.Module): @@ -1169,61 +429,56 @@ class MLP(nn.Module): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.gate_up_proj = nn.Linear( + self.hidden_size, self.intermediate_size * 2, bias=False + ) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) def __call__(self, x: mx.array) -> mx.array: h = self.gate_up_proj(x) - h = _swiglu(h) - return self.down_proj(h) # type: ignore + hs = mx.split(h, 2, axis=-1) + return self.down_proj(nn.silu(hs[0]) * hs[1]) class PlamoDecoderLayer(nn.Module): - def __init__(self, config: ModelArgs, is_mamba: bool, layer_idx: int) -> None: + def __init__(self, config: ModelArgs, is_mamba: bool) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size self.is_mamba = is_mamba self.mixer: nn.Module if is_mamba: - self.mixer = Mamba(config, layer_idx) + self.mixer = Mamba(config) else: - self.mixer = Attention(config, layer_idx) + self.mixer = Attention(config) self.mlp = MLP(config) - """ - Notes: The model performance was degraded when setting all offsets to 1. - """ - self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0) - self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5) - self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0) - self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5)) + self.pre_mixer_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 + ) + self.post_mixer_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5 + ) + self.pre_mlp_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 + ) + self.post_mlp_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5) + ) def __call__( self, hidden_states: mx.array, - attention_mask: Optional[mx.array] = None, - past_state: Optional[PlamoCache] = None, - output_attentions: Optional[bool] = False, - ) -> tuple[Any, ...]: - # from LlamaDecoder + mask: Optional[mx.array] = None, + cache=None, + ): residual = hidden_states hidden_states = self.pre_mixer_norm(hidden_states) - # Self Attention - if self.is_mamba: - hidden_states_sa, present_key_value = self.mixer( - hidden_states=hidden_states, - attention_mask=attention_mask, - past_states=past_state, - ) - self_attn_weights = None - else: - hidden_states_sa, self_attn_weights, present_key_value = self.mixer( - hidden_states=hidden_states, - attention_mask=attention_mask, - past_states=past_state, - output_attentions=output_attentions, - ) + hidden_states_sa = self.mixer( + hidden_states=hidden_states, + mask=mask, + cache=cache, + ) hidden_states_sa = self.post_mixer_norm(hidden_states_sa) hidden_states = residual + hidden_states_sa @@ -1236,14 +491,7 @@ class PlamoDecoderLayer(nn.Module): # Residual hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp) - hidden_states = residual + hidden_states_mlp - - outputs: Any = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - return outputs # type: ignore + return residual + hidden_states_mlp def is_mamba(config: ModelArgs, i: int) -> bool: @@ -1263,384 +511,71 @@ class PlamoDecoder(nn.Module): super().__init__() self.layers = [ - PlamoDecoderLayer(config, is_mamba=is_mamba(config, i), layer_idx=i) + PlamoDecoderLayer(config, is_mamba=is_mamba(config, i)) for i in range(config.num_hidden_layers) ] - self.gradient_checkpointing = False - def __call__(self, x: DecoderInput) -> DecoderOutput: - all_hidden_states: Optional[tuple[mx.array, ...]] = () if x.output_hidden_states else None - all_self_attns: Optional[tuple[mx.array, ...]] = () if x.output_attentions else None - hidden_states = x.hidden_states - - for decoder_layer in self.layers: - if x.output_hidden_states: - assert all_hidden_states is not None - all_hidden_states += (hidden_states,) - - if self.training and x.gradient_checkpointing: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - x.attention_mask, - x.past_states, - x.output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=x.attention_mask, - past_state=x.past_states, - output_attentions=x.output_attentions, - ) - - hidden_states = layer_outputs[0] - - if x.output_attentions: - assert layer_outputs[1] is not None - assert all_self_attns is not None - all_self_attns += (layer_outputs[1],) - return DecoderOutput(hidden_states, all_hidden_states, all_self_attns) + def __call__(self, x: mx.array, mask: mx.array, cache): + for i, decoder_layer in enumerate(self.layers): + x = decoder_layer( + x, + mask=mask, + cache=cache[i], + ) + return x -class ModelOutput(OrderedDict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def __getitem__(self, k): - if isinstance(k, str): - inner_dict = dict(self.items()) - return inner_dict[k] - else: - return self.to_tuple()[k] - - def to_tuple(self) -> tuple[Any]: - """ - Convert self to a tuple containing all the attributes/keys that are not `None`. - """ - return tuple(self[k] for k in self.keys()) - - -class BaseModelOutputWithPast(ModelOutput): - """ - Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). - - Args: - last_hidden_state (:obj:`mx.array` of shape :obj:`(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - - If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape - :obj:`(batch_size, 1, hidden_size)` is output. - past_key_values (:obj:`list[mx.array]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): - list of :obj:`mx.array` of length :obj:`config.n_layers`, with each tensor of shape - :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see - ``past_key_values`` input) to speed up sequential decoding. - hidden_states (:obj:`tuple(mx.array)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`mx.array` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(mx.array)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`mx.array` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.last_hidden_state: mx.array = kwargs.pop("last_hidden_state") - self.past_key_values: Optional[tuple[tuple[mx.array]]] = kwargs.pop("past_key_values", None) - self.hidden_states: Optional[tuple[mx.array, ...]] = kwargs.pop("hidden_states", None) - self.attentions: Optional[tuple[mx.array, ...]] = kwargs.pop("attentions", None) - - -class CausalLMOutputWithPast(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (`mx.array` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`mx.array` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(mx.array))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(mx.array)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(mx.array)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `mx.array` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(mx.array)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `mx.array` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - self.loss: Optional[mx.array] = kwargs.pop("loss", None) - self.logits: mx.array | None = kwargs.pop("logits", None) - self.past_key_values: Optional[tuple[tuple[mx.array]]] = kwargs.pop("past_key_values", None) - self.hidden_states: Optional[tuple[mx.array, ...]] = kwargs.pop("hidden_states", None) - self.attentions: Optional[tuple[mx.array, ...]] = kwargs.pop("attentions", None) - - -class PlamoPreTrainedModel(nn.Module): # type: ignore - config_class = ModelArgs - _no_split_modules: list[str] - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["PlamoDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - +class PlamoModel(nn.Module): def __init__(self, config: ModelArgs): super().__init__() + self.config = config - - def _init_weights(self, module: nn.Module) -> None: - std = 0.02 - if isinstance(module, nn.Linear): - module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape) - if module.bias is not None: - module.bias = mx.zeros_like(module.bias) - elif isinstance(module, nn.Embedding): - module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape) - if module.padding_idx is not None: - module.weight[module.padding_idx] = mx.zeros_like(module.weight[module.padding_idx]) - - -class PlamoModel(PlamoPreTrainedModel): - def __init__(self, config: ModelArgs): - super().__init__(config) - assert config.eval_attention_n_bit is None - assert config.eval_mlp_n_bit is None - - self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = PlamoDecoder(config) # type: ignore self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = False - # Initialize weights and apply final processing - # self.post_init() - - def get_input_embeddings(self) -> nn.Embedding: - return self.embed_tokens - - def set_input_embeddings(self, value: nn.Embedding) -> None: - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, - attention_mask: mx.array, - input_shape: tuple[int, int], - inputs_embeds: Optional[mx.array], - past_key_values_length: int, - ) -> Optional[mx.array]: - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask: Optional[mx.array] = None - if input_shape[-1] > 1: - assert inputs_embeds is not None - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - past_key_values_length=past_key_values_length, - ) - input_shape = (input_shape[0], combined_attention_mask.shape[2]) - - if attention_mask is not None: - if attention_mask.ndim == 4: - # Custom 4D attention mask - expanded_attn_mask = attention_mask - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - assert inputs_embeds is not None - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def __call__( self, - input_ids: Optional[mx.array] = None, - attention_mask: Optional[mx.array] = None, - position_ids: Optional[mx.array] = None, - past_key_values: Optional[PlamoCache] = None, - inputs_embeds: Optional[mx.array] = None, - image_features: Optional[mx.array] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, BaseModelOutputWithPast]: - assert input_ids is not None - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache + inputs: mx.array, + mask: Optional[mx.array] = None, + cache=None, + ): + batch_size, seq_length = inputs.shape - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + h = self.embed_tokens(inputs) - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + if mask is None: + mask = create_attention_mask(h, [cache[1]]) - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values.get_seq_length() - seq_length_with_past = seq_length_with_past + past_key_values_length - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if image_features is not None: - assert self.config.image_token_id is not None - image_embeds = self.image_proj(image_features) - assert image_embeds.shape == inputs_embeds.shape, ( - image_embeds.shape, - inputs_embeds.shape, - ) - mask = input_ids == self.config.image_token_id - inputs_embeds[mask] = image_embeds[mask] - - # embed positions - require_attn_mask = False - if not self.training or past_key_values is not None: - require_attn_mask = True - if seq_length_with_past >= self.config.attention_window_size: - require_attn_mask = True - if require_attn_mask and attention_mask is None: - attention_mask = mx.ones( - (batch_size, seq_length_with_past), - dtype=mx.bool_, # type: ignore - ) - if attention_mask is not None: - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - use_cache = False - - if use_cache and past_key_values is None: - past_key_values = PlamoCache(self.config) + if cache is None: + cache = [None] * len(self.layers) # decoder layers out = self.layers( - DecoderInput( - hidden_states, - attention_mask, - past_key_values, - output_hidden_states, - output_attentions, - self.gradient_checkpointing, - ) + h, + mask, + cache, ) - assert isinstance(out, DecoderOutput) - hidden_states = out.hidden_states - all_hidden_states = out.all_hidden_states - all_self_attns = out.all_self_attns - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - assert all_hidden_states is not None - all_hidden_states += (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - past_key_values, - all_hidden_states, - all_self_attns, - ] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) + return self.norm(out) -class Model(PlamoPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - # Without this, the model cannot be loaded into a meta device. - # Relevant code: - # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L4376-L4381 - # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L356 - # https://github.com/pytorch/pytorch/blob/v2.4.1/torch/nn/modules/module.py#L2068 - _supports_param_buffer_assignment = False +class Model(nn.Module): def __init__(self, config: ModelArgs) -> None: - super().__init__(config) + super().__init__() self.config = config self.model_type = config.model_type self.model = PlamoModel(config) self.vocab_size = config.vocab_size - vocab_size = ((self.vocab_size + 15) // 16) * 16 if not config.tie_word_embeddings: - self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False) - - # Initialize weights and apply final processing - # self.post_init() - - def get_input_embeddings(self) -> nn.Embedding: - return self.model.embed_tokens - - def set_input_embeddings(self, value: nn.Embedding) -> None: - self.model.embed_tokens = value - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings: nn.Module) -> None: - self.lm_head = new_embeddings - - def set_decoder(self, decoder: PlamoModel) -> None: - self.model = decoder - - def get_decoder(self) -> PlamoModel: - return self.model + self.lm_head: nn.Module = nn.Linear( + config.hidden_size, vocab_size, bias=False + ) def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: for k, v in weights.items(): @@ -1648,163 +583,25 @@ class Model(PlamoPreTrainedModel): weights[k] = v.moveaxis(2, 1) return weights - def make_cache(self) -> PlamoCache: - return PlamoCache(self.config) + def make_cache(self): + # TODO use RotatingKVCache is not full_attn + # full_attn = self.layer_idx in self.config.full_attention_idx + return [MambaCache() if l.is_mamba else KVCache() for l in self.layers] - def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array: - model_inputs = self.prepare_inputs_for_generation( - input_ids=inputs, - past_key_values=cache, - use_cache=self.config.use_cache, - ) - model_inputs["input_ids"] = inputs - output = self.forward(**model_inputs) - if not isinstance(output, CausalLMOutputWithPast): - raise ValueError( - f"Unexpected output type for causal language model: {type(output)} != CausalLMOutputWithPast" - ) - if output.logits is not None: - return output.logits - else: - raise ValueError("The model did not return any logits.") - - def forward( - self, - input_ids: Optional[mx.array] = None, - attention_mask: Optional[mx.array] = None, - position_ids: Optional[mx.array] = None, - past_key_values: Optional[PlamoCache] = None, - inputs_embeds: Optional[mx.array] = None, - image_features: Optional[mx.array] = None, - labels: Optional[mx.array] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple[Any, ...], CausalLMOutputWithPast]: - r""" - Args: - labels (`mx.array` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - Returns: - Example: - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" - assert input_ids is not None - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + def __call__( + self, inputs: mx.array, mask: Optional[mx.array] = None, cache=None + ) -> mx.array: outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - image_features=image_features, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + inputs=inputs, + mask=None, + cache=cache, ) - if isinstance(outputs, tuple): - hidden_states = outputs[0] - elif isinstance(outputs, BaseModelOutputWithPast): - hidden_states = outputs.last_hidden_state - if self.config.tie_word_embeddings: - logits = self.model.embed_tokens.as_linear(hidden_states) + logits = self.model.embed_tokens.as_linear(outputs) else: - logits = self.lm_head(hidden_states) + logits = self.lm_head(outputs) - logits = logits[..., : self.vocab_size] - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :] - shift_labels = labels[..., 1:] - # Flatten the tokens - loss_fct = nn.losses.cross_entropy - shift_logits = shift_logits.reshape((-1, self.config.vocab_size)) - shift_labels = shift_labels.reshape((-1,)) - # Enable model parallelism - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - if not isinstance(outputs, BaseModelOutputWithPast): - raise ValueError( - f"Unexpected output type for causal language model: {type(outputs)} != BaseModelOutputWithPast" - ) - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, - input_ids: mx.array, - past_key_values: Optional[PlamoCache] = None, - attention_mask: Optional[mx.array] = None, - inputs_embeds: Optional[mx.array] = None, - image_features: Optional[mx.array] = None, - **kwargs: Any, - ) -> dict[str, Any]: - if past_key_values: - input_ids = input_ids[:, -1:] - if image_features is not None: - image_features = image_features[:, -1:, :] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.astype(mx.int64).cumsum(-1) - 1 - position_ids = mx.where(attention_mask == 0, 1, position_ids) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs: dict[str, Any] = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "image_features": image_features, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values: PlamoCache, beam_idx: mx.array) -> PlamoCache: - past_key_values.reorder_cache(beam_idx) - return past_key_values + return logits @property def layers(self):