From 00d13ebd40e82c40a05853df0c43cf10547dd618 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Fri, 14 Feb 2025 01:51:06 +0900 Subject: [PATCH] Fix some part --- llms/mlx_lm/models/plamo2.py | 833 +++++++++++------------------------ 1 file changed, 253 insertions(+), 580 deletions(-) diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index e7ffe488..b78c6530 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -7,21 +7,17 @@ from typing import Any, Literal, NamedTuple, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask def _is_first_token(mask: mx.array) -> mx.array: - assert mask.dtype == mx.bool_ + assert mask.dtype == mx.bool_ # type: ignore B, Nh, q_len, kv_len = mask.shape mask = mask[:, :, :, -q_len:] cont = q_len != kv_len v = False if cont else True - out = mx.logical_not( - mx.diagonal(mask, offset=-1, axis1=-2, axis2=-1).astype(mx.bool_) - ) - out = mx.concatenate( - [mx.full(shape=(B, Nh, 1), dtype=mx.bool_, vals=v), out], axis=-1 - ) + out = mx.logical_not(mx.diagonal(mask, offset=-1, axis1=-2, axis2=-1).astype(mx.bool_)) # type: ignore + out = mx.concatenate([mx.full(shape=(B, Nh, 1), dtype=mx.bool_, vals=v), out], axis=-1) # type: ignore return out @@ -38,17 +34,13 @@ def _swiglu(h: mx.array) -> mx.array: class RotaryEmbedding(nn.Module): - def __init__( - self, dim: int, max_position_embeddings: int = 2048, base: int = 10000 - ) -> None: + def __init__(self, dim: int, max_position_embeddings: int = 2048, base: int = 10000) -> None: super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / ( - self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim) - ) + inv_freq = 1.0 / (self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)) self._inv_freq = inv_freq # Build here to make `torch.jit.trace` work. @@ -82,9 +74,7 @@ def _rotate_half(x: mx.array) -> mx.array: return mx.concatenate([-x2, x1], axis=-1) -def _rotary_pos_emb( - x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array -) -> mx.array: +def _rotary_pos_emb(x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array) -> mx.array: # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] @@ -156,9 +146,7 @@ class ModelArgs(BaseModelArgs): # type: ignore self.hidden_size_per_head = hidden_size_per_head self.num_key_value_heads = num_key_value_heads self.attention_window_size = attention_window_size - self.full_attention_idx = ( - full_attention_idx if full_attention_idx is not None else [] - ) + self.full_attention_idx = full_attention_idx if full_attention_idx is not None else [] self.mamba_d_state = mamba_d_state self.mamba_d_conv = mamba_d_conv @@ -236,13 +224,9 @@ class PlamoCache(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config - self.cache: list[Optional[PlamoLayerCache]] = [ - None for _ in range(config.num_hidden_layers) - ] + self.cache: list[Optional[PlamoLayerCache]] = [None for _ in range(config.num_hidden_layers)] - def append_kv( - self, key: mx.array, value: mx.array, layer_idx: int - ) -> tuple[mx.array, mx.array]: + def append_kv(self, key: mx.array, value: mx.array, layer_idx: int) -> tuple[mx.array, mx.array]: c = self.cache[layer_idx] if c is None: return key, value @@ -258,13 +242,9 @@ class PlamoCache(nn.Module): _validate(c.key, key) _validate(c.value, value) assert key.shape[2] == value.shape[2] - return mx.concatenate([c.key, key], axis=2), mx.concatenate( - [c.value, value], axis=2 - ) + return mx.concatenate([c.key, key], axis=2), mx.concatenate([c.value, value], axis=2) - def update_attention( - self, key_states: mx.array, value_states: mx.array, layer_idx: int - ) -> PlamoAttentionCache: + def update_attention(self, key_states: mx.array, value_states: mx.array, layer_idx: int) -> PlamoAttentionCache: full_attn = layer_idx in self.config.full_attention_idx window_size = self.config.attention_window_size @@ -286,11 +266,10 @@ class PlamoCache(nn.Module): else: c.key = k[:, :, -window_size:, :] c.value = v[:, :, -window_size:, :] + self.cache[layer_idx] = c return self.cache[layer_idx] # type: ignore - def update_mamba( - self, conv_state: mx.array, ssm_state: mx.array, layer_idx: int - ) -> PlamoMambaCache: + def update_mamba(self, conv_state: mx.array, ssm_state: mx.array, layer_idx: int) -> PlamoMambaCache: if self.cache[layer_idx] is None: self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state) else: @@ -330,9 +309,7 @@ class PlamoCache(nn.Module): def get_max_length(self) -> int | None: return None - def get_usable_length( - self, new_seq_length: int, layer_idx: Optional[int] = 0 - ) -> int: + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: """Given the sequence length of the new inputs, returns the usable length of the cache.""" # Cache without size limit -> all cache is usable # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache @@ -388,9 +365,7 @@ class DecoderOutput(NamedTuple): # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: tuple[int, int], dtype: mx.Dtype, past_key_values_length: int = 0 -) -> mx.array: +def _make_causal_mask(input_ids_shape: tuple[int, int], dtype: mx.Dtype, past_key_values_length: int = 0) -> mx.array: """ Make causal mask used for bi-directional self-attention. """ @@ -401,36 +376,26 @@ def _make_causal_mask( mask = mask.astype(dtype) if past_key_values_length > 0: - mask = mx.concatenate( - [mx.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1 - ) - return mx.broadcast_to( - mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length) - ) + mask = mx.concatenate([mx.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1) + return mx.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length)) # Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask( - mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None -) -> mx.array: +def _expand_mask(mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None) -> mx.array: """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.shape tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = mx.broadcast_to( - mask[:, None, None, :], (bsz, 1, tgt_len, src_len) - ).astype(dtype) + expanded_mask = mx.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)).astype(dtype) inverted_mask = 1.0 - expanded_mask return mx.where(inverted_mask.astype(mx.bool_), float("-inf"), inverted_mask) # type: ignore -def _rms_norm( - hidden_states: mx.array, weight: Optional[mx.array], eps: float, offset: float = 1.0 -) -> mx.array: +def _rms_norm(hidden_states: mx.array, weight: Optional[mx.array], eps: float, offset: float = 1.0) -> mx.array: input_dtype = hidden_states.dtype hidden_states = hidden_states.astype(mx.float32) variance = mx.power(hidden_states, 2).mean(-1, keepdims=True) @@ -454,18 +419,13 @@ class RMSNorm(nn.Module): self.offset = offset def __call__(self, hidden_states: mx.array) -> mx.array: - return _rms_norm( - hidden_states, self.weight, self.variance_epsilon, offset=self.offset - ) + return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset) def get_initial_dt_bias(num_heads: int) -> mx.array: dt_min = 0.001 dt_max = 0.1 - dt = mx.exp( - mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ) + dt = mx.exp(mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) dt = mx.clip(dt, a_min=1e-4, a_max=None) inv_dt = dt + mx.log(-mx.expm1(-dt)) return inv_dt @@ -476,6 +436,78 @@ def get_initial_A(num_heads: int) -> mx.array: return mx.log(A) +def selective_state_update_ref( + state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False +) -> tuple[mx.array, mx.array]: + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.ndim > 3 + if state.ndim == 3: + state = mx.expand_dims(state, 1) + if x.ndim == 2: + x = mx.expand_dims(x, 1) + if dt.ndim == 2: + dt = mx.expand_dims(dt, 1) + if A.ndim == 2: + A = mx.expand_dims(A, 0) + if B.ndim == 2: + B = mx.expand_dims(B, 1) + if C.ndim == 2: + C = mx.expand_dims(C, 1) + if D is not None and D.ndim == 1: + D = mx.expand_dims(D, 0) + if z is not None and z.ndim == 2: + z = mx.expand_dims(z, 1) + if dt_bias is not None and dt_bias.ndim == 1: + dt_bias = mx.expand_dims(dt_bias, 0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + dt = dt + dt_bias + dt = nn.softplus(dt) if dt_softplus else dt + dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate) + B = mx.reshape( + mx.tile(mx.expand_dims(B, axis=2), (1, 1, nheads // ngroups, 1)), + (batch, nheads, dstate), + ) # (batch, nheads, dstate) + C = mx.reshape( + mx.tile(mx.expand_dims(C, axis=2), (1, 1, nheads // ngroups, 1)), + (batch, nheads, dstate), + ) # (batch, nheads, dstate) + dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(B, axis=-2) # (batch, nheads, dim, dstate) + state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate + out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C) + if D is not None: + out += (x * D).astype(out.dtype) + out = (out if z is None else out * nn.silu(z)).astype(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out, state + + def ssd_update_state( ssm_state: mx.array, x: mx.array, @@ -489,32 +521,24 @@ def ssd_update_state( dt_softplus: bool, ) -> tuple[mx.array, mx.array]: assert ssm_state.dtype == mx.float32 - dtype = x.dtype - f = selective_state_update_ref hidden_size_per_head = x.shape[-1] d_state = B.shape[-1] - A = mx.broadcast_to( - A[:, None, None], (A.shape[0], hidden_size_per_head, d_state) - ).astype(mx.float32) - dt = mx.broadcast_to( - dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head) - ) - dt_bias = mx.broadcast_to( - dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head) - ) + A = mx.broadcast_to(A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)).astype(mx.float32) + dt = mx.broadcast_to(dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head)) + dt_bias = mx.broadcast_to(dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head)) D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head)) assert ssm_state.dtype == mx.float32 - out, ssm_state = f( + out, ssm_state = selective_state_update_ref( ssm_state, - x.astype(dtype), - dt.astype(dtype), - A.astype(mx.float32), - B.astype(dtype), - C.astype(dtype), - D.astype(mx.float32), - z.astype(dtype), - dt_bias.astype(mx.float32), + x, + dt, + A, + B, + C, + D, + z, + dt_bias, dt_softplus=dt_softplus, ) return out[:, None], ssm_state @@ -550,7 +574,7 @@ def _ssd_chunk_scan_combined_naive( A, B[:, i], C[:, i], - D, + D if D.ndim == 1 else D[:, i], z=z[:, i], dt_bias=dt_bias, dt_softplus=dt_softplus, @@ -570,14 +594,12 @@ def ssd_chunk_scan_combined( z: mx.array, dt_bias: mx.array, dt_softplus: bool, - return_final_states: bool, seq_idx: mx.array | None, ssm_state: mx.array | None, -) -> tuple[mx.array, mx.array] | mx.array: +) -> tuple[mx.array, mx.array]: if seq_idx is not None: assert seq_idx.dtype == mx.int32 assert ssm_state is None - assert not return_final_states if ssm_state is not None: assert ssm_state.dtype == mx.float32 assert seq_idx is None @@ -595,7 +617,7 @@ def ssd_chunk_scan_combined( bsize, _, num_heads, channel = x.shape state = B.shape[-1] ssm_state = mx.zeros((bsize, num_heads, channel, state), dtype=mx.float32) - tmp = _ssd_chunk_scan_combined_naive( + tmp, ssm_state = _ssd_chunk_scan_combined_naive( x, dt, A, @@ -608,10 +630,7 @@ def ssd_chunk_scan_combined( seq_idx=seq_idx, ssm_state=ssm_state, ) - if return_final_states: - return tmp - else: - return tmp[0] + return tmp, ssm_state def _causal_conv1d( @@ -642,9 +661,7 @@ def _causal_conv1d( mx.zeros_like(conv_state), conv_state, ) - out[:, :, i : i + 1], conv_state = _causal_conv1d_update( - conv_state, weight, x[:, :, i : i + 1] - ) + out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1]) x = out if return_final_states: return x, conv_state @@ -652,9 +669,61 @@ def _causal_conv1d( return x, None -def _causal_conv1d_update( - conv_state: mx.array, weight: mx.array, xBC: mx.array +def causal_conv1d_update( + x, conv_state, weight, bias=None, activation=None, cache_seqlens=None ) -> tuple[mx.array, mx.array]: + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.ndim == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = mx.concatenate([conv_state, x], axis=-1).astype(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state = x_new[:, :, -state_len:] + else: + width_idx = mx.expand_dims(mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0) + mx.expand_dims( + cache_seqlens, axis=1 + ) + width_idx = mx.expand_dims(mx.remainder(width_idx, state_len), axis=1) + width_idx = mx.broadcast_to(width_idx, (width_idx.shape[0], dim, width_idx.shape[2])) + x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1) + x_new = x_new.astype(weight.dtype) + copy_idx = mx.expand_dims(mx.arange(seqlen, dtype=mx.int64), axis=0) + mx.expand_dims(cache_seqlens, axis=1) + copy_idx = mx.expand_dims(mx.remainder(copy_idx, state_len), axis=1) + copy_idx = mx.broadcast_to(copy_idx, (copy_idx.shape[0], dim, copy_idx.shape[2])) + conv_state.scatter_(2, copy_idx, x) + assert bias is None + # x_new: (N, C, L) -> (N, L, C) + out = mx.conv1d( + x_new.transpose(0, 2, 1), + mx.expand_dims(weight, axis=2), + padding=0, + groups=dim, + ).transpose(0, 2, 1)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else nn.silu(out)).astype(dtype_in), conv_state + + +def _causal_conv1d_update(conv_state: mx.array, weight: mx.array, xBC: mx.array) -> tuple[mx.array, mx.array]: dtype = conv_state.dtype xBC = xBC.astype(dtype) weight = weight.astype(dtype) @@ -680,6 +749,7 @@ def is_mamba(config: ModelArgs, i: int) -> bool: return (i % config.mamba_step) != (config.mamba_step // 2) +# Based on: https://github.com/Dao-AILab/causal-conv1d/blob/82867a9d2e6907cc0f637ac6aff318f696838548/causal_conv1d/causal_conv1d_interface.py#L206 def causal_conv1d(x, weight, bias=None, activation=None): """ MLX implementation of a causal depthwise 1D convolution. @@ -731,352 +801,10 @@ def causal_conv1d(x, weight, bias=None, activation=None): return y -def mamba_chunk_scan_combined( - x, - dt, - A, - B, - C, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - dt_softplus=False, - return_final_states=False, -): - """ - MLX implementation of the Mamba chunk-wise scan. - Args: - x (mx.array): Input sequence of shape (batch, seqlen, nheads, head_dim). - dt (mx.array or scalar): Time-step factor(s) for the state update. - A, B, C (mx.array): State-space parameters (see notes). - chunk_size (int): Length of chunks to split the sequence. - D (mx.array, optional): Optional direct output weights. - z (mx.array, optional): Optional gating input for output modulation. - dt_bias (mx.array, optional): Optional bias to add to dt. - initial_states (mx.array, optional): Initial state for the recurrence. - dt_softplus (bool): If True, apply softplus to dt. - return_final_states (bool): If True, return final state of sequence. - Returns: - mx.array (or tuple): Output sequence (batch, seqlen, output_dim), and final states if requested. - """ - # Ensure inputs are MLX arrays - x = mx.array(x) if not isinstance(x, mx.array) else x - A = mx.array(A) if not isinstance(A, mx.array) else A - B = mx.array(B) if not isinstance(B, mx.array) else B - C = mx.array(C) if not isinstance(C, mx.array) else C - if D is not None: - D = mx.array(D) if not isinstance(D, mx.array) else D - if z is not None: - z = mx.array(z) if not isinstance(z, mx.array) else z - if dt_bias is not None: - dt_bias = mx.array(dt_bias) if not isinstance(dt_bias, mx.array) else dt_bias - dt = mx.array(dt) if not isinstance(dt, mx.array) else dt - - batch, seq_len, nheads, head_dim = x.shape - - # If needed, apply softplus to dt to ensure positivity (as in original code) - if dt_softplus: - dt = mx.log(1 + mx.exp(dt)) # softplus: log(1 + exp(dt)) - if dt_bias is not None: - dt = dt + dt_bias # incorporate bias to dt if provided - - # Prepare initial state - state_dim = A.shape[ - -1 - ] # assume A is of shape (nheads, state_dim) for diagonal A or (nheads, state_dim, state_dim) - if initial_states is None: - # Initialize state to zero for each sequence in batch and each head​:contentReference[oaicite:3]{index=3} - state = mx.zeros((batch, nheads, state_dim), dtype=A.dtype) - else: - state = ( - mx.array(initial_states) - if not isinstance(initial_states, mx.array) - else initial_states - ) - - # Precompute exponent of A*dt for state update per step (assuming A is diagonal or elementwise applicable) - # If A is diagonal values per head (shape (nheads, state_dim)), we compute elementwise exponentials. - exp_dA = None - if A.ndim == 2 or A.ndim == 1: - # A is given as diagonal values per state - exp_dA = mx.exp(A * dt) # shape (nheads, state_dim) or (state_dim,) - if exp_dA.ndim == 2: - exp_dA = exp_dA.reshape( - (1, nheads, state_dim) - ) # shape (1, nheads, state_dim) for broadcasting - else: - # If A is a full matrix per head, use matrix exponential (if available in MLX) - exp_dA = mx.exp( - A * dt - ) # assuming MX can exponentiate matrix elementwise or use specialized routine - - # Output buffer - out_list = [] # will collect output chunks - - # Process sequence in chunks - num_chunks = (seq_len + chunk_size - 1) // chunk_size # ceiling division - for ch in range(num_chunks): - start = ch * chunk_size - end = min((ch + 1) * chunk_size, seq_len) - x_chunk = x[:, start:end, :, :] # shape (batch, chunk_len, nheads, head_dim) - chunk_len = x_chunk.shape[1] - - # If gating input z is provided (e.g., a per-head modulation), slice it for this chunk as well - if z is not None: - z_chunk = ( - z[:, start:end, :] if z.shape[1] == seq_len else z - ) # shape (batch, chunk_len, nheads) or (batch, chunk_len, output_dim) - - # Iterate through time steps within the chunk - # (This loop is on the chunk length, which is at most chunk_size for performance) - for t in range(chunk_len): - x_t = x_chunk[:, t, :, :] # (batch, nheads, head_dim) - # Compute state increment from input: B * u(t). - # If B is shape (nheads, state_dim, head_dim) or (nheads, state_dim) for 1D input: - if B.ndim == 3: - # Perform matrix multiplication for each head: - # out shape (batch, nheads, state_dim) - inc = mx.einsum("h n d, b h d -> b h n", B, x_t) - else: - # B is (nheads, state_dim) or (state_dim,) meaning one input per state - # In this case, head_dim should be 1 - inc = B.reshape((1, nheads, state_dim)) * x_t # broadcast multiply - # If dt is not already applied in B, multiply by dt (assuming continuous-time formulation, dt scaling) - inc = inc * dt - - # State update: s_{t+1} = exp(A*dt) * s_t + inc​:contentReference[oaicite:4]{index=4}. - state = ( - state * exp_dA + inc - ) # elementwise multiply if exp_dA broadcast shape (1, nheads, state_dim) - - # Compute output for this time step: y_t = C * state_t + (D * x_t if direct term exists) - if C.ndim == 3: - # C shape (nheads, output_dim, state_dim), do einsum for each head - y_t = mx.einsum("h d n, b h n -> b h d", C, state) - else: - # C shape (nheads, state_dim) or (state_dim,), output one value per head - # Multiply and sum over state_dim - y_t = mx.sum( - state * C.reshape((1, nheads, state_dim)), axis=-1, keepdims=True - ) # (batch, nheads, 1) - if D is not None: - # Add direct input contribution: D * x(t) - if D.ndim == 2: - # D shape (nheads, output_dim) - y_t += mx.einsum( - "h d, b h d0 -> b h d", - D, - x_t[..., None] if x_t.ndim == 3 else x_t, - ) - else: - # D shape (nheads,) or scalar - y_t += D.reshape((1, nheads, -1)) * ( - x_t if x_t.ndim == 3 else x_t[..., None] - ) - # Apply gating activation if provided (e.g., elementwise multiply by a sigmoidal function of z) - if z is not None: - # Example: if z is meant to gate outputs via a sigmoid activation (silu), as in some Mamba variants - # We'll assume z_chunk provides an additive bias or multiplier for output. - # Here we apply SiLU gating: output * sigmoid(z) - y_t = y_t * mx.sigmoid(z_chunk[:, t, :].reshape(y_t.shape)) - out_list.append(y_t) - # end of chunk loop - # end of all chunks - - # Concatenate outputs from all chunks and reshape to (batch, seq_len, output_dim_total) - y = ( - mx.concatenate(out_list, axis=1) if isinstance(out_list, list) else out_list - ) # list contains (batch, nheads, output_dim) at each time - # After concatenation, y shape is (batch, seq_len * nheads, output_dim) if each y_t was (batch, nheads, output_dim). - # We should reshape to (batch, seq_len, nheads*output_dim) for final output sequence. - if isinstance(y, mx.array): - # If output was built as MLX array directly: - out = y.reshape((batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1))) - else: - # If out_list was used and concatenated via Python, we might get a NumPy array; ensure it's MLX: - out = mx.array(y).reshape( - (batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1)) - ) - - if return_final_states: - # Return the final state as well (state holds final state after last chunk) - return out, state - return out - - -class PlamoPreTrainedModel(nn.Module): # type: ignore - config_class = ModelArgs - _no_split_modules: list[str] - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["PlamoDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - - def _init_weights(self, module: nn.Module) -> None: - std = 0.02 - if isinstance(module, nn.Linear): - module.weight = mx.random.normal( - loc=0.0, scale=std, shape=module.weight.shape - ) - if module.bias is not None: - module.bias = mx.zeros_like(module.bias) - elif isinstance(module, nn.Embedding): - module.weight = mx.random.normal( - loc=0.0, scale=std, shape=module.weight.shape - ) - if module.padding_idx is not None: - module.weight[module.padding_idx] = mx.zeros_like( - module.weight[module.padding_idx] - ) - - -def causal_conv1d_update( - x, conv_state, weight, bias=None, activation=None, cache_seqlens=None -) -> tuple[mx.array, mx.array]: - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len before performing the convolution. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - unsqueeze = x.ndim == 2 - if unsqueeze: - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - width = weight.shape[1] - state_len = conv_state.shape[-1] - assert conv_state.shape == (batch, dim, state_len) - assert weight.shape == (dim, width) - if cache_seqlens is None: - x_new = mx.concatenate([conv_state, x], axis=-1).astype( - weight.dtype - ) # (batch, dim, state_len + seqlen) - conv_state = x_new[:, :, -state_len:] - else: - width_idx = mx.expand_dims( - mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0 - ) + mx.expand_dims(cache_seqlens, axis=1) - width_idx = mx.expand_dims(mx.remainder(width_idx, state_len), axis=1) - width_idx = mx.broadcast_to( - width_idx, (width_idx.shape[0], dim, width_idx.shape[2]) - ) - x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1) - x_new = x_new.astype(weight.dtype) - copy_idx = mx.expand_dims( - mx.arange(seqlen, dtype=mx.int64), axis=0 - ) + mx.expand_dims(cache_seqlens, axis=1) - copy_idx = mx.expand_dims(mx.remainder(copy_idx, state_len), axis=1) - copy_idx = mx.broadcast_to( - copy_idx, (copy_idx.shape[0], dim, copy_idx.shape[2]) - ) - conv_state.scatter_(2, copy_idx, x) - assert bias is None - # x_new: (N, C, L) -> (N, L, C) - out = mx.conv1d( - x_new.transpose(0, 2, 1), - mx.expand_dims(weight, axis=2), - padding=0, - groups=dim, - ).transpose(0, 2, 1)[:, :, -seqlen:] - if unsqueeze: - out = out.squeeze(-1) - return (out if activation is None else nn.silu(out)).astype(dtype_in), conv_state - - -def selective_state_update_ref( - state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False -) -> tuple[mx.array, mx.array]: - """ - Argument: - state: (batch, dim, dstate) or (batch, nheads, dim, dstate) - x: (batch, dim) or (batch, nheads, dim) - dt: (batch, dim) or (batch, nheads, dim) - A: (dim, dstate) or (nheads, dim, dstate) - B: (batch, dstate) or (batch, ngroups, dstate) - C: (batch, dstate) or (batch, ngroups, dstate) - D: (dim,) or (nheads, dim) - z: (batch, dim) or (batch, nheads, dim) - dt_bias: (dim,) or (nheads, dim) - Return: - out: (batch, dim) or (batch, nheads, dim) - """ - has_heads = state.ndim > 3 - if state.ndim == 3: - state = state.unsqueeze(1) - if x.ndim == 2: - x = x.unsqueeze(1) - if dt.ndim == 2: - dt = dt.unsqueeze(1) - if A.ndim == 2: - A = A.unsqueeze(0) - if B.ndim == 2: - B = B.unsqueeze(1) - if C.ndim == 2: - C = C.unsqueeze(1) - if D is not None and D.ndim == 1: - D = D.unsqueeze(0) - if z is not None and z.ndim == 2: - z = z.unsqueeze(1) - if dt_bias is not None and dt_bias.ndim == 1: - dt_bias = dt_bias.unsqueeze(0) - batch, nheads, dim, dstate = state.shape - assert x.shape == (batch, nheads, dim) - assert dt.shape == x.shape - assert A.shape == (nheads, dim, dstate) - ngroups = B.shape[1] - assert nheads % ngroups == 0, "nheads must be divisible by ngroups" - assert B.shape == (batch, ngroups, dstate) - assert C.shape == B.shape - if D is not None: - assert D.shape == (nheads, dim) - if z is not None: - assert z.shape == x.shape - if dt_bias is not None: - assert dt_bias.shape == (nheads, dim) - dt = dt + dt_bias - dt = nn.softplus(dt) if dt_softplus else dt - dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate) - B = mx.reshape( - mx.tile(mx.expand_dims(B, axis=2), (1, 1, nheads // ngroups, 1)), - (batch, nheads, dstate), - ) # (batch, nheads, dstate) - C = mx.reshape( - mx.tile(mx.expand_dims(C, axis=2), (1, 1, nheads // ngroups, 1)), - (batch, nheads, dstate), - ) # (batch, nheads, dstate) - dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims( - B, axis=-2 - ) # (batch, nheads, dim, dstate) - state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate - out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C) - if D is not None: - out += (x * D).astype(out.dtype) - out = (out if z is None else out * nn.silu(z)).astype(x.dtype) - if not has_heads: - out = out.squeeze(1) - return out, state - - def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array: max_len = max(q_len, kv_len) mask = mx.tril( - mx.triu(mx.ones((max_len, max_len), dtype=mx.bool_), k=-window_size), + mx.triu(mx.ones((max_len, max_len), dtype=mx.bool_), k=-window_size), # type: ignore k=window_size, ) return mask[-q_len:, -kv_len:] @@ -1106,16 +834,12 @@ class Attention(nn.Module): self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, bias=False, ) - self.o_proj = nn.Linear( - self.q_num_heads * self.v_dim, self.hidden_size, bias=False - ) + self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False) self.q_weight = mx.ones((self.q_num_heads, self.qk_dim)) self.k_weight = mx.ones((self.k_num_heads, self.qk_dim)) - self.rotary_emb = RotaryEmbedding( - self.qk_dim, max_position_embeddings=self.config.attention_window_size - ) + self.rotary_emb = RotaryEmbedding(self.qk_dim, max_position_embeddings=self.config.attention_window_size) def __call__( self, @@ -1130,33 +854,21 @@ class Attention(nn.Module): query_states, key_states, value_states = mx.split( qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1 ) - query_states = query_states.reshape( - bsz, q_len, self.q_num_heads, self.qk_dim - ).transpose(0, 2, 1, 3) - key_states = key_states.reshape( - bsz, q_len, self.k_num_heads, self.qk_dim - ).transpose(0, 2, 1, 3) - value_states = value_states.reshape( - bsz, q_len, self.v_num_heads, self.v_dim - ).transpose(0, 2, 1, 3) + query_states = query_states.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3) + key_states = key_states.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) + value_states = value_states.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) attn_dtype = query_states.dtype - query_states = ( - _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None] - ) + query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None] key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None] if past_states is not None: # reuse k, v, self_attention key_states_new = key_states value_states_new = value_states - key_states, value_states = past_states.append_kv( - key_states, value_states, self.layer_idx - ) # type: ignore - past_states.update_attention( - key_states_new, value_states_new, self.layer_idx - ) + key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) # type: ignore + past_states.update_attention(key_states_new, value_states_new, self.layer_idx) kv_seq_len = key_states.shape[-2] position_ids = mx.arange(kv_seq_len, dtype=mx.int64)[None] @@ -1191,9 +903,7 @@ class Attention(nn.Module): ) else: if attention_mask.dtype == bool: - attention_mask = mx.where( - attention_mask, mx.array(0.0, dtype=mx.float16), float("-inf") - ) + attention_mask = mx.where(attention_mask, mx.array(0.0, dtype=mx.float16), float("-inf")) if len(attention_mask.shape) == 2: attention_mask = attention_mask[None, None] assert len(attention_mask.shape) == 4 @@ -1206,20 +916,14 @@ class Attention(nn.Module): ) # `generate` function creates attention mask that does not consider sliding window m_swa = m_swa[None, None] - attention_mask = attention_mask[ - :, :, -query_states.shape[2] :, -key_states.shape[2] : - ] + attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :] attention_mask = mx.where(m_swa, attention_mask, float("-inf")) # like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers, # we need to attend to all tokens in masked rows for `scaled_dot_product_attention` bool_mask = mx.logical_not(mx.isneginf(attention_mask)) - valid_tokens = mx.sum(bool_mask, axis=-1).astype( - mx.bool_ - ) # (..., q_len) # type: ignore - attention_mask = mx.where( - valid_tokens[..., None], attention_mask, float(0.0) - ) + valid_tokens = mx.sum(bool_mask, axis=-1).astype(mx.bool_) # type: ignore # (..., q_len) + attention_mask = mx.where(valid_tokens[..., None], attention_mask, float(0.0)) attn_output = mx.fast.scaled_dot_product_attention( query_states, key_states, @@ -1254,9 +958,7 @@ class Mamba(nn.Module): self.intermediate_size = self.num_heads * self.hidden_size_per_head - self.in_proj = nn.Linear( - self.hidden_size, 2 * self.intermediate_size, bias=False - ) + self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False) self.conv1d = nn.Conv1d( in_channels=self.intermediate_size, out_channels=self.intermediate_size, @@ -1278,7 +980,7 @@ class Mamba(nn.Module): self.dt_bias = get_initial_dt_bias(self.num_heads) self.A_log = get_initial_A(self.num_heads) - self.D = mx.ones(self.num_heads) + self.D = mx.ones(self.num_heads, dtype=mx.float32) # TODO norm weight before gating like Mamba2 self.dt_norm_weight = mx.ones(self.dt_dim) @@ -1342,7 +1044,7 @@ class Mamba(nn.Module): ssm_state = c.ssm_state zx = self.in_proj(hidden_states) - zx = zx.reshape(bsize, length, self.num_heads, -1) + zx = zx.reshape(bsize, length, self.num_heads, -1).astype(mx.float32) # z: (bsize, length, num_heads, hidden_size_per_head) # x: (bsize, length, num_heads, hidden_size_per_head) z, x = mx.split( @@ -1354,44 +1056,31 @@ class Mamba(nn.Module): ) # conv - x = x.reshape(bsize, length, -1).transpose( - 0, 2, 1 - ) # (bsize, intermediate_size, length) + x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length) if bool_mask is not None: x = mx.where(bool_mask[:, None, :], x, 0.0) if is_update: assert conv_state is not None x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x) else: - x, conv_state = _causal_conv1d( - conv_state, self.conv1d.weight, x, seq_idx=seq_idx - ) - x = x.astype(hidden_states.dtype) + x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx) + x = x.astype(mx.float32) x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size) x = x.reshape(bsize, length, -1) # x: (bsize, length, num_heads, hidden_size_per_head) # B: (bsize, length, 1, d_state) # C: (bsize, length, 1, d_state) # dt: (bsize, length, dt_dim) - BCdt = self.bcdt_proj(x) + BCdt = self.bcdt_proj(x).astype(mx.float32) x = x.reshape(bsize, length, self.num_heads, -1) B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1) B = B[:, :, None, :] C = C[:, :, None, :] A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,) - dt = ( - _rms_norm(dt, None, self.config.rms_norm_eps) - * self.dt_norm_weight[None, None, :] - ) - B = ( - _rms_norm(B, None, self.config.rms_norm_eps) - * self.B_norm_weight[None, None, None, :] - ) - C = ( - _rms_norm(C, None, self.config.rms_norm_eps) - * self.C_norm_weight[None, None, None, :] - ) + dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :] + B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :] + C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :] # (bsize, length, num_heads, 1) dt = self.dt_proj(dt)[..., None] @@ -1415,6 +1104,8 @@ class Mamba(nn.Module): x = mx.where(bool_mask[:, :, None, None], x, 0.0) # ssm + self.D = self.D.astype(mx.float32) + self.dt_bias = self.dt_bias.astype(mx.float32) if is_update: assert ssm_state is not None out, ssm_state = ssd_update_state( @@ -1430,7 +1121,7 @@ class Mamba(nn.Module): dt_softplus=True, ) else: - tmp = ssd_chunk_scan_combined( + out, ssm_state = ssd_chunk_scan_combined( x, dt.reshape(bsize, length, -1), A, @@ -1441,15 +1132,9 @@ class Mamba(nn.Module): z=z, dt_bias=self.dt_bias, dt_softplus=True, - return_final_states=past_states is not None, seq_idx=seq_idx, ssm_state=ssm_state, ) - if past_states is not None: - out, ssm_state = tmp - else: - assert isinstance(tmp, mx.array) - out = tmp y = self.out_proj(out.reshape(bsize, length, -1)) @@ -1467,9 +1152,7 @@ class MLP(nn.Module): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_up_proj = nn.Linear( - self.hidden_size, self.intermediate_size * 2, bias=False - ) + self.gate_up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) def __call__(self, x: mx.array) -> mx.array: @@ -1493,18 +1176,10 @@ class PlamoDecoderLayer(nn.Module): """ Notes: The model performance was degraded when setting all offsets to 1. """ - self.pre_mixer_norm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps, offset=1.0 - ) - self.post_mixer_norm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5 - ) - self.pre_mlp_norm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps, offset=1.0 - ) - self.post_mlp_norm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5) - ) + self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0) + self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5) + self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0) + self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5)) def __call__( self, @@ -1565,15 +1240,11 @@ class PlamoDecoder(nn.Module): self.gradient_checkpointing = False def __call__(self, x: DecoderInput) -> DecoderOutput: - all_hidden_states: Optional[tuple[mx.array, ...]] = ( - () if x.output_hidden_states else None - ) - all_self_attns: Optional[tuple[mx.array, ...]] = ( - () if x.output_attentions else None - ) + all_hidden_states: Optional[tuple[mx.array, ...]] = () if x.output_hidden_states else None + all_self_attns: Optional[tuple[mx.array, ...]] = () if x.output_attentions else None hidden_states = x.hidden_states - for decoder_layer in self.layers: + for layer_i, decoder_layer in enumerate(self.layers): if x.output_hidden_states: assert all_hidden_states is not None all_hidden_states += (hidden_states,) @@ -1653,10 +1324,8 @@ class BaseModelOutputWithPast(ModelOutput): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.last_hidden_state: mx.array = kwargs.pop("last_hidden_state") - self.past_key_values: Optional[PlamoCache] = kwargs.pop("past_key_values", None) - self.hidden_states: Optional[tuple[mx.array, ...]] = kwargs.pop( - "hidden_states", None - ) + self.past_key_values: Optional[tuple[tuple[mx.array]]] = kwargs.pop("past_key_values", None) + self.hidden_states: Optional[tuple[mx.array, ...]] = kwargs.pop("hidden_states", None) self.attentions: Optional[tuple[mx.array, ...]] = kwargs.pop("attentions", None) @@ -1693,15 +1362,36 @@ class CausalLMOutputWithPast(ModelOutput): self.loss: Optional[mx.array] = kwargs.pop("loss", None) self.logits: mx.array | None = kwargs.pop("logits", None) - self.past_key_values: Optional[tuple[tuple[mx.array]]] = kwargs.pop( - "past_key_values", None - ) - self.hidden_states: Optional[tuple[mx.array, ...]] = kwargs.pop( - "hidden_states", None - ) + self.past_key_values: Optional[tuple[tuple[mx.array]]] = kwargs.pop("past_key_values", None) + self.hidden_states: Optional[tuple[mx.array, ...]] = kwargs.pop("hidden_states", None) self.attentions: Optional[tuple[mx.array, ...]] = kwargs.pop("attentions", None) +class PlamoPreTrainedModel(nn.Module): # type: ignore + config_class = ModelArgs + _no_split_modules: list[str] + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PlamoDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + + def _init_weights(self, module: nn.Module) -> None: + std = 0.02 + if isinstance(module, nn.Linear): + module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape) + if module.bias is not None: + module.bias = mx.zeros_like(module.bias) + elif isinstance(module, nn.Embedding): + module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape) + if module.padding_idx is not None: + module.weight[module.padding_idx] = mx.zeros_like(module.weight[module.padding_idx]) + + class PlamoModel(PlamoPreTrainedModel): def __init__(self, config: ModelArgs): super().__init__(config) @@ -1752,13 +1442,9 @@ class PlamoModel(PlamoPreTrainedModel): else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] assert inputs_embeds is not None - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask @@ -1777,33 +1463,21 @@ class PlamoModel(PlamoPreTrainedModel): return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutputWithPast]: assert input_ids is not None - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") seq_length_with_past = seq_length past_key_values_length = 0 @@ -1834,7 +1508,7 @@ class PlamoModel(PlamoPreTrainedModel): if require_attn_mask and attention_mask is None: attention_mask = mx.ones( (batch_size, seq_length_with_past), - dtype=mx.bool_, + dtype=mx.bool_, # type: ignore ) if attention_mask is not None: attention_mask = self._prepare_decoder_attention_mask( @@ -1914,11 +1588,9 @@ class Model(PlamoPreTrainedModel): vocab_size = ((self.vocab_size + 15) // 16) * 16 if not config.tie_word_embeddings: - self.lm_head: nn.Module = nn.Linear( - config.hidden_size, vocab_size, bias=False - ) + self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False) - self._past_key_values: Optional[tuple[tuple[mx.array]]] = None + self._cache: Optional[PlamoCache] = None # Initialize weights and apply final processing # self.post_init() @@ -1947,9 +1619,17 @@ class Model(PlamoPreTrainedModel): weights[k] = v.moveaxis(2, 1) return weights + def make_cache(self) -> PlamoCache: + return PlamoCache(self.config) + def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array: + if self._cache is not None: + plamo_cache = self._cache + else: + plamo_cache = None output = self.forward( input_ids=inputs, + past_key_values=plamo_cache, use_cache=self.config.use_cache, return_dict=True, ) @@ -1958,7 +1638,8 @@ class Model(PlamoPreTrainedModel): f"Unexpected output type for causal language model: {type(output)} != CausalLMOutputWithPast" ) if output.past_key_values is not None: - self._past_key_values = output.past_key_values + # output.past_key_values is actually a PlamoCache object + self._cache = output.past_key_values # type: ignore if output.logits is not None: return output.logits else: @@ -1999,19 +1680,11 @@ class Model(PlamoPreTrainedModel): ```""" assert input_ids is not None - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model(