mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00
Use mlx's BaseModelArgs
This commit is contained in:
parent
9a6e6541de
commit
40c7ce8048
@ -5,9 +5,8 @@ from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from transformers import PretrainedConfig
|
|
||||||
|
|
||||||
from mlx_lm.models.base import create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask
|
||||||
|
|
||||||
|
|
||||||
def _is_first_token(mask: mx.array) -> mx.array:
|
def _is_first_token(mask: mx.array) -> mx.array:
|
||||||
@ -16,8 +15,12 @@ def _is_first_token(mask: mx.array) -> mx.array:
|
|||||||
mask = mask[:, :, :, -q_len:]
|
mask = mask[:, :, :, -q_len:]
|
||||||
cont = q_len != kv_len
|
cont = q_len != kv_len
|
||||||
v = False if cont else True
|
v = False if cont else True
|
||||||
out = mx.logical_not(mx.diagonal(mask, offset=-1, axis1=-2, axis2=-1).astype(mx.bool_))
|
out = mx.logical_not(
|
||||||
out = mx.concatenate([mx.full(shape=(B, Nh, 1), dtype=mx.bool_, vals=v), out], axis=-1)
|
mx.diagonal(mask, offset=-1, axis1=-2, axis2=-1).astype(mx.bool_)
|
||||||
|
)
|
||||||
|
out = mx.concatenate(
|
||||||
|
[mx.full(shape=(B, Nh, 1), dtype=mx.bool_, vals=v), out], axis=-1
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -34,13 +37,17 @@ def _swiglu(h: mx.array) -> mx.array:
|
|||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim: int, max_position_embeddings: int = 2048, base: int = 10000) -> None:
|
def __init__(
|
||||||
|
self, dim: int, max_position_embeddings: int = 2048, base: int = 10000
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.base = base
|
self.base = base
|
||||||
inv_freq = 1.0 / (self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim))
|
inv_freq = 1.0 / (
|
||||||
|
self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
|
||||||
|
)
|
||||||
self._inv_freq = inv_freq
|
self._inv_freq = inv_freq
|
||||||
|
|
||||||
# Build here to make `torch.jit.trace` work.
|
# Build here to make `torch.jit.trace` work.
|
||||||
@ -74,7 +81,9 @@ def _rotate_half(x: mx.array) -> mx.array:
|
|||||||
return mx.concatenate([-x2, x1], axis=-1)
|
return mx.concatenate([-x2, x1], axis=-1)
|
||||||
|
|
||||||
|
|
||||||
def _rotary_pos_emb(x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array) -> mx.array:
|
def _rotary_pos_emb(
|
||||||
|
x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array
|
||||||
|
) -> mx.array:
|
||||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
@ -90,8 +99,9 @@ class LinearType(str, enum.Enum):
|
|||||||
Fp8Retain = "fp8-retain"
|
Fp8Retain = "fp8-retain"
|
||||||
|
|
||||||
|
|
||||||
class ModelArgs(PretrainedConfig): # type: ignore
|
@dataclass
|
||||||
model_type: str = "plamo"
|
class ModelArgs(BaseModelArgs): # type: ignore
|
||||||
|
model_type: str = "plamo2"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -145,7 +155,9 @@ class ModelArgs(PretrainedConfig): # type: ignore
|
|||||||
self.hidden_size_per_head = hidden_size_per_head
|
self.hidden_size_per_head = hidden_size_per_head
|
||||||
self.num_key_value_heads = num_key_value_heads
|
self.num_key_value_heads = num_key_value_heads
|
||||||
self.attention_window_size = attention_window_size
|
self.attention_window_size = attention_window_size
|
||||||
self.full_attention_idx = full_attention_idx if full_attention_idx is not None else []
|
self.full_attention_idx = (
|
||||||
|
full_attention_idx if full_attention_idx is not None else []
|
||||||
|
)
|
||||||
|
|
||||||
self.mamba_d_state = mamba_d_state
|
self.mamba_d_state = mamba_d_state
|
||||||
self.mamba_d_conv = mamba_d_conv
|
self.mamba_d_conv = mamba_d_conv
|
||||||
@ -221,9 +233,13 @@ class PlamoCache(nn.Module):
|
|||||||
def __init__(self, config: ModelArgs) -> None:
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.cache: List[Optional[PlamoLayerCache]] = [None for _ in range(config.num_hidden_layers)]
|
self.cache: List[Optional[PlamoLayerCache]] = [
|
||||||
|
None for _ in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
|
||||||
def append_kv(self, key: mx.array, value: mx.array, layer_idx: int) -> tuple[mx.array, mx.array]:
|
def append_kv(
|
||||||
|
self, key: mx.array, value: mx.array, layer_idx: int
|
||||||
|
) -> tuple[mx.array, mx.array]:
|
||||||
c = self.cache[layer_idx]
|
c = self.cache[layer_idx]
|
||||||
if c is None:
|
if c is None:
|
||||||
return key, value
|
return key, value
|
||||||
@ -239,9 +255,13 @@ class PlamoCache(nn.Module):
|
|||||||
_validate(c.key, key)
|
_validate(c.key, key)
|
||||||
_validate(c.value, value)
|
_validate(c.value, value)
|
||||||
assert key.shape[2] == value.shape[2]
|
assert key.shape[2] == value.shape[2]
|
||||||
return mx.concatenate([c.key, key], axis=2), mx.concatenate([c.value, value], axis=2)
|
return mx.concatenate([c.key, key], axis=2), mx.concatenate(
|
||||||
|
[c.value, value], axis=2
|
||||||
|
)
|
||||||
|
|
||||||
def update_attention(self, key_states: mx.array, value_states: mx.array, layer_idx: int) -> PlamoAttentionCache:
|
def update_attention(
|
||||||
|
self, key_states: mx.array, value_states: mx.array, layer_idx: int
|
||||||
|
) -> PlamoAttentionCache:
|
||||||
full_attn = layer_idx in self.config.full_attention_idx
|
full_attn = layer_idx in self.config.full_attention_idx
|
||||||
window_size = self.config.attention_window_size
|
window_size = self.config.attention_window_size
|
||||||
|
|
||||||
@ -265,7 +285,9 @@ class PlamoCache(nn.Module):
|
|||||||
c.value = v[:, :, -window_size:, :]
|
c.value = v[:, :, -window_size:, :]
|
||||||
return self.cache[layer_idx] # type: ignore
|
return self.cache[layer_idx] # type: ignore
|
||||||
|
|
||||||
def update_mamba(self, conv_state: mx.array, ssm_state: mx.array, layer_idx: int) -> PlamoMambaCache:
|
def update_mamba(
|
||||||
|
self, conv_state: mx.array, ssm_state: mx.array, layer_idx: int
|
||||||
|
) -> PlamoMambaCache:
|
||||||
if self.cache[layer_idx] is None:
|
if self.cache[layer_idx] is None:
|
||||||
self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state)
|
self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state)
|
||||||
else:
|
else:
|
||||||
@ -305,7 +327,9 @@ class PlamoCache(nn.Module):
|
|||||||
def get_max_length(self) -> int | None:
|
def get_max_length(self) -> int | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
def get_usable_length(
|
||||||
|
self, new_seq_length: int, layer_idx: Optional[int] = 0
|
||||||
|
) -> int:
|
||||||
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
||||||
# Cache without size limit -> all cache is usable
|
# Cache without size limit -> all cache is usable
|
||||||
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
||||||
@ -361,7 +385,9 @@ class DecoderOutput(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||||
def _make_causal_mask(input_ids_shape: Tuple[int, int], dtype: mx.Dtype, past_key_values_length: int = 0) -> mx.array:
|
def _make_causal_mask(
|
||||||
|
input_ids_shape: Tuple[int, int], dtype: mx.Dtype, past_key_values_length: int = 0
|
||||||
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
"""
|
"""
|
||||||
@ -372,26 +398,36 @@ def _make_causal_mask(input_ids_shape: Tuple[int, int], dtype: mx.Dtype, past_ke
|
|||||||
mask = mask.astype(dtype)
|
mask = mask.astype(dtype)
|
||||||
|
|
||||||
if past_key_values_length > 0:
|
if past_key_values_length > 0:
|
||||||
mask = mx.concatenate([mx.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1)
|
mask = mx.concatenate(
|
||||||
return mx.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
|
[mx.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1
|
||||||
|
)
|
||||||
|
return mx.broadcast_to(
|
||||||
|
mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||||
def _expand_mask(mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None) -> mx.array:
|
def _expand_mask(
|
||||||
|
mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None
|
||||||
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
"""
|
"""
|
||||||
bsz, src_len = mask.shape
|
bsz, src_len = mask.shape
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
expanded_mask = mx.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)).astype(dtype)
|
expanded_mask = mx.broadcast_to(
|
||||||
|
mask[:, None, None, :], (bsz, 1, tgt_len, src_len)
|
||||||
|
).astype(dtype)
|
||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return mx.where(inverted_mask.astype(mx.bool_), float("-inf"), inverted_mask) # type: ignore
|
return mx.where(inverted_mask.astype(mx.bool_), float("-inf"), inverted_mask) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def _rms_norm(hidden_states: mx.array, weight: Optional[mx.array], eps: float, offset: float = 1.0) -> mx.array:
|
def _rms_norm(
|
||||||
|
hidden_states: mx.array, weight: Optional[mx.array], eps: float, offset: float = 1.0
|
||||||
|
) -> mx.array:
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
hidden_states = hidden_states.astype(mx.float32)
|
hidden_states = hidden_states.astype(mx.float32)
|
||||||
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
|
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
|
||||||
@ -415,13 +451,18 @@ class RMSNorm(nn.Module):
|
|||||||
self.offset = offset
|
self.offset = offset
|
||||||
|
|
||||||
def __call__(self, hidden_states: mx.array) -> mx.array:
|
def __call__(self, hidden_states: mx.array) -> mx.array:
|
||||||
return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset)
|
return _rms_norm(
|
||||||
|
hidden_states, self.weight, self.variance_epsilon, offset=self.offset
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_initial_dt_bias(num_heads: int) -> mx.array:
|
def get_initial_dt_bias(num_heads: int) -> mx.array:
|
||||||
dt_min = 0.001
|
dt_min = 0.001
|
||||||
dt_max = 0.1
|
dt_max = 0.1
|
||||||
dt = mx.exp(mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
|
dt = mx.exp(
|
||||||
|
mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min))
|
||||||
|
+ math.log(dt_min)
|
||||||
|
)
|
||||||
dt = mx.clip(dt, a_min=1e-4, a_max=None)
|
dt = mx.clip(dt, a_min=1e-4, a_max=None)
|
||||||
inv_dt = dt + mx.log(-mx.expm1(-dt))
|
inv_dt = dt + mx.log(-mx.expm1(-dt))
|
||||||
return inv_dt
|
return inv_dt
|
||||||
@ -450,9 +491,15 @@ def ssd_update_state(
|
|||||||
|
|
||||||
hidden_size_per_head = x.shape[-1]
|
hidden_size_per_head = x.shape[-1]
|
||||||
d_state = B.shape[-1]
|
d_state = B.shape[-1]
|
||||||
A = mx.broadcast_to(A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)).astype(mx.float32)
|
A = mx.broadcast_to(
|
||||||
dt = mx.broadcast_to(dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head))
|
A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)
|
||||||
dt_bias = mx.broadcast_to(dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head))
|
).astype(mx.float32)
|
||||||
|
dt = mx.broadcast_to(
|
||||||
|
dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head)
|
||||||
|
)
|
||||||
|
dt_bias = mx.broadcast_to(
|
||||||
|
dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head)
|
||||||
|
)
|
||||||
D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head))
|
D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head))
|
||||||
assert ssm_state.dtype == mx.float32
|
assert ssm_state.dtype == mx.float32
|
||||||
out, ssm_state = f(
|
out, ssm_state = f(
|
||||||
@ -592,7 +639,9 @@ def _causal_conv1d(
|
|||||||
mx.zeros_like(conv_state),
|
mx.zeros_like(conv_state),
|
||||||
conv_state,
|
conv_state,
|
||||||
)
|
)
|
||||||
out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1])
|
out[:, :, i : i + 1], conv_state = _causal_conv1d_update(
|
||||||
|
conv_state, weight, x[:, :, i : i + 1]
|
||||||
|
)
|
||||||
x = out
|
x = out
|
||||||
if return_final_states:
|
if return_final_states:
|
||||||
return x, conv_state
|
return x, conv_state
|
||||||
@ -600,7 +649,9 @@ def _causal_conv1d(
|
|||||||
return x, None
|
return x, None
|
||||||
|
|
||||||
|
|
||||||
def _causal_conv1d_update(conv_state: mx.array, weight: mx.array, xBC: mx.array) -> tuple[mx.array, mx.array]:
|
def _causal_conv1d_update(
|
||||||
|
conv_state: mx.array, weight: mx.array, xBC: mx.array
|
||||||
|
) -> tuple[mx.array, mx.array]:
|
||||||
dtype = conv_state.dtype
|
dtype = conv_state.dtype
|
||||||
xBC = xBC.astype(dtype)
|
xBC = xBC.astype(dtype)
|
||||||
weight = weight.astype(dtype)
|
weight = weight.astype(dtype)
|
||||||
@ -729,12 +780,18 @@ def mamba_chunk_scan_combined(
|
|||||||
dt = dt + dt_bias # incorporate bias to dt if provided
|
dt = dt + dt_bias # incorporate bias to dt if provided
|
||||||
|
|
||||||
# Prepare initial state
|
# Prepare initial state
|
||||||
state_dim = A.shape[-1] # assume A is of shape (nheads, state_dim) for diagonal A or (nheads, state_dim, state_dim)
|
state_dim = A.shape[
|
||||||
|
-1
|
||||||
|
] # assume A is of shape (nheads, state_dim) for diagonal A or (nheads, state_dim, state_dim)
|
||||||
if initial_states is None:
|
if initial_states is None:
|
||||||
# Initialize state to zero for each sequence in batch and each head​:contentReference[oaicite:3]{index=3}
|
# Initialize state to zero for each sequence in batch and each head​:contentReference[oaicite:3]{index=3}
|
||||||
state = mx.zeros((batch, nheads, state_dim), dtype=A.dtype)
|
state = mx.zeros((batch, nheads, state_dim), dtype=A.dtype)
|
||||||
else:
|
else:
|
||||||
state = mx.array(initial_states) if not isinstance(initial_states, mx.array) else initial_states
|
state = (
|
||||||
|
mx.array(initial_states)
|
||||||
|
if not isinstance(initial_states, mx.array)
|
||||||
|
else initial_states
|
||||||
|
)
|
||||||
|
|
||||||
# Precompute exponent of A*dt for state update per step (assuming A is diagonal or elementwise applicable)
|
# Precompute exponent of A*dt for state update per step (assuming A is diagonal or elementwise applicable)
|
||||||
# If A is diagonal values per head (shape (nheads, state_dim)), we compute elementwise exponentials.
|
# If A is diagonal values per head (shape (nheads, state_dim)), we compute elementwise exponentials.
|
||||||
@ -743,10 +800,14 @@ def mamba_chunk_scan_combined(
|
|||||||
# A is given as diagonal values per state
|
# A is given as diagonal values per state
|
||||||
exp_dA = mx.exp(A * dt) # shape (nheads, state_dim) or (state_dim,)
|
exp_dA = mx.exp(A * dt) # shape (nheads, state_dim) or (state_dim,)
|
||||||
if exp_dA.ndim == 2:
|
if exp_dA.ndim == 2:
|
||||||
exp_dA = exp_dA.reshape((1, nheads, state_dim)) # shape (1, nheads, state_dim) for broadcasting
|
exp_dA = exp_dA.reshape(
|
||||||
|
(1, nheads, state_dim)
|
||||||
|
) # shape (1, nheads, state_dim) for broadcasting
|
||||||
else:
|
else:
|
||||||
# If A is a full matrix per head, use matrix exponential (if available in MLX)
|
# If A is a full matrix per head, use matrix exponential (if available in MLX)
|
||||||
exp_dA = mx.exp(A * dt) # assuming MX can exponentiate matrix elementwise or use specialized routine
|
exp_dA = mx.exp(
|
||||||
|
A * dt
|
||||||
|
) # assuming MX can exponentiate matrix elementwise or use specialized routine
|
||||||
|
|
||||||
# Output buffer
|
# Output buffer
|
||||||
out_list = [] # will collect output chunks
|
out_list = [] # will collect output chunks
|
||||||
@ -783,7 +844,9 @@ def mamba_chunk_scan_combined(
|
|||||||
inc = inc * dt
|
inc = inc * dt
|
||||||
|
|
||||||
# State update: s_{t+1} = exp(A*dt) * s_t + inc​:contentReference[oaicite:4]{index=4}.
|
# State update: s_{t+1} = exp(A*dt) * s_t + inc​:contentReference[oaicite:4]{index=4}.
|
||||||
state = state * exp_dA + inc # elementwise multiply if exp_dA broadcast shape (1, nheads, state_dim)
|
state = (
|
||||||
|
state * exp_dA + inc
|
||||||
|
) # elementwise multiply if exp_dA broadcast shape (1, nheads, state_dim)
|
||||||
|
|
||||||
# Compute output for this time step: y_t = C * state_t + (D * x_t if direct term exists)
|
# Compute output for this time step: y_t = C * state_t + (D * x_t if direct term exists)
|
||||||
if C.ndim == 3:
|
if C.ndim == 3:
|
||||||
@ -792,7 +855,9 @@ def mamba_chunk_scan_combined(
|
|||||||
else:
|
else:
|
||||||
# C shape (nheads, state_dim) or (state_dim,), output one value per head
|
# C shape (nheads, state_dim) or (state_dim,), output one value per head
|
||||||
# Multiply and sum over state_dim
|
# Multiply and sum over state_dim
|
||||||
y_t = mx.sum(state * C.reshape((1, nheads, state_dim)), axis=-1, keepdims=True) # (batch, nheads, 1)
|
y_t = mx.sum(
|
||||||
|
state * C.reshape((1, nheads, state_dim)), axis=-1, keepdims=True
|
||||||
|
) # (batch, nheads, 1)
|
||||||
if D is not None:
|
if D is not None:
|
||||||
# Add direct input contribution: D * x(t)
|
# Add direct input contribution: D * x(t)
|
||||||
if D.ndim == 2:
|
if D.ndim == 2:
|
||||||
@ -804,7 +869,9 @@ def mamba_chunk_scan_combined(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# D shape (nheads,) or scalar
|
# D shape (nheads,) or scalar
|
||||||
y_t += D.reshape((1, nheads, -1)) * (x_t if x_t.ndim == 3 else x_t[..., None])
|
y_t += D.reshape((1, nheads, -1)) * (
|
||||||
|
x_t if x_t.ndim == 3 else x_t[..., None]
|
||||||
|
)
|
||||||
# Apply gating activation if provided (e.g., elementwise multiply by a sigmoidal function of z)
|
# Apply gating activation if provided (e.g., elementwise multiply by a sigmoidal function of z)
|
||||||
if z is not None:
|
if z is not None:
|
||||||
# Example: if z is meant to gate outputs via a sigmoid activation (silu), as in some Mamba variants
|
# Example: if z is meant to gate outputs via a sigmoid activation (silu), as in some Mamba variants
|
||||||
@ -826,7 +893,9 @@ def mamba_chunk_scan_combined(
|
|||||||
out = y.reshape((batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1)))
|
out = y.reshape((batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1)))
|
||||||
else:
|
else:
|
||||||
# If out_list was used and concatenated via Python, we might get a NumPy array; ensure it's MLX:
|
# If out_list was used and concatenated via Python, we might get a NumPy array; ensure it's MLX:
|
||||||
out = mx.array(y).reshape((batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1)))
|
out = mx.array(y).reshape(
|
||||||
|
(batch, -1, nheads * (y.shape[-1] if y.ndim > 2 else 1))
|
||||||
|
)
|
||||||
|
|
||||||
if return_final_states:
|
if return_final_states:
|
||||||
# Return the final state as well (state holds final state after last chunk)
|
# Return the final state as well (state holds final state after last chunk)
|
||||||
@ -850,16 +919,24 @@ class PlamoPreTrainedModel(nn.Module): # type: ignore
|
|||||||
def _init_weights(self, module: nn.Module) -> None:
|
def _init_weights(self, module: nn.Module) -> None:
|
||||||
std = 0.02
|
std = 0.02
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape)
|
module.weight = mx.random.normal(
|
||||||
|
loc=0.0, scale=std, shape=module.weight.shape
|
||||||
|
)
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
module.bias = mx.zeros_like(module.bias)
|
module.bias = mx.zeros_like(module.bias)
|
||||||
elif isinstance(module, nn.Embedding):
|
elif isinstance(module, nn.Embedding):
|
||||||
module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape)
|
module.weight = mx.random.normal(
|
||||||
|
loc=0.0, scale=std, shape=module.weight.shape
|
||||||
|
)
|
||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight[module.padding_idx] = mx.zeros_like(module.weight[module.padding_idx])
|
module.weight[module.padding_idx] = mx.zeros_like(
|
||||||
|
module.weight[module.padding_idx]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
def causal_conv1d_update(
|
||||||
|
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
x: (batch, dim) or (batch, dim, seqlen)
|
x: (batch, dim) or (batch, dim, seqlen)
|
||||||
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
||||||
@ -884,13 +961,21 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cach
|
|||||||
assert conv_state.shape == (batch, dim, state_len)
|
assert conv_state.shape == (batch, dim, state_len)
|
||||||
assert weight.shape == (dim, width)
|
assert weight.shape == (dim, width)
|
||||||
if cache_seqlens is None:
|
if cache_seqlens is None:
|
||||||
x_new = mx.concatenate([conv_state, x], axis=-1).astype(weight.dtype) # (batch, dim, state_len + seqlen)
|
x_new = mx.concatenate([conv_state, x], axis=-1).astype(
|
||||||
|
weight.dtype
|
||||||
|
) # (batch, dim, state_len + seqlen)
|
||||||
conv_state = x_new[:, :, -state_len:]
|
conv_state = x_new[:, :, -state_len:]
|
||||||
else:
|
else:
|
||||||
width_idx = mx.expand_dims(mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0) + cache_seqlens.unsqueeze(1)
|
width_idx = mx.expand_dims(
|
||||||
|
mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0
|
||||||
|
) + cache_seqlens.unsqueeze(1)
|
||||||
width_idx = mx.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
width_idx = mx.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||||
x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1).astype(weight.dtype)
|
x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1).astype(
|
||||||
copy_idx = mx.expand_dims(mx.arange(seqlen, dtype=mx.int64), axis=0) + cache_seqlens.unsqueeze(1)
|
weight.dtype
|
||||||
|
)
|
||||||
|
copy_idx = mx.expand_dims(
|
||||||
|
mx.arange(seqlen, dtype=mx.int64), axis=0
|
||||||
|
) + cache_seqlens.unsqueeze(1)
|
||||||
copy_idx = mx.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
copy_idx = mx.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||||
conv_state.scatter_(2, copy_idx, x)
|
conv_state.scatter_(2, copy_idx, x)
|
||||||
assert bias is None
|
assert bias is None
|
||||||
@ -900,9 +985,7 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cach
|
|||||||
mx.expand_dims(weight, axis=2),
|
mx.expand_dims(weight, axis=2),
|
||||||
padding=0,
|
padding=0,
|
||||||
groups=dim,
|
groups=dim,
|
||||||
).transpose(
|
).transpose(0, 2, 1)[:, :, -seqlen:]
|
||||||
0, 2, 1
|
|
||||||
)[:, :, -seqlen:]
|
|
||||||
if unsqueeze:
|
if unsqueeze:
|
||||||
out = out.squeeze(-1)
|
out = out.squeeze(-1)
|
||||||
return (out if activation is None else nn.silu(out)).astype(dtype_in), conv_state
|
return (out if activation is None else nn.silu(out)).astype(dtype_in), conv_state
|
||||||
@ -969,7 +1052,9 @@ def selective_state_update_ref(
|
|||||||
mx.tile(mx.expand_dims(C, axis=2), (1, 1, nheads // ngroups, 1)),
|
mx.tile(mx.expand_dims(C, axis=2), (1, 1, nheads // ngroups, 1)),
|
||||||
(batch, nheads, dstate),
|
(batch, nheads, dstate),
|
||||||
) # (batch, nheads, dstate)
|
) # (batch, nheads, dstate)
|
||||||
dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(B, axis=-2) # (batch, nheads, dim, dstate)
|
dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(
|
||||||
|
B, axis=-2
|
||||||
|
) # (batch, nheads, dim, dstate)
|
||||||
state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate
|
state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate
|
||||||
out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C)
|
out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C)
|
||||||
if D is not None:
|
if D is not None:
|
||||||
@ -1013,12 +1098,16 @@ class Attention(nn.Module):
|
|||||||
self.q_proj_dim + self.k_proj_dim + self.v_proj_dim,
|
self.q_proj_dim + self.k_proj_dim + self.v_proj_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False)
|
self.o_proj = nn.Linear(
|
||||||
|
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
self.q_weight = mx.ones((self.q_num_heads, self.qk_dim))
|
self.q_weight = mx.ones((self.q_num_heads, self.qk_dim))
|
||||||
self.k_weight = mx.ones((self.k_num_heads, self.qk_dim))
|
self.k_weight = mx.ones((self.k_num_heads, self.qk_dim))
|
||||||
|
|
||||||
self.rotary_emb = RotaryEmbedding(self.qk_dim, max_position_embeddings=self.config.attention_window_size)
|
self.rotary_emb = RotaryEmbedding(
|
||||||
|
self.qk_dim, max_position_embeddings=self.config.attention_window_size
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -1033,21 +1122,33 @@ class Attention(nn.Module):
|
|||||||
query_states, key_states, value_states = mx.split(
|
query_states, key_states, value_states = mx.split(
|
||||||
qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1
|
qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1
|
||||||
)
|
)
|
||||||
query_states = query_states.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
query_states = query_states.reshape(
|
||||||
key_states = key_states.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
bsz, q_len, self.q_num_heads, self.qk_dim
|
||||||
value_states = value_states.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
|
).transpose(0, 2, 1, 3)
|
||||||
|
key_states = key_states.reshape(
|
||||||
|
bsz, q_len, self.k_num_heads, self.qk_dim
|
||||||
|
).transpose(0, 2, 1, 3)
|
||||||
|
value_states = value_states.reshape(
|
||||||
|
bsz, q_len, self.v_num_heads, self.v_dim
|
||||||
|
).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
attn_dtype = query_states.dtype
|
attn_dtype = query_states.dtype
|
||||||
|
|
||||||
query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
|
query_states = (
|
||||||
|
_rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
|
||||||
|
)
|
||||||
key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
|
key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
|
||||||
|
|
||||||
if past_states is not None:
|
if past_states is not None:
|
||||||
# reuse k, v, self_attention
|
# reuse k, v, self_attention
|
||||||
key_states_new = key_states
|
key_states_new = key_states
|
||||||
value_states_new = value_states
|
value_states_new = value_states
|
||||||
key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) # type: ignore
|
key_states, value_states = past_states.append_kv(
|
||||||
past_states.update_attention(key_states_new, value_states_new, self.layer_idx)
|
key_states, value_states, self.layer_idx
|
||||||
|
) # type: ignore
|
||||||
|
past_states.update_attention(
|
||||||
|
key_states_new, value_states_new, self.layer_idx
|
||||||
|
)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
position_ids = mx.arange(kv_seq_len, dtype=mx.int64)[None]
|
position_ids = mx.arange(kv_seq_len, dtype=mx.int64)[None]
|
||||||
@ -1082,7 +1183,9 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if attention_mask.dtype == bool:
|
if attention_mask.dtype == bool:
|
||||||
attention_mask = mx.where(attention_mask, mx.array(0.0, dtype=mx.float16), float("-inf"))
|
attention_mask = mx.where(
|
||||||
|
attention_mask, mx.array(0.0, dtype=mx.float16), float("-inf")
|
||||||
|
)
|
||||||
if len(attention_mask.shape) == 2:
|
if len(attention_mask.shape) == 2:
|
||||||
attention_mask = attention_mask[None, None]
|
attention_mask = attention_mask[None, None]
|
||||||
assert len(attention_mask.shape) == 4
|
assert len(attention_mask.shape) == 4
|
||||||
@ -1095,14 +1198,20 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
# `generate` function creates attention mask that does not consider sliding window
|
# `generate` function creates attention mask that does not consider sliding window
|
||||||
m_swa = m_swa[None, None]
|
m_swa = m_swa[None, None]
|
||||||
attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :]
|
attention_mask = attention_mask[
|
||||||
|
:, :, -query_states.shape[2] :, -key_states.shape[2] :
|
||||||
|
]
|
||||||
attention_mask = mx.where(m_swa, attention_mask, float("-inf"))
|
attention_mask = mx.where(m_swa, attention_mask, float("-inf"))
|
||||||
|
|
||||||
# like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers,
|
# like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers,
|
||||||
# we need to attend to all tokens in masked rows for `scaled_dot_product_attention`
|
# we need to attend to all tokens in masked rows for `scaled_dot_product_attention`
|
||||||
bool_mask = mx.logical_not(mx.isneginf(attention_mask))
|
bool_mask = mx.logical_not(mx.isneginf(attention_mask))
|
||||||
valid_tokens = mx.sum(bool_mask, axis=-1).astype(mx.bool_) # (..., q_len) # type: ignore
|
valid_tokens = mx.sum(bool_mask, axis=-1).astype(
|
||||||
attention_mask = mx.where(valid_tokens[..., None], attention_mask, float(0.0))
|
mx.bool_
|
||||||
|
) # (..., q_len) # type: ignore
|
||||||
|
attention_mask = mx.where(
|
||||||
|
valid_tokens[..., None], attention_mask, float(0.0)
|
||||||
|
)
|
||||||
attn_output = mx.fast.scaled_dot_product_attention(
|
attn_output = mx.fast.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
@ -1137,7 +1246,9 @@ class Mamba(nn.Module):
|
|||||||
|
|
||||||
self.intermediate_size = self.num_heads * self.hidden_size_per_head
|
self.intermediate_size = self.num_heads * self.hidden_size_per_head
|
||||||
|
|
||||||
self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
self.in_proj = nn.Linear(
|
||||||
|
self.hidden_size, 2 * self.intermediate_size, bias=False
|
||||||
|
)
|
||||||
self.conv1d = nn.Conv1d(
|
self.conv1d = nn.Conv1d(
|
||||||
in_channels=self.intermediate_size,
|
in_channels=self.intermediate_size,
|
||||||
out_channels=self.intermediate_size,
|
out_channels=self.intermediate_size,
|
||||||
@ -1235,14 +1346,18 @@ class Mamba(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# conv
|
# conv
|
||||||
x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length)
|
x = x.reshape(bsize, length, -1).transpose(
|
||||||
|
0, 2, 1
|
||||||
|
) # (bsize, intermediate_size, length)
|
||||||
if bool_mask is not None:
|
if bool_mask is not None:
|
||||||
x = mx.where(bool_mask[:, None, :], x, 0.0)
|
x = mx.where(bool_mask[:, None, :], x, 0.0)
|
||||||
if is_update:
|
if is_update:
|
||||||
assert conv_state is not None
|
assert conv_state is not None
|
||||||
x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
|
x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
|
||||||
else:
|
else:
|
||||||
x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
|
x, conv_state = _causal_conv1d(
|
||||||
|
conv_state, self.conv1d.weight, x, seq_idx=seq_idx
|
||||||
|
)
|
||||||
x = x.astype(hidden_states.dtype)
|
x = x.astype(hidden_states.dtype)
|
||||||
x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size)
|
x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size)
|
||||||
x = x.reshape(bsize, length, -1)
|
x = x.reshape(bsize, length, -1)
|
||||||
@ -1257,9 +1372,18 @@ class Mamba(nn.Module):
|
|||||||
C = C[:, :, None, :]
|
C = C[:, :, None, :]
|
||||||
|
|
||||||
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
|
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
|
||||||
dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
|
dt = (
|
||||||
B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
|
_rms_norm(dt, None, self.config.rms_norm_eps)
|
||||||
C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
|
* self.dt_norm_weight[None, None, :]
|
||||||
|
)
|
||||||
|
B = (
|
||||||
|
_rms_norm(B, None, self.config.rms_norm_eps)
|
||||||
|
* self.B_norm_weight[None, None, None, :]
|
||||||
|
)
|
||||||
|
C = (
|
||||||
|
_rms_norm(C, None, self.config.rms_norm_eps)
|
||||||
|
* self.C_norm_weight[None, None, None, :]
|
||||||
|
)
|
||||||
|
|
||||||
# (bsize, length, num_heads, 1)
|
# (bsize, length, num_heads, 1)
|
||||||
dt = self.dt_proj(dt)[..., None]
|
dt = self.dt_proj(dt)[..., None]
|
||||||
@ -1335,7 +1459,9 @@ class MLP(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
self.gate_up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
|
self.gate_up_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.intermediate_size * 2, bias=False
|
||||||
|
)
|
||||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
@ -1359,10 +1485,18 @@ class PlamoDecoderLayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Notes: The model performance was degraded when setting all offsets to 1.
|
Notes: The model performance was degraded when setting all offsets to 1.
|
||||||
"""
|
"""
|
||||||
self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
|
self.pre_mixer_norm = RMSNorm(
|
||||||
self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5)
|
config.hidden_size, eps=config.rms_norm_eps, offset=1.0
|
||||||
self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
|
)
|
||||||
self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5))
|
self.post_mixer_norm = RMSNorm(
|
||||||
|
config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5
|
||||||
|
)
|
||||||
|
self.pre_mlp_norm = RMSNorm(
|
||||||
|
config.hidden_size, eps=config.rms_norm_eps, offset=1.0
|
||||||
|
)
|
||||||
|
self.post_mlp_norm = RMSNorm(
|
||||||
|
config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5)
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -1423,8 +1557,12 @@ class PlamoDecoder(nn.Module):
|
|||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def __call__(self, x: DecoderInput) -> DecoderOutput:
|
def __call__(self, x: DecoderInput) -> DecoderOutput:
|
||||||
all_hidden_states: Optional[Tuple[mx.array, ...]] = () if x.output_hidden_states else None
|
all_hidden_states: Optional[Tuple[mx.array, ...]] = (
|
||||||
all_self_attns: Optional[Tuple[mx.array, ...]] = () if x.output_attentions else None
|
() if x.output_hidden_states else None
|
||||||
|
)
|
||||||
|
all_self_attns: Optional[Tuple[mx.array, ...]] = (
|
||||||
|
() if x.output_attentions else None
|
||||||
|
)
|
||||||
hidden_states = x.hidden_states
|
hidden_states = x.hidden_states
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
for decoder_layer in self.layers:
|
||||||
@ -1579,9 +1717,13 @@ class PlamoModel(PlamoPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
assert inputs_embeds is not None
|
assert inputs_embeds is not None
|
||||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
expanded_attn_mask = _expand_mask(
|
||||||
|
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||||
|
)
|
||||||
combined_attention_mask = (
|
combined_attention_mask = (
|
||||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
expanded_attn_mask
|
||||||
|
if combined_attention_mask is None
|
||||||
|
else expanded_attn_mask + combined_attention_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
return combined_attention_mask
|
return combined_attention_mask
|
||||||
@ -1600,21 +1742,33 @@ class PlamoModel(PlamoPreTrainedModel):
|
|||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
assert input_ids is not None
|
assert input_ids is not None
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
raise ValueError(
|
||||||
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||||
|
)
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
raise ValueError(
|
||||||
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||||
|
)
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
seq_length_with_past = seq_length
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
@ -1781,11 +1935,19 @@ class Model(PlamoPreTrainedModel):
|
|||||||
```"""
|
```"""
|
||||||
assert input_ids is not None
|
assert input_ids is not None
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = (
|
||||||
output_hidden_states = (
|
output_attentions
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
|
@ -19,7 +19,6 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
List,
|
List,
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
@ -44,7 +43,6 @@ from transformers import PreTrainedTokenizer
|
|||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .models import cache
|
from .models import cache
|
||||||
from .sample_utils import make_logits_processors, make_sampler
|
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
from .tuner.utils import load_adapters, nparams
|
from .tuner.utils import load_adapters, nparams
|
||||||
@ -721,6 +719,11 @@ def load_model(
|
|||||||
weights = {}
|
weights = {}
|
||||||
for wf in weight_files:
|
for wf in weight_files:
|
||||||
weights.update(mx.load(wf))
|
weights.update(mx.load(wf))
|
||||||
|
if "lm_head.weight" not in weights:
|
||||||
|
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||||
|
for k in weights.keys():
|
||||||
|
if "conv1d.weight" in k:
|
||||||
|
weights[k] = weights[k].transpose(0, 2, 1)
|
||||||
|
|
||||||
model_class, model_args_class = get_model_classes(config=config)
|
model_class, model_args_class = get_model_classes(config=config)
|
||||||
|
|
||||||
@ -1050,6 +1053,12 @@ def convert(
|
|||||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||||
|
|
||||||
weights = dict(tree_flatten(model.parameters()))
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
|
if "lm_head.weight" not in weights:
|
||||||
|
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||||
|
for k in weights.keys():
|
||||||
|
if "conv1d.weight" in k:
|
||||||
|
weights[k] = weights[k].transpose(0, 2, 1)
|
||||||
|
|
||||||
dtype = getattr(mx, dtype)
|
dtype = getattr(mx, dtype)
|
||||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user