diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index db141c63..fe037b0e 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -5,9 +5,8 @@ from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from transformers import PretrainedConfig -from mlx_lm.models.base import create_attention_mask +from .base import BaseModelArgs, create_attention_mask def _is_first_token(mask: mx.array) -> mx.array: @@ -16,8 +15,12 @@ def _is_first_token(mask: mx.array) -> mx.array: mask = mask[:, :, :, -q_len:] cont = q_len != kv_len v = False if cont else True - out = mx.logical_not(mx.diagonal(mask, offset=-1, axis1=-2, axis2=-1).astype(mx.bool_)) - out = mx.concatenate([mx.full(shape=(B, Nh, 1), dtype=mx.bool_, vals=v), out], axis=-1) + out = mx.logical_not( + mx.diagonal(mask, offset=-1, axis1=-2, axis2=-1).astype(mx.bool_) + ) + out = mx.concatenate( + [mx.full(shape=(B, Nh, 1), dtype=mx.bool_, vals=v), out], axis=-1 + ) return out @@ -34,13 +37,17 @@ def _swiglu(h: mx.array) -> mx.array: class RotaryEmbedding(nn.Module): - def __init__(self, dim: int, max_position_embeddings: int = 2048, base: int = 10000) -> None: + def __init__( + self, dim: int, max_position_embeddings: int = 2048, base: int = 10000 + ) -> None: super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)) + inv_freq = 1.0 / ( + self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim) + ) self._inv_freq = inv_freq # Build here to make `torch.jit.trace` work. @@ -74,7 +81,9 @@ def _rotate_half(x: mx.array) -> mx.array: return mx.concatenate([-x2, x1], axis=-1) -def _rotary_pos_emb(x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array) -> mx.array: +def _rotary_pos_emb( + x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array +) -> mx.array: # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] @@ -90,8 +99,9 @@ class LinearType(str, enum.Enum): Fp8Retain = "fp8-retain" -class ModelArgs(PretrainedConfig): # type: ignore - model_type: str = "plamo" +@dataclass +class ModelArgs(BaseModelArgs): # type: ignore + model_type: str = "plamo2" def __init__( self, @@ -145,7 +155,9 @@ class ModelArgs(PretrainedConfig): # type: ignore self.hidden_size_per_head = hidden_size_per_head self.num_key_value_heads = num_key_value_heads self.attention_window_size = attention_window_size - self.full_attention_idx = full_attention_idx if full_attention_idx is not None else [] + self.full_attention_idx = ( + full_attention_idx if full_attention_idx is not None else [] + ) self.mamba_d_state = mamba_d_state self.mamba_d_conv = mamba_d_conv @@ -221,9 +233,13 @@ class PlamoCache(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config - self.cache: List[Optional[PlamoLayerCache]] = [None for _ in range(config.num_hidden_layers)] + self.cache: List[Optional[PlamoLayerCache]] = [ + None for _ in range(config.num_hidden_layers) + ] - def append_kv(self, key: mx.array, value: mx.array, layer_idx: int) -> tuple[mx.array, mx.array]: + def append_kv( + self, key: mx.array, value: mx.array, layer_idx: int + ) -> tuple[mx.array, mx.array]: c = self.cache[layer_idx] if c is None: return key, value @@ -239,9 +255,13 @@ class PlamoCache(nn.Module): _validate(c.key, key) _validate(c.value, value) assert key.shape[2] == value.shape[2] - return mx.concatenate([c.key, key], axis=2), mx.concatenate([c.value, value], axis=2) + return mx.concatenate([c.key, key], axis=2), mx.concatenate( + [c.value, value], axis=2 + ) - def update_attention(self, key_states: mx.array, value_states: mx.array, layer_idx: int) -> PlamoAttentionCache: + def update_attention( + self, key_states: mx.array, value_states: mx.array, layer_idx: int + ) -> PlamoAttentionCache: full_attn = layer_idx in self.config.full_attention_idx window_size = self.config.attention_window_size @@ -265,7 +285,9 @@ class PlamoCache(nn.Module): c.value = v[:, :, -window_size:, :] return self.cache[layer_idx] # type: ignore - def update_mamba(self, conv_state: mx.array, ssm_state: mx.array, layer_idx: int) -> PlamoMambaCache: + def update_mamba( + self, conv_state: mx.array, ssm_state: mx.array, layer_idx: int + ) -> PlamoMambaCache: if self.cache[layer_idx] is None: self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state) else: @@ -305,7 +327,9 @@ class PlamoCache(nn.Module): def get_max_length(self) -> int | None: return None - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + def get_usable_length( + self, new_seq_length: int, layer_idx: Optional[int] = 0 + ) -> int: """Given the sequence length of the new inputs, returns the usable length of the cache.""" # Cache without size limit -> all cache is usable # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache @@ -361,7 +385,9 @@ class DecoderOutput(NamedTuple): # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: Tuple[int, int], dtype: mx.Dtype, past_key_values_length: int = 0) -> mx.array: +def _make_causal_mask( + input_ids_shape: Tuple[int, int], dtype: mx.Dtype, past_key_values_length: int = 0 +) -> mx.array: """ Make causal mask used for bi-directional self-attention. """ @@ -372,26 +398,36 @@ def _make_causal_mask(input_ids_shape: Tuple[int, int], dtype: mx.Dtype, past_ke mask = mask.astype(dtype) if past_key_values_length > 0: - mask = mx.concatenate([mx.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1) - return mx.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length)) + mask = mx.concatenate( + [mx.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1 + ) + return mx.broadcast_to( + mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length) + ) # Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None) -> mx.array: +def _expand_mask( + mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None +) -> mx.array: """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.shape tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = mx.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)).astype(dtype) + expanded_mask = mx.broadcast_to( + mask[:, None, None, :], (bsz, 1, tgt_len, src_len) + ).astype(dtype) inverted_mask = 1.0 - expanded_mask return mx.where(inverted_mask.astype(mx.bool_), float("-inf"), inverted_mask) # type: ignore -def _rms_norm(hidden_states: mx.array, weight: Optional[mx.array], eps: float, offset: float = 1.0) -> mx.array: +def _rms_norm( + hidden_states: mx.array, weight: Optional[mx.array], eps: float, offset: float = 1.0 +) -> mx.array: input_dtype = hidden_states.dtype hidden_states = hidden_states.astype(mx.float32) variance = mx.power(hidden_states, 2).mean(-1, keepdims=True) @@ -415,13 +451,18 @@ class RMSNorm(nn.Module): self.offset = offset def __call__(self, hidden_states: mx.array) -> mx.array: - return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset) + return _rms_norm( + hidden_states, self.weight, self.variance_epsilon, offset=self.offset + ) def get_initial_dt_bias(num_heads: int) -> mx.array: dt_min = 0.001 dt_max = 0.1 - dt = mx.exp(mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) + dt = mx.exp( + mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) dt = mx.clip(dt, a_min=1e-4, a_max=None) inv_dt = dt + mx.log(-mx.expm1(-dt)) return inv_dt @@ -450,9 +491,15 @@ def ssd_update_state( hidden_size_per_head = x.shape[-1] d_state = B.shape[-1] - A = mx.broadcast_to(A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)).astype(mx.float32) - dt = mx.broadcast_to(dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head)) - dt_bias = mx.broadcast_to(dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head)) + A = mx.broadcast_to( + A[:, None, None], (A.shape[0], hidden_size_per_head, d_state) + ).astype(mx.float32) + dt = mx.broadcast_to( + dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head) + ) + dt_bias = mx.broadcast_to( + dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head) + ) D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head)) assert ssm_state.dtype == mx.float32 out, ssm_state = f( @@ -592,7 +639,9 @@ def _causal_conv1d( mx.zeros_like(conv_state), conv_state, ) - out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1]) + out[:, :, i : i + 1], conv_state = _causal_conv1d_update( + conv_state, weight, x[:, :, i : i + 1] + ) x = out if return_final_states: return x, conv_state @@ -600,7 +649,9 @@ def _causal_conv1d( return x, None -def _causal_conv1d_update(conv_state: mx.array, weight: mx.array, xBC: mx.array) -> tuple[mx.array, mx.array]: +def _causal_conv1d_update( + conv_state: mx.array, weight: mx.array, xBC: mx.array +) -> tuple[mx.array, mx.array]: dtype = conv_state.dtype xBC = xBC.astype(dtype) weight = weight.astype(dtype) @@ -729,12 +780,18 @@ def mamba_chunk_scan_combined( dt = dt + dt_bias # incorporate bias to dt if provided # Prepare initial state - state_dim = A.shape[-1] # assume A is of shape (nheads, state_dim) for diagonal A or (nheads, state_dim, state_dim) + state_dim = A.shape[ + -1 + ] # assume A is of shape (nheads, state_dim) for diagonal A or (nheads, state_dim, state_dim) if initial_states is None: # Initialize state to zero for each sequence in batch and each head​:contentReference[oaicite:3]{index=3} state = mx.zeros((batch, nheads, state_dim), dtype=A.dtype) else: - state = mx.array(initial_states) if not isinstance(initial_states, mx.array) else initial_states + state = ( + mx.array(initial_states) + if not isinstance(initial_states, mx.array) + else initial_states + ) # Precompute exponent of A*dt for state update per step (assuming A is diagonal or elementwise applicable) # If A is diagonal values per head (shape (nheads, state_dim)), we compute elementwise exponentials. @@ -743,10 +800,14 @@ def mamba_chunk_scan_combined( # A is given as diagonal values per state exp_dA = mx.exp(A * dt) # shape (nheads, state_dim) or (state_dim,) if exp_dA.ndim == 2: - exp_dA = exp_dA.reshape((1, nheads, state_dim)) # shape (1, nheads, state_dim) for broadcasting + exp_dA = exp_dA.reshape( + (1, nheads, state_dim) + ) # shape (1, nheads, state_dim) for broadcasting else: # If A is a full matrix per head, use matrix exponential (if available in MLX) - exp_dA = mx.exp(A * dt) # assuming MX can exponentiate matrix elementwise or use specialized routine + exp_dA = mx.exp( + A * dt + ) # assuming MX can exponentiate matrix elementwise or use specialized routine # Output buffer out_list = [] # will collect output chunks @@ -783,7 +844,9 @@ def mamba_chunk_scan_combined( inc = inc * dt # State update: s_{t+1} = exp(A*dt) * s_t + inc​:contentReference[oaicite:4]{index=4}. - state = state * exp_dA + inc # elementwise multiply if exp_dA broadcast shape (1, nheads, state_dim) + state = ( + state * exp_dA + inc + ) # elementwise multiply if exp_dA broadcast shape (1, nheads, state_dim) # Compute output for this time step: y_t = C * state_t + (D * x_t if direct term exists) if C.ndim == 3: @@ -792,7 +855,9 @@ def mamba_chunk_scan_combined( else: # C shape (nheads, state_dim) or (state_dim,), output one value per head # Multiply and sum over state_dim - y_t = mx.sum(state * C.reshape((1, nheads, state_dim)), axis=-1, keepdims=True) # (batch, nheads, 1) + y_t = mx.sum( + state * C.reshape((1, nheads, state_dim)), axis=-1, keepdims=True + ) # (batch, nheads, 1) if D is not None: # Add direct input contribution: D * x(t) if D.ndim == 2: @@ -804,7 +869,9 @@ def mamba_chunk_scan_combined( ) else: # D shape (nheads,) or scalar - y_t += D.reshape((1, nheads, -1)) * (x_t if x_t.ndim == 3 else x_t[..., None]) + y_t += D.reshape((1, nheads, -1)) * ( + x_t if x_t.ndim == 3 else x_t[..., None] + ) # Apply gating activation if provided (e.g., elementwise multiply by a sigmoidal function of z) if z is not None: # Example: if z is meant to gate outputs via a sigmoid activation (silu), as in some Mamba variants @@ -826,7 +893,9 @@ def mamba_chunk_scan_combined( out = y.reshape((batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1))) else: # If out_list was used and concatenated via Python, we might get a NumPy array; ensure it's MLX: - out = mx.array(y).reshape((batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1))) + out = mx.array(y).reshape( + (batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1)) + ) if return_final_states: # Return the final state as well (state holds final state after last chunk) @@ -850,16 +919,24 @@ class PlamoPreTrainedModel(nn.Module): # type: ignore def _init_weights(self, module: nn.Module) -> None: std = 0.02 if isinstance(module, nn.Linear): - module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape) + module.weight = mx.random.normal( + loc=0.0, scale=std, shape=module.weight.shape + ) if module.bias is not None: module.bias = mx.zeros_like(module.bias) elif isinstance(module, nn.Embedding): - module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape) + module.weight = mx.random.normal( + loc=0.0, scale=std, shape=module.weight.shape + ) if module.padding_idx is not None: - module.weight[module.padding_idx] = mx.zeros_like(module.weight[module.padding_idx]) + module.weight[module.padding_idx] = mx.zeros_like( + module.weight[module.padding_idx] + ) -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): +def causal_conv1d_update( + x, conv_state, weight, bias=None, activation=None, cache_seqlens=None +): """ x: (batch, dim) or (batch, dim, seqlen) conv_state: (batch, dim, state_len), where state_len >= width - 1 @@ -884,13 +961,21 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cach assert conv_state.shape == (batch, dim, state_len) assert weight.shape == (dim, width) if cache_seqlens is None: - x_new = mx.concatenate([conv_state, x], axis=-1).astype(weight.dtype) # (batch, dim, state_len + seqlen) + x_new = mx.concatenate([conv_state, x], axis=-1).astype( + weight.dtype + ) # (batch, dim, state_len + seqlen) conv_state = x_new[:, :, -state_len:] else: - width_idx = mx.expand_dims(mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0) + cache_seqlens.unsqueeze(1) + width_idx = mx.expand_dims( + mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0 + ) + cache_seqlens.unsqueeze(1) width_idx = mx.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1).astype(weight.dtype) - copy_idx = mx.expand_dims(mx.arange(seqlen, dtype=mx.int64), axis=0) + cache_seqlens.unsqueeze(1) + x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1).astype( + weight.dtype + ) + copy_idx = mx.expand_dims( + mx.arange(seqlen, dtype=mx.int64), axis=0 + ) + cache_seqlens.unsqueeze(1) copy_idx = mx.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) conv_state.scatter_(2, copy_idx, x) assert bias is None @@ -900,9 +985,7 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cach mx.expand_dims(weight, axis=2), padding=0, groups=dim, - ).transpose( - 0, 2, 1 - )[:, :, -seqlen:] + ).transpose(0, 2, 1)[:, :, -seqlen:] if unsqueeze: out = out.squeeze(-1) return (out if activation is None else nn.silu(out)).astype(dtype_in), conv_state @@ -969,7 +1052,9 @@ def selective_state_update_ref( mx.tile(mx.expand_dims(C, axis=2), (1, 1, nheads // ngroups, 1)), (batch, nheads, dstate), ) # (batch, nheads, dstate) - dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(B, axis=-2) # (batch, nheads, dim, dstate) + dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims( + B, axis=-2 + ) # (batch, nheads, dim, dstate) state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C) if D is not None: @@ -1013,12 +1098,16 @@ class Attention(nn.Module): self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, bias=False, ) - self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False) + self.o_proj = nn.Linear( + self.q_num_heads * self.v_dim, self.hidden_size, bias=False + ) self.q_weight = mx.ones((self.q_num_heads, self.qk_dim)) self.k_weight = mx.ones((self.k_num_heads, self.qk_dim)) - self.rotary_emb = RotaryEmbedding(self.qk_dim, max_position_embeddings=self.config.attention_window_size) + self.rotary_emb = RotaryEmbedding( + self.qk_dim, max_position_embeddings=self.config.attention_window_size + ) def __call__( self, @@ -1033,21 +1122,33 @@ class Attention(nn.Module): query_states, key_states, value_states = mx.split( qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1 ) - query_states = query_states.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3) - key_states = key_states.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) - value_states = value_states.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) + query_states = query_states.reshape( + bsz, q_len, self.q_num_heads, self.qk_dim + ).transpose(0, 2, 1, 3) + key_states = key_states.reshape( + bsz, q_len, self.k_num_heads, self.qk_dim + ).transpose(0, 2, 1, 3) + value_states = value_states.reshape( + bsz, q_len, self.v_num_heads, self.v_dim + ).transpose(0, 2, 1, 3) attn_dtype = query_states.dtype - query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None] + query_states = ( + _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None] + ) key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None] if past_states is not None: # reuse k, v, self_attention key_states_new = key_states value_states_new = value_states - key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) # type: ignore - past_states.update_attention(key_states_new, value_states_new, self.layer_idx) + key_states, value_states = past_states.append_kv( + key_states, value_states, self.layer_idx + ) # type: ignore + past_states.update_attention( + key_states_new, value_states_new, self.layer_idx + ) kv_seq_len = key_states.shape[-2] position_ids = mx.arange(kv_seq_len, dtype=mx.int64)[None] @@ -1082,7 +1183,9 @@ class Attention(nn.Module): ) else: if attention_mask.dtype == bool: - attention_mask = mx.where(attention_mask, mx.array(0.0, dtype=mx.float16), float("-inf")) + attention_mask = mx.where( + attention_mask, mx.array(0.0, dtype=mx.float16), float("-inf") + ) if len(attention_mask.shape) == 2: attention_mask = attention_mask[None, None] assert len(attention_mask.shape) == 4 @@ -1095,14 +1198,20 @@ class Attention(nn.Module): ) # `generate` function creates attention mask that does not consider sliding window m_swa = m_swa[None, None] - attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :] + attention_mask = attention_mask[ + :, :, -query_states.shape[2] :, -key_states.shape[2] : + ] attention_mask = mx.where(m_swa, attention_mask, float("-inf")) # like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers, # we need to attend to all tokens in masked rows for `scaled_dot_product_attention` bool_mask = mx.logical_not(mx.isneginf(attention_mask)) - valid_tokens = mx.sum(bool_mask, axis=-1).astype(mx.bool_) # (..., q_len) # type: ignore - attention_mask = mx.where(valid_tokens[..., None], attention_mask, float(0.0)) + valid_tokens = mx.sum(bool_mask, axis=-1).astype( + mx.bool_ + ) # (..., q_len) # type: ignore + attention_mask = mx.where( + valid_tokens[..., None], attention_mask, float(0.0) + ) attn_output = mx.fast.scaled_dot_product_attention( query_states, key_states, @@ -1137,7 +1246,9 @@ class Mamba(nn.Module): self.intermediate_size = self.num_heads * self.hidden_size_per_head - self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False) + self.in_proj = nn.Linear( + self.hidden_size, 2 * self.intermediate_size, bias=False + ) self.conv1d = nn.Conv1d( in_channels=self.intermediate_size, out_channels=self.intermediate_size, @@ -1235,14 +1346,18 @@ class Mamba(nn.Module): ) # conv - x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length) + x = x.reshape(bsize, length, -1).transpose( + 0, 2, 1 + ) # (bsize, intermediate_size, length) if bool_mask is not None: x = mx.where(bool_mask[:, None, :], x, 0.0) if is_update: assert conv_state is not None x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x) else: - x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx) + x, conv_state = _causal_conv1d( + conv_state, self.conv1d.weight, x, seq_idx=seq_idx + ) x = x.astype(hidden_states.dtype) x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size) x = x.reshape(bsize, length, -1) @@ -1257,9 +1372,18 @@ class Mamba(nn.Module): C = C[:, :, None, :] A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,) - dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :] - B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :] - C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :] + dt = ( + _rms_norm(dt, None, self.config.rms_norm_eps) + * self.dt_norm_weight[None, None, :] + ) + B = ( + _rms_norm(B, None, self.config.rms_norm_eps) + * self.B_norm_weight[None, None, None, :] + ) + C = ( + _rms_norm(C, None, self.config.rms_norm_eps) + * self.C_norm_weight[None, None, None, :] + ) # (bsize, length, num_heads, 1) dt = self.dt_proj(dt)[..., None] @@ -1335,7 +1459,9 @@ class MLP(nn.Module): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.gate_up_proj = nn.Linear( + self.hidden_size, self.intermediate_size * 2, bias=False + ) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) def __call__(self, x: mx.array) -> mx.array: @@ -1359,10 +1485,18 @@ class PlamoDecoderLayer(nn.Module): """ Notes: The model performance was degraded when setting all offsets to 1. """ - self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0) - self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5) - self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0) - self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5)) + self.pre_mixer_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 + ) + self.post_mixer_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5 + ) + self.pre_mlp_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 + ) + self.post_mlp_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5) + ) def __call__( self, @@ -1423,8 +1557,12 @@ class PlamoDecoder(nn.Module): self.gradient_checkpointing = False def __call__(self, x: DecoderInput) -> DecoderOutput: - all_hidden_states: Optional[Tuple[mx.array, ...]] = () if x.output_hidden_states else None - all_self_attns: Optional[Tuple[mx.array, ...]] = () if x.output_attentions else None + all_hidden_states: Optional[Tuple[mx.array, ...]] = ( + () if x.output_hidden_states else None + ) + all_self_attns: Optional[Tuple[mx.array, ...]] = ( + () if x.output_attentions else None + ) hidden_states = x.hidden_states for decoder_layer in self.layers: @@ -1579,9 +1717,13 @@ class PlamoModel(PlamoPreTrainedModel): else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] assert inputs_embeds is not None - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask @@ -1600,21 +1742,33 @@ class PlamoModel(PlamoPreTrainedModel): return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: assert input_ids is not None - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) elif input_ids is not None: batch_size, seq_length = input_ids.shape else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) seq_length_with_past = seq_length past_key_values_length = 0 @@ -1781,11 +1935,19 @@ class Model(PlamoPreTrainedModel): ```""" assert input_ids is not None - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 78a2e802..6201b9dd 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -19,7 +19,6 @@ from typing import ( Dict, Generator, List, - NamedTuple, Optional, Tuple, Type, @@ -44,7 +43,6 @@ from transformers import PreTrainedTokenizer # Local imports from .models import cache -from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model from .tuner.utils import load_adapters, nparams @@ -721,6 +719,11 @@ def load_model( weights = {} for wf in weight_files: weights.update(mx.load(wf)) + if "lm_head.weight" not in weights: + weights["lm_head.weight"] = weights["model.embed_tokens.weight"] + for k in weights.keys(): + if "conv1d.weight" in k: + weights[k] = weights[k].transpose(0, 2, 1) model_class, model_args_class = get_model_classes(config=config) @@ -1050,6 +1053,12 @@ def convert( model, config, tokenizer = fetch_from_hub(model_path, lazy=True) weights = dict(tree_flatten(model.parameters())) + if "lm_head.weight" not in weights: + weights["lm_head.weight"] = weights["model.embed_tokens.weight"] + for k in weights.keys(): + if "conv1d.weight" in k: + weights[k] = weights[k].transpose(0, 2, 1) + dtype = getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()}