Use mlx's BaseModelArgs

This commit is contained in:
Shunta Saito 2025-02-13 14:47:08 +09:00
parent 9a6e6541de
commit 40c7ce8048
2 changed files with 258 additions and 87 deletions

View File

@ -5,9 +5,8 @@ from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from transformers import PretrainedConfig
from mlx_lm.models.base import create_attention_mask from .base import BaseModelArgs, create_attention_mask
def _is_first_token(mask: mx.array) -> mx.array: def _is_first_token(mask: mx.array) -> mx.array:
@ -16,8 +15,12 @@ def _is_first_token(mask: mx.array) -> mx.array:
mask = mask[:, :, :, -q_len:] mask = mask[:, :, :, -q_len:]
cont = q_len != kv_len cont = q_len != kv_len
v = False if cont else True v = False if cont else True
out = mx.logical_not(mx.diagonal(mask, offset=-1, axis1=-2, axis2=-1).astype(mx.bool_)) out = mx.logical_not(
out = mx.concatenate([mx.full(shape=(B, Nh, 1), dtype=mx.bool_, vals=v), out], axis=-1) mx.diagonal(mask, offset=-1, axis1=-2, axis2=-1).astype(mx.bool_)
)
out = mx.concatenate(
[mx.full(shape=(B, Nh, 1), dtype=mx.bool_, vals=v), out], axis=-1
)
return out return out
@ -34,13 +37,17 @@ def _swiglu(h: mx.array) -> mx.array:
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_position_embeddings: int = 2048, base: int = 10000) -> None: def __init__(
self, dim: int, max_position_embeddings: int = 2048, base: int = 10000
) -> None:
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
inv_freq = 1.0 / (self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)) inv_freq = 1.0 / (
self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
)
self._inv_freq = inv_freq self._inv_freq = inv_freq
# Build here to make `torch.jit.trace` work. # Build here to make `torch.jit.trace` work.
@ -74,7 +81,9 @@ def _rotate_half(x: mx.array) -> mx.array:
return mx.concatenate([-x2, x1], axis=-1) return mx.concatenate([-x2, x1], axis=-1)
def _rotary_pos_emb(x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array) -> mx.array: def _rotary_pos_emb(
x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array
) -> mx.array:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
@ -90,8 +99,9 @@ class LinearType(str, enum.Enum):
Fp8Retain = "fp8-retain" Fp8Retain = "fp8-retain"
class ModelArgs(PretrainedConfig): # type: ignore @dataclass
model_type: str = "plamo" class ModelArgs(BaseModelArgs): # type: ignore
model_type: str = "plamo2"
def __init__( def __init__(
self, self,
@ -145,7 +155,9 @@ class ModelArgs(PretrainedConfig): # type: ignore
self.hidden_size_per_head = hidden_size_per_head self.hidden_size_per_head = hidden_size_per_head
self.num_key_value_heads = num_key_value_heads self.num_key_value_heads = num_key_value_heads
self.attention_window_size = attention_window_size self.attention_window_size = attention_window_size
self.full_attention_idx = full_attention_idx if full_attention_idx is not None else [] self.full_attention_idx = (
full_attention_idx if full_attention_idx is not None else []
)
self.mamba_d_state = mamba_d_state self.mamba_d_state = mamba_d_state
self.mamba_d_conv = mamba_d_conv self.mamba_d_conv = mamba_d_conv
@ -221,9 +233,13 @@ class PlamoCache(nn.Module):
def __init__(self, config: ModelArgs) -> None: def __init__(self, config: ModelArgs) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.cache: List[Optional[PlamoLayerCache]] = [None for _ in range(config.num_hidden_layers)] self.cache: List[Optional[PlamoLayerCache]] = [
None for _ in range(config.num_hidden_layers)
]
def append_kv(self, key: mx.array, value: mx.array, layer_idx: int) -> tuple[mx.array, mx.array]: def append_kv(
self, key: mx.array, value: mx.array, layer_idx: int
) -> tuple[mx.array, mx.array]:
c = self.cache[layer_idx] c = self.cache[layer_idx]
if c is None: if c is None:
return key, value return key, value
@ -239,9 +255,13 @@ class PlamoCache(nn.Module):
_validate(c.key, key) _validate(c.key, key)
_validate(c.value, value) _validate(c.value, value)
assert key.shape[2] == value.shape[2] assert key.shape[2] == value.shape[2]
return mx.concatenate([c.key, key], axis=2), mx.concatenate([c.value, value], axis=2) return mx.concatenate([c.key, key], axis=2), mx.concatenate(
[c.value, value], axis=2
)
def update_attention(self, key_states: mx.array, value_states: mx.array, layer_idx: int) -> PlamoAttentionCache: def update_attention(
self, key_states: mx.array, value_states: mx.array, layer_idx: int
) -> PlamoAttentionCache:
full_attn = layer_idx in self.config.full_attention_idx full_attn = layer_idx in self.config.full_attention_idx
window_size = self.config.attention_window_size window_size = self.config.attention_window_size
@ -265,7 +285,9 @@ class PlamoCache(nn.Module):
c.value = v[:, :, -window_size:, :] c.value = v[:, :, -window_size:, :]
return self.cache[layer_idx] # type: ignore return self.cache[layer_idx] # type: ignore
def update_mamba(self, conv_state: mx.array, ssm_state: mx.array, layer_idx: int) -> PlamoMambaCache: def update_mamba(
self, conv_state: mx.array, ssm_state: mx.array, layer_idx: int
) -> PlamoMambaCache:
if self.cache[layer_idx] is None: if self.cache[layer_idx] is None:
self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state) self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state)
else: else:
@ -305,7 +327,9 @@ class PlamoCache(nn.Module):
def get_max_length(self) -> int | None: def get_max_length(self) -> int | None:
return None return None
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: def get_usable_length(
self, new_seq_length: int, layer_idx: Optional[int] = 0
) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache.""" """Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable # Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
@ -361,7 +385,9 @@ class DecoderOutput(NamedTuple):
# Copied from transformers.models.bart.modeling_bart._make_causal_mask # Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: Tuple[int, int], dtype: mx.Dtype, past_key_values_length: int = 0) -> mx.array: def _make_causal_mask(
input_ids_shape: Tuple[int, int], dtype: mx.Dtype, past_key_values_length: int = 0
) -> mx.array:
""" """
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
@ -372,26 +398,36 @@ def _make_causal_mask(input_ids_shape: Tuple[int, int], dtype: mx.Dtype, past_ke
mask = mask.astype(dtype) mask = mask.astype(dtype)
if past_key_values_length > 0: if past_key_values_length > 0:
mask = mx.concatenate([mx.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1) mask = mx.concatenate(
return mx.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length)) [mx.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1
)
return mx.broadcast_to(
mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length)
)
# Copied from transformers.models.bart.modeling_bart._expand_mask # Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None) -> mx.array: def _expand_mask(
mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None
) -> mx.array:
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
""" """
bsz, src_len = mask.shape bsz, src_len = mask.shape
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mx.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)).astype(dtype) expanded_mask = mx.broadcast_to(
mask[:, None, None, :], (bsz, 1, tgt_len, src_len)
).astype(dtype)
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return mx.where(inverted_mask.astype(mx.bool_), float("-inf"), inverted_mask) # type: ignore return mx.where(inverted_mask.astype(mx.bool_), float("-inf"), inverted_mask) # type: ignore
def _rms_norm(hidden_states: mx.array, weight: Optional[mx.array], eps: float, offset: float = 1.0) -> mx.array: def _rms_norm(
hidden_states: mx.array, weight: Optional[mx.array], eps: float, offset: float = 1.0
) -> mx.array:
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
hidden_states = hidden_states.astype(mx.float32) hidden_states = hidden_states.astype(mx.float32)
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True) variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
@ -415,13 +451,18 @@ class RMSNorm(nn.Module):
self.offset = offset self.offset = offset
def __call__(self, hidden_states: mx.array) -> mx.array: def __call__(self, hidden_states: mx.array) -> mx.array:
return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset) return _rms_norm(
hidden_states, self.weight, self.variance_epsilon, offset=self.offset
)
def get_initial_dt_bias(num_heads: int) -> mx.array: def get_initial_dt_bias(num_heads: int) -> mx.array:
dt_min = 0.001 dt_min = 0.001
dt_max = 0.1 dt_max = 0.1
dt = mx.exp(mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) dt = mx.exp(
mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)
dt = mx.clip(dt, a_min=1e-4, a_max=None) dt = mx.clip(dt, a_min=1e-4, a_max=None)
inv_dt = dt + mx.log(-mx.expm1(-dt)) inv_dt = dt + mx.log(-mx.expm1(-dt))
return inv_dt return inv_dt
@ -450,9 +491,15 @@ def ssd_update_state(
hidden_size_per_head = x.shape[-1] hidden_size_per_head = x.shape[-1]
d_state = B.shape[-1] d_state = B.shape[-1]
A = mx.broadcast_to(A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)).astype(mx.float32) A = mx.broadcast_to(
dt = mx.broadcast_to(dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head)) A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)
dt_bias = mx.broadcast_to(dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head)) ).astype(mx.float32)
dt = mx.broadcast_to(
dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head)
)
dt_bias = mx.broadcast_to(
dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head)
)
D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head)) D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head))
assert ssm_state.dtype == mx.float32 assert ssm_state.dtype == mx.float32
out, ssm_state = f( out, ssm_state = f(
@ -592,7 +639,9 @@ def _causal_conv1d(
mx.zeros_like(conv_state), mx.zeros_like(conv_state),
conv_state, conv_state,
) )
out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1]) out[:, :, i : i + 1], conv_state = _causal_conv1d_update(
conv_state, weight, x[:, :, i : i + 1]
)
x = out x = out
if return_final_states: if return_final_states:
return x, conv_state return x, conv_state
@ -600,7 +649,9 @@ def _causal_conv1d(
return x, None return x, None
def _causal_conv1d_update(conv_state: mx.array, weight: mx.array, xBC: mx.array) -> tuple[mx.array, mx.array]: def _causal_conv1d_update(
conv_state: mx.array, weight: mx.array, xBC: mx.array
) -> tuple[mx.array, mx.array]:
dtype = conv_state.dtype dtype = conv_state.dtype
xBC = xBC.astype(dtype) xBC = xBC.astype(dtype)
weight = weight.astype(dtype) weight = weight.astype(dtype)
@ -729,12 +780,18 @@ def mamba_chunk_scan_combined(
dt = dt + dt_bias # incorporate bias to dt if provided dt = dt + dt_bias # incorporate bias to dt if provided
# Prepare initial state # Prepare initial state
state_dim = A.shape[-1] # assume A is of shape (nheads, state_dim) for diagonal A or (nheads, state_dim, state_dim) state_dim = A.shape[
-1
] # assume A is of shape (nheads, state_dim) for diagonal A or (nheads, state_dim, state_dim)
if initial_states is None: if initial_states is None:
# Initialize state to zero for each sequence in batch and each head​:contentReference[oaicite:3]{index=3} # Initialize state to zero for each sequence in batch and each head​:contentReference[oaicite:3]{index=3}
state = mx.zeros((batch, nheads, state_dim), dtype=A.dtype) state = mx.zeros((batch, nheads, state_dim), dtype=A.dtype)
else: else:
state = mx.array(initial_states) if not isinstance(initial_states, mx.array) else initial_states state = (
mx.array(initial_states)
if not isinstance(initial_states, mx.array)
else initial_states
)
# Precompute exponent of A*dt for state update per step (assuming A is diagonal or elementwise applicable) # Precompute exponent of A*dt for state update per step (assuming A is diagonal or elementwise applicable)
# If A is diagonal values per head (shape (nheads, state_dim)), we compute elementwise exponentials. # If A is diagonal values per head (shape (nheads, state_dim)), we compute elementwise exponentials.
@ -743,10 +800,14 @@ def mamba_chunk_scan_combined(
# A is given as diagonal values per state # A is given as diagonal values per state
exp_dA = mx.exp(A * dt) # shape (nheads, state_dim) or (state_dim,) exp_dA = mx.exp(A * dt) # shape (nheads, state_dim) or (state_dim,)
if exp_dA.ndim == 2: if exp_dA.ndim == 2:
exp_dA = exp_dA.reshape((1, nheads, state_dim)) # shape (1, nheads, state_dim) for broadcasting exp_dA = exp_dA.reshape(
(1, nheads, state_dim)
) # shape (1, nheads, state_dim) for broadcasting
else: else:
# If A is a full matrix per head, use matrix exponential (if available in MLX) # If A is a full matrix per head, use matrix exponential (if available in MLX)
exp_dA = mx.exp(A * dt) # assuming MX can exponentiate matrix elementwise or use specialized routine exp_dA = mx.exp(
A * dt
) # assuming MX can exponentiate matrix elementwise or use specialized routine
# Output buffer # Output buffer
out_list = [] # will collect output chunks out_list = [] # will collect output chunks
@ -783,7 +844,9 @@ def mamba_chunk_scan_combined(
inc = inc * dt inc = inc * dt
# State update: s_{t+1} = exp(A*dt) * s_t + inc​:contentReference[oaicite:4]{index=4}. # State update: s_{t+1} = exp(A*dt) * s_t + inc​:contentReference[oaicite:4]{index=4}.
state = state * exp_dA + inc # elementwise multiply if exp_dA broadcast shape (1, nheads, state_dim) state = (
state * exp_dA + inc
) # elementwise multiply if exp_dA broadcast shape (1, nheads, state_dim)
# Compute output for this time step: y_t = C * state_t + (D * x_t if direct term exists) # Compute output for this time step: y_t = C * state_t + (D * x_t if direct term exists)
if C.ndim == 3: if C.ndim == 3:
@ -792,7 +855,9 @@ def mamba_chunk_scan_combined(
else: else:
# C shape (nheads, state_dim) or (state_dim,), output one value per head # C shape (nheads, state_dim) or (state_dim,), output one value per head
# Multiply and sum over state_dim # Multiply and sum over state_dim
y_t = mx.sum(state * C.reshape((1, nheads, state_dim)), axis=-1, keepdims=True) # (batch, nheads, 1) y_t = mx.sum(
state * C.reshape((1, nheads, state_dim)), axis=-1, keepdims=True
) # (batch, nheads, 1)
if D is not None: if D is not None:
# Add direct input contribution: D * x(t) # Add direct input contribution: D * x(t)
if D.ndim == 2: if D.ndim == 2:
@ -804,7 +869,9 @@ def mamba_chunk_scan_combined(
) )
else: else:
# D shape (nheads,) or scalar # D shape (nheads,) or scalar
y_t += D.reshape((1, nheads, -1)) * (x_t if x_t.ndim == 3 else x_t[..., None]) y_t += D.reshape((1, nheads, -1)) * (
x_t if x_t.ndim == 3 else x_t[..., None]
)
# Apply gating activation if provided (e.g., elementwise multiply by a sigmoidal function of z) # Apply gating activation if provided (e.g., elementwise multiply by a sigmoidal function of z)
if z is not None: if z is not None:
# Example: if z is meant to gate outputs via a sigmoid activation (silu), as in some Mamba variants # Example: if z is meant to gate outputs via a sigmoid activation (silu), as in some Mamba variants
@ -826,7 +893,9 @@ def mamba_chunk_scan_combined(
out = y.reshape((batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1))) out = y.reshape((batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1)))
else: else:
# If out_list was used and concatenated via Python, we might get a NumPy array; ensure it's MLX: # If out_list was used and concatenated via Python, we might get a NumPy array; ensure it's MLX:
out = mx.array(y).reshape((batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1))) out = mx.array(y).reshape(
(batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1))
)
if return_final_states: if return_final_states:
# Return the final state as well (state holds final state after last chunk) # Return the final state as well (state holds final state after last chunk)
@ -850,16 +919,24 @@ class PlamoPreTrainedModel(nn.Module): # type: ignore
def _init_weights(self, module: nn.Module) -> None: def _init_weights(self, module: nn.Module) -> None:
std = 0.02 std = 0.02
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape) module.weight = mx.random.normal(
loc=0.0, scale=std, shape=module.weight.shape
)
if module.bias is not None: if module.bias is not None:
module.bias = mx.zeros_like(module.bias) module.bias = mx.zeros_like(module.bias)
elif isinstance(module, nn.Embedding): elif isinstance(module, nn.Embedding):
module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape) module.weight = mx.random.normal(
loc=0.0, scale=std, shape=module.weight.shape
)
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight[module.padding_idx] = mx.zeros_like(module.weight[module.padding_idx]) module.weight[module.padding_idx] = mx.zeros_like(
module.weight[module.padding_idx]
)
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): def causal_conv1d_update(
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
):
""" """
x: (batch, dim) or (batch, dim, seqlen) x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1 conv_state: (batch, dim, state_len), where state_len >= width - 1
@ -884,13 +961,21 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cach
assert conv_state.shape == (batch, dim, state_len) assert conv_state.shape == (batch, dim, state_len)
assert weight.shape == (dim, width) assert weight.shape == (dim, width)
if cache_seqlens is None: if cache_seqlens is None:
x_new = mx.concatenate([conv_state, x], axis=-1).astype(weight.dtype) # (batch, dim, state_len + seqlen) x_new = mx.concatenate([conv_state, x], axis=-1).astype(
weight.dtype
) # (batch, dim, state_len + seqlen)
conv_state = x_new[:, :, -state_len:] conv_state = x_new[:, :, -state_len:]
else: else:
width_idx = mx.expand_dims(mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0) + cache_seqlens.unsqueeze(1) width_idx = mx.expand_dims(
mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0
) + cache_seqlens.unsqueeze(1)
width_idx = mx.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) width_idx = mx.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1).astype(weight.dtype) x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1).astype(
copy_idx = mx.expand_dims(mx.arange(seqlen, dtype=mx.int64), axis=0) + cache_seqlens.unsqueeze(1) weight.dtype
)
copy_idx = mx.expand_dims(
mx.arange(seqlen, dtype=mx.int64), axis=0
) + cache_seqlens.unsqueeze(1)
copy_idx = mx.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) copy_idx = mx.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
conv_state.scatter_(2, copy_idx, x) conv_state.scatter_(2, copy_idx, x)
assert bias is None assert bias is None
@ -900,9 +985,7 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cach
mx.expand_dims(weight, axis=2), mx.expand_dims(weight, axis=2),
padding=0, padding=0,
groups=dim, groups=dim,
).transpose( ).transpose(0, 2, 1)[:, :, -seqlen:]
0, 2, 1
)[:, :, -seqlen:]
if unsqueeze: if unsqueeze:
out = out.squeeze(-1) out = out.squeeze(-1)
return (out if activation is None else nn.silu(out)).astype(dtype_in), conv_state return (out if activation is None else nn.silu(out)).astype(dtype_in), conv_state
@ -969,7 +1052,9 @@ def selective_state_update_ref(
mx.tile(mx.expand_dims(C, axis=2), (1, 1, nheads // ngroups, 1)), mx.tile(mx.expand_dims(C, axis=2), (1, 1, nheads // ngroups, 1)),
(batch, nheads, dstate), (batch, nheads, dstate),
) # (batch, nheads, dstate) ) # (batch, nheads, dstate)
dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(B, axis=-2) # (batch, nheads, dim, dstate) dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(
B, axis=-2
) # (batch, nheads, dim, dstate)
state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate
out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C) out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C)
if D is not None: if D is not None:
@ -1013,12 +1098,16 @@ class Attention(nn.Module):
self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, self.q_proj_dim + self.k_proj_dim + self.v_proj_dim,
bias=False, bias=False,
) )
self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False) self.o_proj = nn.Linear(
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
)
self.q_weight = mx.ones((self.q_num_heads, self.qk_dim)) self.q_weight = mx.ones((self.q_num_heads, self.qk_dim))
self.k_weight = mx.ones((self.k_num_heads, self.qk_dim)) self.k_weight = mx.ones((self.k_num_heads, self.qk_dim))
self.rotary_emb = RotaryEmbedding(self.qk_dim, max_position_embeddings=self.config.attention_window_size) self.rotary_emb = RotaryEmbedding(
self.qk_dim, max_position_embeddings=self.config.attention_window_size
)
def __call__( def __call__(
self, self,
@ -1033,21 +1122,33 @@ class Attention(nn.Module):
query_states, key_states, value_states = mx.split( query_states, key_states, value_states = mx.split(
qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1 qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1
) )
query_states = query_states.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3) query_states = query_states.reshape(
key_states = key_states.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) bsz, q_len, self.q_num_heads, self.qk_dim
value_states = value_states.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) ).transpose(0, 2, 1, 3)
key_states = key_states.reshape(
bsz, q_len, self.k_num_heads, self.qk_dim
).transpose(0, 2, 1, 3)
value_states = value_states.reshape(
bsz, q_len, self.v_num_heads, self.v_dim
).transpose(0, 2, 1, 3)
attn_dtype = query_states.dtype attn_dtype = query_states.dtype
query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None] query_states = (
_rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
)
key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None] key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
if past_states is not None: if past_states is not None:
# reuse k, v, self_attention # reuse k, v, self_attention
key_states_new = key_states key_states_new = key_states
value_states_new = value_states value_states_new = value_states
key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) # type: ignore key_states, value_states = past_states.append_kv(
past_states.update_attention(key_states_new, value_states_new, self.layer_idx) key_states, value_states, self.layer_idx
) # type: ignore
past_states.update_attention(
key_states_new, value_states_new, self.layer_idx
)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
position_ids = mx.arange(kv_seq_len, dtype=mx.int64)[None] position_ids = mx.arange(kv_seq_len, dtype=mx.int64)[None]
@ -1082,7 +1183,9 @@ class Attention(nn.Module):
) )
else: else:
if attention_mask.dtype == bool: if attention_mask.dtype == bool:
attention_mask = mx.where(attention_mask, mx.array(0.0, dtype=mx.float16), float("-inf")) attention_mask = mx.where(
attention_mask, mx.array(0.0, dtype=mx.float16), float("-inf")
)
if len(attention_mask.shape) == 2: if len(attention_mask.shape) == 2:
attention_mask = attention_mask[None, None] attention_mask = attention_mask[None, None]
assert len(attention_mask.shape) == 4 assert len(attention_mask.shape) == 4
@ -1095,14 +1198,20 @@ class Attention(nn.Module):
) )
# `generate` function creates attention mask that does not consider sliding window # `generate` function creates attention mask that does not consider sliding window
m_swa = m_swa[None, None] m_swa = m_swa[None, None]
attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :] attention_mask = attention_mask[
:, :, -query_states.shape[2] :, -key_states.shape[2] :
]
attention_mask = mx.where(m_swa, attention_mask, float("-inf")) attention_mask = mx.where(m_swa, attention_mask, float("-inf"))
# like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers, # like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers,
# we need to attend to all tokens in masked rows for `scaled_dot_product_attention` # we need to attend to all tokens in masked rows for `scaled_dot_product_attention`
bool_mask = mx.logical_not(mx.isneginf(attention_mask)) bool_mask = mx.logical_not(mx.isneginf(attention_mask))
valid_tokens = mx.sum(bool_mask, axis=-1).astype(mx.bool_) # (..., q_len) # type: ignore valid_tokens = mx.sum(bool_mask, axis=-1).astype(
attention_mask = mx.where(valid_tokens[..., None], attention_mask, float(0.0)) mx.bool_
) # (..., q_len) # type: ignore
attention_mask = mx.where(
valid_tokens[..., None], attention_mask, float(0.0)
)
attn_output = mx.fast.scaled_dot_product_attention( attn_output = mx.fast.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
@ -1137,7 +1246,9 @@ class Mamba(nn.Module):
self.intermediate_size = self.num_heads * self.hidden_size_per_head self.intermediate_size = self.num_heads * self.hidden_size_per_head
self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False) self.in_proj = nn.Linear(
self.hidden_size, 2 * self.intermediate_size, bias=False
)
self.conv1d = nn.Conv1d( self.conv1d = nn.Conv1d(
in_channels=self.intermediate_size, in_channels=self.intermediate_size,
out_channels=self.intermediate_size, out_channels=self.intermediate_size,
@ -1235,14 +1346,18 @@ class Mamba(nn.Module):
) )
# conv # conv
x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length) x = x.reshape(bsize, length, -1).transpose(
0, 2, 1
) # (bsize, intermediate_size, length)
if bool_mask is not None: if bool_mask is not None:
x = mx.where(bool_mask[:, None, :], x, 0.0) x = mx.where(bool_mask[:, None, :], x, 0.0)
if is_update: if is_update:
assert conv_state is not None assert conv_state is not None
x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x) x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
else: else:
x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx) x, conv_state = _causal_conv1d(
conv_state, self.conv1d.weight, x, seq_idx=seq_idx
)
x = x.astype(hidden_states.dtype) x = x.astype(hidden_states.dtype)
x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size) x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size)
x = x.reshape(bsize, length, -1) x = x.reshape(bsize, length, -1)
@ -1257,9 +1372,18 @@ class Mamba(nn.Module):
C = C[:, :, None, :] C = C[:, :, None, :]
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,) A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :] dt = (
B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :] _rms_norm(dt, None, self.config.rms_norm_eps)
C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :] * self.dt_norm_weight[None, None, :]
)
B = (
_rms_norm(B, None, self.config.rms_norm_eps)
* self.B_norm_weight[None, None, None, :]
)
C = (
_rms_norm(C, None, self.config.rms_norm_eps)
* self.C_norm_weight[None, None, None, :]
)
# (bsize, length, num_heads, 1) # (bsize, length, num_heads, 1)
dt = self.dt_proj(dt)[..., None] dt = self.dt_proj(dt)[..., None]
@ -1335,7 +1459,9 @@ class MLP(nn.Module):
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.gate_up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) self.gate_up_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=False
)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
@ -1359,10 +1485,18 @@ class PlamoDecoderLayer(nn.Module):
""" """
Notes: The model performance was degraded when setting all offsets to 1. Notes: The model performance was degraded when setting all offsets to 1.
""" """
self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0) self.pre_mixer_norm = RMSNorm(
self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5) config.hidden_size, eps=config.rms_norm_eps, offset=1.0
self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0) )
self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5)) self.post_mixer_norm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5
)
self.pre_mlp_norm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps, offset=1.0
)
self.post_mlp_norm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5)
)
def __call__( def __call__(
self, self,
@ -1423,8 +1557,12 @@ class PlamoDecoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def __call__(self, x: DecoderInput) -> DecoderOutput: def __call__(self, x: DecoderInput) -> DecoderOutput:
all_hidden_states: Optional[Tuple[mx.array, ...]] = () if x.output_hidden_states else None all_hidden_states: Optional[Tuple[mx.array, ...]] = (
all_self_attns: Optional[Tuple[mx.array, ...]] = () if x.output_attentions else None () if x.output_hidden_states else None
)
all_self_attns: Optional[Tuple[mx.array, ...]] = (
() if x.output_attentions else None
)
hidden_states = x.hidden_states hidden_states = x.hidden_states
for decoder_layer in self.layers: for decoder_layer in self.layers:
@ -1579,9 +1717,13 @@ class PlamoModel(PlamoPreTrainedModel):
else: else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
assert inputs_embeds is not None assert inputs_embeds is not None
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
combined_attention_mask = ( combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
) )
return combined_attention_mask return combined_attention_mask
@ -1600,21 +1742,33 @@ class PlamoModel(PlamoPreTrainedModel):
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
assert input_ids is not None assert input_ids is not None
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length seq_length_with_past = seq_length
past_key_values_length = 0 past_key_values_length = 0
@ -1781,11 +1935,19 @@ class Model(PlamoPreTrainedModel):
```""" ```"""
assert input_ids is not None assert input_ids is not None
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_hidden_states = ( output_attentions
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model( outputs = self.model(

View File

@ -19,7 +19,6 @@ from typing import (
Dict, Dict,
Generator, Generator,
List, List,
NamedTuple,
Optional, Optional,
Tuple, Tuple,
Type, Type,
@ -44,7 +43,6 @@ from transformers import PreTrainedTokenizer
# Local imports # Local imports
from .models import cache from .models import cache
from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters, nparams from .tuner.utils import load_adapters, nparams
@ -721,6 +719,11 @@ def load_model(
weights = {} weights = {}
for wf in weight_files: for wf in weight_files:
weights.update(mx.load(wf)) weights.update(mx.load(wf))
if "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
for k in weights.keys():
if "conv1d.weight" in k:
weights[k] = weights[k].transpose(0, 2, 1)
model_class, model_args_class = get_model_classes(config=config) model_class, model_args_class = get_model_classes(config=config)
@ -1050,6 +1053,12 @@ def convert(
model, config, tokenizer = fetch_from_hub(model_path, lazy=True) model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters())) weights = dict(tree_flatten(model.parameters()))
if "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
for k in weights.keys():
if "conv1d.weight" in k:
weights[k] = weights[k].transpose(0, 2, 1)
dtype = getattr(mx, dtype) dtype = getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()} weights = {k: v.astype(dtype) for k, v in weights.items()}