mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 07:30:06 +08:00
Use mlx's BaseModelArgs
This commit is contained in:
parent
9a6e6541de
commit
40c7ce8048
@ -5,9 +5,8 @@ from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from mlx_lm.models.base import create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
def _is_first_token(mask: mx.array) -> mx.array:
|
||||
@ -16,8 +15,12 @@ def _is_first_token(mask: mx.array) -> mx.array:
|
||||
mask = mask[:, :, :, -q_len:]
|
||||
cont = q_len != kv_len
|
||||
v = False if cont else True
|
||||
out = mx.logical_not(mx.diagonal(mask, offset=-1, axis1=-2, axis2=-1).astype(mx.bool_))
|
||||
out = mx.concatenate([mx.full(shape=(B, Nh, 1), dtype=mx.bool_, vals=v), out], axis=-1)
|
||||
out = mx.logical_not(
|
||||
mx.diagonal(mask, offset=-1, axis1=-2, axis2=-1).astype(mx.bool_)
|
||||
)
|
||||
out = mx.concatenate(
|
||||
[mx.full(shape=(B, Nh, 1), dtype=mx.bool_, vals=v), out], axis=-1
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@ -34,13 +37,17 @@ def _swiglu(h: mx.array) -> mx.array:
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, max_position_embeddings: int = 2048, base: int = 10000) -> None:
|
||||
def __init__(
|
||||
self, dim: int, max_position_embeddings: int = 2048, base: int = 10000
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim))
|
||||
inv_freq = 1.0 / (
|
||||
self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
|
||||
)
|
||||
self._inv_freq = inv_freq
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
@ -74,7 +81,9 @@ def _rotate_half(x: mx.array) -> mx.array:
|
||||
return mx.concatenate([-x2, x1], axis=-1)
|
||||
|
||||
|
||||
def _rotary_pos_emb(x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array) -> mx.array:
|
||||
def _rotary_pos_emb(
|
||||
x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array
|
||||
) -> mx.array:
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
@ -90,8 +99,9 @@ class LinearType(str, enum.Enum):
|
||||
Fp8Retain = "fp8-retain"
|
||||
|
||||
|
||||
class ModelArgs(PretrainedConfig): # type: ignore
|
||||
model_type: str = "plamo"
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs): # type: ignore
|
||||
model_type: str = "plamo2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -145,7 +155,9 @@ class ModelArgs(PretrainedConfig): # type: ignore
|
||||
self.hidden_size_per_head = hidden_size_per_head
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.attention_window_size = attention_window_size
|
||||
self.full_attention_idx = full_attention_idx if full_attention_idx is not None else []
|
||||
self.full_attention_idx = (
|
||||
full_attention_idx if full_attention_idx is not None else []
|
||||
)
|
||||
|
||||
self.mamba_d_state = mamba_d_state
|
||||
self.mamba_d_conv = mamba_d_conv
|
||||
@ -221,9 +233,13 @@ class PlamoCache(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.cache: List[Optional[PlamoLayerCache]] = [None for _ in range(config.num_hidden_layers)]
|
||||
self.cache: List[Optional[PlamoLayerCache]] = [
|
||||
None for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def append_kv(self, key: mx.array, value: mx.array, layer_idx: int) -> tuple[mx.array, mx.array]:
|
||||
def append_kv(
|
||||
self, key: mx.array, value: mx.array, layer_idx: int
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
c = self.cache[layer_idx]
|
||||
if c is None:
|
||||
return key, value
|
||||
@ -239,9 +255,13 @@ class PlamoCache(nn.Module):
|
||||
_validate(c.key, key)
|
||||
_validate(c.value, value)
|
||||
assert key.shape[2] == value.shape[2]
|
||||
return mx.concatenate([c.key, key], axis=2), mx.concatenate([c.value, value], axis=2)
|
||||
return mx.concatenate([c.key, key], axis=2), mx.concatenate(
|
||||
[c.value, value], axis=2
|
||||
)
|
||||
|
||||
def update_attention(self, key_states: mx.array, value_states: mx.array, layer_idx: int) -> PlamoAttentionCache:
|
||||
def update_attention(
|
||||
self, key_states: mx.array, value_states: mx.array, layer_idx: int
|
||||
) -> PlamoAttentionCache:
|
||||
full_attn = layer_idx in self.config.full_attention_idx
|
||||
window_size = self.config.attention_window_size
|
||||
|
||||
@ -265,7 +285,9 @@ class PlamoCache(nn.Module):
|
||||
c.value = v[:, :, -window_size:, :]
|
||||
return self.cache[layer_idx] # type: ignore
|
||||
|
||||
def update_mamba(self, conv_state: mx.array, ssm_state: mx.array, layer_idx: int) -> PlamoMambaCache:
|
||||
def update_mamba(
|
||||
self, conv_state: mx.array, ssm_state: mx.array, layer_idx: int
|
||||
) -> PlamoMambaCache:
|
||||
if self.cache[layer_idx] is None:
|
||||
self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state)
|
||||
else:
|
||||
@ -305,7 +327,9 @@ class PlamoCache(nn.Module):
|
||||
def get_max_length(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
||||
def get_usable_length(
|
||||
self, new_seq_length: int, layer_idx: Optional[int] = 0
|
||||
) -> int:
|
||||
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
||||
# Cache without size limit -> all cache is usable
|
||||
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
||||
@ -361,7 +385,9 @@ class DecoderOutput(NamedTuple):
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||
def _make_causal_mask(input_ids_shape: Tuple[int, int], dtype: mx.Dtype, past_key_values_length: int = 0) -> mx.array:
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: Tuple[int, int], dtype: mx.Dtype, past_key_values_length: int = 0
|
||||
) -> mx.array:
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
@ -372,26 +398,36 @@ def _make_causal_mask(input_ids_shape: Tuple[int, int], dtype: mx.Dtype, past_ke
|
||||
mask = mask.astype(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = mx.concatenate([mx.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1)
|
||||
return mx.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
|
||||
mask = mx.concatenate(
|
||||
[mx.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1
|
||||
)
|
||||
return mx.broadcast_to(
|
||||
mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||
def _expand_mask(mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None) -> mx.array:
|
||||
def _expand_mask(
|
||||
mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None
|
||||
) -> mx.array:
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.shape
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mx.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)).astype(dtype)
|
||||
expanded_mask = mx.broadcast_to(
|
||||
mask[:, None, None, :], (bsz, 1, tgt_len, src_len)
|
||||
).astype(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return mx.where(inverted_mask.astype(mx.bool_), float("-inf"), inverted_mask) # type: ignore
|
||||
|
||||
|
||||
def _rms_norm(hidden_states: mx.array, weight: Optional[mx.array], eps: float, offset: float = 1.0) -> mx.array:
|
||||
def _rms_norm(
|
||||
hidden_states: mx.array, weight: Optional[mx.array], eps: float, offset: float = 1.0
|
||||
) -> mx.array:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.astype(mx.float32)
|
||||
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
|
||||
@ -415,13 +451,18 @@ class RMSNorm(nn.Module):
|
||||
self.offset = offset
|
||||
|
||||
def __call__(self, hidden_states: mx.array) -> mx.array:
|
||||
return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset)
|
||||
return _rms_norm(
|
||||
hidden_states, self.weight, self.variance_epsilon, offset=self.offset
|
||||
)
|
||||
|
||||
|
||||
def get_initial_dt_bias(num_heads: int) -> mx.array:
|
||||
dt_min = 0.001
|
||||
dt_max = 0.1
|
||||
dt = mx.exp(mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
|
||||
dt = mx.exp(
|
||||
mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min))
|
||||
+ math.log(dt_min)
|
||||
)
|
||||
dt = mx.clip(dt, a_min=1e-4, a_max=None)
|
||||
inv_dt = dt + mx.log(-mx.expm1(-dt))
|
||||
return inv_dt
|
||||
@ -450,9 +491,15 @@ def ssd_update_state(
|
||||
|
||||
hidden_size_per_head = x.shape[-1]
|
||||
d_state = B.shape[-1]
|
||||
A = mx.broadcast_to(A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)).astype(mx.float32)
|
||||
dt = mx.broadcast_to(dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head))
|
||||
dt_bias = mx.broadcast_to(dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head))
|
||||
A = mx.broadcast_to(
|
||||
A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)
|
||||
).astype(mx.float32)
|
||||
dt = mx.broadcast_to(
|
||||
dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head)
|
||||
)
|
||||
dt_bias = mx.broadcast_to(
|
||||
dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head)
|
||||
)
|
||||
D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head))
|
||||
assert ssm_state.dtype == mx.float32
|
||||
out, ssm_state = f(
|
||||
@ -592,7 +639,9 @@ def _causal_conv1d(
|
||||
mx.zeros_like(conv_state),
|
||||
conv_state,
|
||||
)
|
||||
out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1])
|
||||
out[:, :, i : i + 1], conv_state = _causal_conv1d_update(
|
||||
conv_state, weight, x[:, :, i : i + 1]
|
||||
)
|
||||
x = out
|
||||
if return_final_states:
|
||||
return x, conv_state
|
||||
@ -600,7 +649,9 @@ def _causal_conv1d(
|
||||
return x, None
|
||||
|
||||
|
||||
def _causal_conv1d_update(conv_state: mx.array, weight: mx.array, xBC: mx.array) -> tuple[mx.array, mx.array]:
|
||||
def _causal_conv1d_update(
|
||||
conv_state: mx.array, weight: mx.array, xBC: mx.array
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
dtype = conv_state.dtype
|
||||
xBC = xBC.astype(dtype)
|
||||
weight = weight.astype(dtype)
|
||||
@ -729,12 +780,18 @@ def mamba_chunk_scan_combined(
|
||||
dt = dt + dt_bias # incorporate bias to dt if provided
|
||||
|
||||
# Prepare initial state
|
||||
state_dim = A.shape[-1] # assume A is of shape (nheads, state_dim) for diagonal A or (nheads, state_dim, state_dim)
|
||||
state_dim = A.shape[
|
||||
-1
|
||||
] # assume A is of shape (nheads, state_dim) for diagonal A or (nheads, state_dim, state_dim)
|
||||
if initial_states is None:
|
||||
# Initialize state to zero for each sequence in batch and each head​:contentReference[oaicite:3]{index=3}
|
||||
state = mx.zeros((batch, nheads, state_dim), dtype=A.dtype)
|
||||
else:
|
||||
state = mx.array(initial_states) if not isinstance(initial_states, mx.array) else initial_states
|
||||
state = (
|
||||
mx.array(initial_states)
|
||||
if not isinstance(initial_states, mx.array)
|
||||
else initial_states
|
||||
)
|
||||
|
||||
# Precompute exponent of A*dt for state update per step (assuming A is diagonal or elementwise applicable)
|
||||
# If A is diagonal values per head (shape (nheads, state_dim)), we compute elementwise exponentials.
|
||||
@ -743,10 +800,14 @@ def mamba_chunk_scan_combined(
|
||||
# A is given as diagonal values per state
|
||||
exp_dA = mx.exp(A * dt) # shape (nheads, state_dim) or (state_dim,)
|
||||
if exp_dA.ndim == 2:
|
||||
exp_dA = exp_dA.reshape((1, nheads, state_dim)) # shape (1, nheads, state_dim) for broadcasting
|
||||
exp_dA = exp_dA.reshape(
|
||||
(1, nheads, state_dim)
|
||||
) # shape (1, nheads, state_dim) for broadcasting
|
||||
else:
|
||||
# If A is a full matrix per head, use matrix exponential (if available in MLX)
|
||||
exp_dA = mx.exp(A * dt) # assuming MX can exponentiate matrix elementwise or use specialized routine
|
||||
exp_dA = mx.exp(
|
||||
A * dt
|
||||
) # assuming MX can exponentiate matrix elementwise or use specialized routine
|
||||
|
||||
# Output buffer
|
||||
out_list = [] # will collect output chunks
|
||||
@ -783,7 +844,9 @@ def mamba_chunk_scan_combined(
|
||||
inc = inc * dt
|
||||
|
||||
# State update: s_{t+1} = exp(A*dt) * s_t + inc​:contentReference[oaicite:4]{index=4}.
|
||||
state = state * exp_dA + inc # elementwise multiply if exp_dA broadcast shape (1, nheads, state_dim)
|
||||
state = (
|
||||
state * exp_dA + inc
|
||||
) # elementwise multiply if exp_dA broadcast shape (1, nheads, state_dim)
|
||||
|
||||
# Compute output for this time step: y_t = C * state_t + (D * x_t if direct term exists)
|
||||
if C.ndim == 3:
|
||||
@ -792,7 +855,9 @@ def mamba_chunk_scan_combined(
|
||||
else:
|
||||
# C shape (nheads, state_dim) or (state_dim,), output one value per head
|
||||
# Multiply and sum over state_dim
|
||||
y_t = mx.sum(state * C.reshape((1, nheads, state_dim)), axis=-1, keepdims=True) # (batch, nheads, 1)
|
||||
y_t = mx.sum(
|
||||
state * C.reshape((1, nheads, state_dim)), axis=-1, keepdims=True
|
||||
) # (batch, nheads, 1)
|
||||
if D is not None:
|
||||
# Add direct input contribution: D * x(t)
|
||||
if D.ndim == 2:
|
||||
@ -804,7 +869,9 @@ def mamba_chunk_scan_combined(
|
||||
)
|
||||
else:
|
||||
# D shape (nheads,) or scalar
|
||||
y_t += D.reshape((1, nheads, -1)) * (x_t if x_t.ndim == 3 else x_t[..., None])
|
||||
y_t += D.reshape((1, nheads, -1)) * (
|
||||
x_t if x_t.ndim == 3 else x_t[..., None]
|
||||
)
|
||||
# Apply gating activation if provided (e.g., elementwise multiply by a sigmoidal function of z)
|
||||
if z is not None:
|
||||
# Example: if z is meant to gate outputs via a sigmoid activation (silu), as in some Mamba variants
|
||||
@ -826,7 +893,9 @@ def mamba_chunk_scan_combined(
|
||||
out = y.reshape((batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1)))
|
||||
else:
|
||||
# If out_list was used and concatenated via Python, we might get a NumPy array; ensure it's MLX:
|
||||
out = mx.array(y).reshape((batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1)))
|
||||
out = mx.array(y).reshape(
|
||||
(batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1))
|
||||
)
|
||||
|
||||
if return_final_states:
|
||||
# Return the final state as well (state holds final state after last chunk)
|
||||
@ -850,16 +919,24 @@ class PlamoPreTrainedModel(nn.Module): # type: ignore
|
||||
def _init_weights(self, module: nn.Module) -> None:
|
||||
std = 0.02
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape)
|
||||
module.weight = mx.random.normal(
|
||||
loc=0.0, scale=std, shape=module.weight.shape
|
||||
)
|
||||
if module.bias is not None:
|
||||
module.bias = mx.zeros_like(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape)
|
||||
module.weight = mx.random.normal(
|
||||
loc=0.0, scale=std, shape=module.weight.shape
|
||||
)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx] = mx.zeros_like(module.weight[module.padding_idx])
|
||||
module.weight[module.padding_idx] = mx.zeros_like(
|
||||
module.weight[module.padding_idx]
|
||||
)
|
||||
|
||||
|
||||
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
||||
def causal_conv1d_update(
|
||||
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
|
||||
):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
||||
@ -884,13 +961,21 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cach
|
||||
assert conv_state.shape == (batch, dim, state_len)
|
||||
assert weight.shape == (dim, width)
|
||||
if cache_seqlens is None:
|
||||
x_new = mx.concatenate([conv_state, x], axis=-1).astype(weight.dtype) # (batch, dim, state_len + seqlen)
|
||||
x_new = mx.concatenate([conv_state, x], axis=-1).astype(
|
||||
weight.dtype
|
||||
) # (batch, dim, state_len + seqlen)
|
||||
conv_state = x_new[:, :, -state_len:]
|
||||
else:
|
||||
width_idx = mx.expand_dims(mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0) + cache_seqlens.unsqueeze(1)
|
||||
width_idx = mx.expand_dims(
|
||||
mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0
|
||||
) + cache_seqlens.unsqueeze(1)
|
||||
width_idx = mx.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||
x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1).astype(weight.dtype)
|
||||
copy_idx = mx.expand_dims(mx.arange(seqlen, dtype=mx.int64), axis=0) + cache_seqlens.unsqueeze(1)
|
||||
x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1).astype(
|
||||
weight.dtype
|
||||
)
|
||||
copy_idx = mx.expand_dims(
|
||||
mx.arange(seqlen, dtype=mx.int64), axis=0
|
||||
) + cache_seqlens.unsqueeze(1)
|
||||
copy_idx = mx.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||
conv_state.scatter_(2, copy_idx, x)
|
||||
assert bias is None
|
||||
@ -900,9 +985,7 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cach
|
||||
mx.expand_dims(weight, axis=2),
|
||||
padding=0,
|
||||
groups=dim,
|
||||
).transpose(
|
||||
0, 2, 1
|
||||
)[:, :, -seqlen:]
|
||||
).transpose(0, 2, 1)[:, :, -seqlen:]
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
return (out if activation is None else nn.silu(out)).astype(dtype_in), conv_state
|
||||
@ -969,7 +1052,9 @@ def selective_state_update_ref(
|
||||
mx.tile(mx.expand_dims(C, axis=2), (1, 1, nheads // ngroups, 1)),
|
||||
(batch, nheads, dstate),
|
||||
) # (batch, nheads, dstate)
|
||||
dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(B, axis=-2) # (batch, nheads, dim, dstate)
|
||||
dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(
|
||||
B, axis=-2
|
||||
) # (batch, nheads, dim, dstate)
|
||||
state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate
|
||||
out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C)
|
||||
if D is not None:
|
||||
@ -1013,12 +1098,16 @@ class Attention(nn.Module):
|
||||
self.q_proj_dim + self.k_proj_dim + self.v_proj_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False)
|
||||
self.o_proj = nn.Linear(
|
||||
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
self.q_weight = mx.ones((self.q_num_heads, self.qk_dim))
|
||||
self.k_weight = mx.ones((self.k_num_heads, self.qk_dim))
|
||||
|
||||
self.rotary_emb = RotaryEmbedding(self.qk_dim, max_position_embeddings=self.config.attention_window_size)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.qk_dim, max_position_embeddings=self.config.attention_window_size
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -1033,21 +1122,33 @@ class Attention(nn.Module):
|
||||
query_states, key_states, value_states = mx.split(
|
||||
qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1
|
||||
)
|
||||
query_states = query_states.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
||||
key_states = key_states.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
||||
value_states = value_states.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
|
||||
query_states = query_states.reshape(
|
||||
bsz, q_len, self.q_num_heads, self.qk_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
key_states = key_states.reshape(
|
||||
bsz, q_len, self.k_num_heads, self.qk_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
value_states = value_states.reshape(
|
||||
bsz, q_len, self.v_num_heads, self.v_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
attn_dtype = query_states.dtype
|
||||
|
||||
query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
|
||||
query_states = (
|
||||
_rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
|
||||
)
|
||||
key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
|
||||
|
||||
if past_states is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states_new = key_states
|
||||
value_states_new = value_states
|
||||
key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) # type: ignore
|
||||
past_states.update_attention(key_states_new, value_states_new, self.layer_idx)
|
||||
key_states, value_states = past_states.append_kv(
|
||||
key_states, value_states, self.layer_idx
|
||||
) # type: ignore
|
||||
past_states.update_attention(
|
||||
key_states_new, value_states_new, self.layer_idx
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
position_ids = mx.arange(kv_seq_len, dtype=mx.int64)[None]
|
||||
@ -1082,7 +1183,9 @@ class Attention(nn.Module):
|
||||
)
|
||||
else:
|
||||
if attention_mask.dtype == bool:
|
||||
attention_mask = mx.where(attention_mask, mx.array(0.0, dtype=mx.float16), float("-inf"))
|
||||
attention_mask = mx.where(
|
||||
attention_mask, mx.array(0.0, dtype=mx.float16), float("-inf")
|
||||
)
|
||||
if len(attention_mask.shape) == 2:
|
||||
attention_mask = attention_mask[None, None]
|
||||
assert len(attention_mask.shape) == 4
|
||||
@ -1095,14 +1198,20 @@ class Attention(nn.Module):
|
||||
)
|
||||
# `generate` function creates attention mask that does not consider sliding window
|
||||
m_swa = m_swa[None, None]
|
||||
attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :]
|
||||
attention_mask = attention_mask[
|
||||
:, :, -query_states.shape[2] :, -key_states.shape[2] :
|
||||
]
|
||||
attention_mask = mx.where(m_swa, attention_mask, float("-inf"))
|
||||
|
||||
# like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers,
|
||||
# we need to attend to all tokens in masked rows for `scaled_dot_product_attention`
|
||||
bool_mask = mx.logical_not(mx.isneginf(attention_mask))
|
||||
valid_tokens = mx.sum(bool_mask, axis=-1).astype(mx.bool_) # (..., q_len) # type: ignore
|
||||
attention_mask = mx.where(valid_tokens[..., None], attention_mask, float(0.0))
|
||||
valid_tokens = mx.sum(bool_mask, axis=-1).astype(
|
||||
mx.bool_
|
||||
) # (..., q_len) # type: ignore
|
||||
attention_mask = mx.where(
|
||||
valid_tokens[..., None], attention_mask, float(0.0)
|
||||
)
|
||||
attn_output = mx.fast.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
@ -1137,7 +1246,9 @@ class Mamba(nn.Module):
|
||||
|
||||
self.intermediate_size = self.num_heads * self.hidden_size_per_head
|
||||
|
||||
self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
||||
self.in_proj = nn.Linear(
|
||||
self.hidden_size, 2 * self.intermediate_size, bias=False
|
||||
)
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=self.intermediate_size,
|
||||
out_channels=self.intermediate_size,
|
||||
@ -1235,14 +1346,18 @@ class Mamba(nn.Module):
|
||||
)
|
||||
|
||||
# conv
|
||||
x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length)
|
||||
x = x.reshape(bsize, length, -1).transpose(
|
||||
0, 2, 1
|
||||
) # (bsize, intermediate_size, length)
|
||||
if bool_mask is not None:
|
||||
x = mx.where(bool_mask[:, None, :], x, 0.0)
|
||||
if is_update:
|
||||
assert conv_state is not None
|
||||
x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
|
||||
else:
|
||||
x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
|
||||
x, conv_state = _causal_conv1d(
|
||||
conv_state, self.conv1d.weight, x, seq_idx=seq_idx
|
||||
)
|
||||
x = x.astype(hidden_states.dtype)
|
||||
x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size)
|
||||
x = x.reshape(bsize, length, -1)
|
||||
@ -1257,9 +1372,18 @@ class Mamba(nn.Module):
|
||||
C = C[:, :, None, :]
|
||||
|
||||
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
|
||||
dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
|
||||
B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
|
||||
C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
|
||||
dt = (
|
||||
_rms_norm(dt, None, self.config.rms_norm_eps)
|
||||
* self.dt_norm_weight[None, None, :]
|
||||
)
|
||||
B = (
|
||||
_rms_norm(B, None, self.config.rms_norm_eps)
|
||||
* self.B_norm_weight[None, None, None, :]
|
||||
)
|
||||
C = (
|
||||
_rms_norm(C, None, self.config.rms_norm_eps)
|
||||
* self.C_norm_weight[None, None, None, :]
|
||||
)
|
||||
|
||||
# (bsize, length, num_heads, 1)
|
||||
dt = self.dt_proj(dt)[..., None]
|
||||
@ -1335,7 +1459,9 @@ class MLP(nn.Module):
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
|
||||
self.gate_up_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size * 2, bias=False
|
||||
)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
@ -1359,10 +1485,18 @@ class PlamoDecoderLayer(nn.Module):
|
||||
"""
|
||||
Notes: The model performance was degraded when setting all offsets to 1.
|
||||
"""
|
||||
self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
|
||||
self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5)
|
||||
self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
|
||||
self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5))
|
||||
self.pre_mixer_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0
|
||||
)
|
||||
self.post_mixer_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5
|
||||
)
|
||||
self.pre_mlp_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0
|
||||
)
|
||||
self.post_mlp_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5)
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -1423,8 +1557,12 @@ class PlamoDecoder(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def __call__(self, x: DecoderInput) -> DecoderOutput:
|
||||
all_hidden_states: Optional[Tuple[mx.array, ...]] = () if x.output_hidden_states else None
|
||||
all_self_attns: Optional[Tuple[mx.array, ...]] = () if x.output_attentions else None
|
||||
all_hidden_states: Optional[Tuple[mx.array, ...]] = (
|
||||
() if x.output_hidden_states else None
|
||||
)
|
||||
all_self_attns: Optional[Tuple[mx.array, ...]] = (
|
||||
() if x.output_attentions else None
|
||||
)
|
||||
hidden_states = x.hidden_states
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
@ -1579,9 +1717,13 @@ class PlamoModel(PlamoPreTrainedModel):
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
assert inputs_embeds is not None
|
||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
@ -1600,21 +1742,33 @@ class PlamoModel(PlamoPreTrainedModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
assert input_ids is not None
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
@ -1781,11 +1935,19 @@ class Model(PlamoPreTrainedModel):
|
||||
```"""
|
||||
assert input_ids is not None
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
|
@ -19,7 +19,6 @@ from typing import (
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
@ -44,7 +43,6 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models import cache
|
||||
from .sample_utils import make_logits_processors, make_sampler
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
from .tuner.utils import load_adapters, nparams
|
||||
@ -721,6 +719,11 @@ def load_model(
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
if "lm_head.weight" not in weights:
|
||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||
for k in weights.keys():
|
||||
if "conv1d.weight" in k:
|
||||
weights[k] = weights[k].transpose(0, 2, 1)
|
||||
|
||||
model_class, model_args_class = get_model_classes(config=config)
|
||||
|
||||
@ -1050,6 +1053,12 @@ def convert(
|
||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
if "lm_head.weight" not in weights:
|
||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||
for k in weights.keys():
|
||||
if "conv1d.weight" in k:
|
||||
weights[k] = weights[k].transpose(0, 2, 1)
|
||||
|
||||
dtype = getattr(mx, dtype)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user