mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 01:46:09 +08:00
Fix channel first weights to channel last for right use of MLX's conv1d
This commit is contained in:
parent
81917d41d5
commit
66dd97ed3d
@ -285,7 +285,7 @@ class PlamoCache(nn.Module):
|
|||||||
assert layer_idx < len(self.cache)
|
assert layer_idx < len(self.cache)
|
||||||
layer_cache = self.cache[layer_idx]
|
layer_cache = self.cache[layer_idx]
|
||||||
return layer_cache # type: ignore
|
return layer_cache # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
return self.cache
|
return self.cache
|
||||||
@ -528,6 +528,7 @@ def ssd_update_state(
|
|||||||
dt_softplus: bool,
|
dt_softplus: bool,
|
||||||
) -> tuple[mx.array, mx.array]:
|
) -> tuple[mx.array, mx.array]:
|
||||||
assert ssm_state.dtype == mx.float32
|
assert ssm_state.dtype == mx.float32
|
||||||
|
dtype = x.dtype
|
||||||
|
|
||||||
hidden_size_per_head = x.shape[-1]
|
hidden_size_per_head = x.shape[-1]
|
||||||
d_state = B.shape[-1]
|
d_state = B.shape[-1]
|
||||||
@ -535,17 +536,16 @@ def ssd_update_state(
|
|||||||
dt = mx.broadcast_to(dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head))
|
dt = mx.broadcast_to(dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head))
|
||||||
dt_bias = mx.broadcast_to(dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head))
|
dt_bias = mx.broadcast_to(dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head))
|
||||||
D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head))
|
D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head))
|
||||||
assert ssm_state.dtype == mx.float32
|
|
||||||
out, ssm_state = selective_state_update_ref(
|
out, ssm_state = selective_state_update_ref(
|
||||||
ssm_state,
|
ssm_state,
|
||||||
x,
|
x.astype(dtype),
|
||||||
dt,
|
dt.astype(dtype),
|
||||||
A,
|
A.astype(mx.float32),
|
||||||
B,
|
B.astype(dtype),
|
||||||
C,
|
C.astype(dtype),
|
||||||
D,
|
D.astype(mx.float32),
|
||||||
z,
|
z.astype(dtype),
|
||||||
dt_bias,
|
dt_bias.astype(mx.float32),
|
||||||
dt_softplus=dt_softplus,
|
dt_softplus=dt_softplus,
|
||||||
)
|
)
|
||||||
return out[:, None], ssm_state
|
return out[:, None], ssm_state
|
||||||
@ -601,12 +601,14 @@ def ssd_chunk_scan_combined(
|
|||||||
z: mx.array,
|
z: mx.array,
|
||||||
dt_bias: mx.array,
|
dt_bias: mx.array,
|
||||||
dt_softplus: bool,
|
dt_softplus: bool,
|
||||||
|
return_final_states: bool,
|
||||||
seq_idx: mx.array | None,
|
seq_idx: mx.array | None,
|
||||||
ssm_state: mx.array | None,
|
ssm_state: mx.array | None,
|
||||||
) -> tuple[mx.array, mx.array]:
|
) -> tuple[mx.array, mx.array] | mx.array:
|
||||||
if seq_idx is not None:
|
if seq_idx is not None:
|
||||||
assert seq_idx.dtype == mx.int32
|
assert seq_idx.dtype == mx.int32
|
||||||
assert ssm_state is None
|
assert ssm_state is None
|
||||||
|
assert not return_final_states
|
||||||
if ssm_state is not None:
|
if ssm_state is not None:
|
||||||
assert ssm_state.dtype == mx.float32
|
assert ssm_state.dtype == mx.float32
|
||||||
assert seq_idx is None
|
assert seq_idx is None
|
||||||
@ -637,7 +639,10 @@ def ssd_chunk_scan_combined(
|
|||||||
seq_idx=seq_idx,
|
seq_idx=seq_idx,
|
||||||
ssm_state=ssm_state,
|
ssm_state=ssm_state,
|
||||||
)
|
)
|
||||||
return tmp, ssm_state
|
if return_final_states:
|
||||||
|
return tmp, ssm_state
|
||||||
|
else:
|
||||||
|
return tmp
|
||||||
|
|
||||||
|
|
||||||
def _causal_conv1d(
|
def _causal_conv1d(
|
||||||
@ -650,7 +655,7 @@ def _causal_conv1d(
|
|||||||
if seq_idx is not None:
|
if seq_idx is not None:
|
||||||
assert seq_idx.dtype == mx.int32
|
assert seq_idx.dtype == mx.int32
|
||||||
assert conv_state is None
|
assert conv_state is None
|
||||||
weight = weight.transpose(0, 2, 1).astype(dtype)
|
weight = weight.astype(dtype)
|
||||||
x = x.astype(dtype)
|
x = x.astype(dtype)
|
||||||
|
|
||||||
return_final_states = conv_state is not None
|
return_final_states = conv_state is not None
|
||||||
@ -664,7 +669,7 @@ def _causal_conv1d(
|
|||||||
for i in range(length):
|
for i in range(length):
|
||||||
if i != 0 and seq_idx is not None:
|
if i != 0 and seq_idx is not None:
|
||||||
conv_state = mx.where(
|
conv_state = mx.where(
|
||||||
mx.array(seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None],
|
seq_idx[:, i - 1][:, None, None] != seq_idx[:, i][:, None, None],
|
||||||
mx.zeros_like(conv_state),
|
mx.zeros_like(conv_state),
|
||||||
conv_state,
|
conv_state,
|
||||||
)
|
)
|
||||||
@ -738,24 +743,12 @@ def _causal_conv1d_update(conv_state: mx.array, weight: mx.array, xBC: mx.array)
|
|||||||
x, conv_state = causal_conv1d_update(
|
x, conv_state = causal_conv1d_update(
|
||||||
x=xBC,
|
x=xBC,
|
||||||
conv_state=conv_state,
|
conv_state=conv_state,
|
||||||
weight=weight[:, 0, :],
|
weight=weight[:, :, 0],
|
||||||
activation="silu",
|
activation="silu",
|
||||||
)
|
)
|
||||||
return x, conv_state
|
return x, conv_state
|
||||||
|
|
||||||
|
|
||||||
def is_mamba(config: ModelArgs, i: int) -> bool:
|
|
||||||
if not config.mamba_enabled:
|
|
||||||
return False
|
|
||||||
assert config.mamba_step > 1
|
|
||||||
assert i < config.num_hidden_layers
|
|
||||||
|
|
||||||
if config.num_hidden_layers <= (config.mamba_step // 2):
|
|
||||||
# use attention in last layer
|
|
||||||
return i != config.num_hidden_layers - 1
|
|
||||||
return (i % config.mamba_step) != (config.mamba_step // 2)
|
|
||||||
|
|
||||||
|
|
||||||
# Based on: https://github.com/Dao-AILab/causal-conv1d/blob/82867a9d2e6907cc0f637ac6aff318f696838548/causal_conv1d/causal_conv1d_interface.py#L206
|
# Based on: https://github.com/Dao-AILab/causal-conv1d/blob/82867a9d2e6907cc0f637ac6aff318f696838548/causal_conv1d/causal_conv1d_interface.py#L206
|
||||||
def causal_conv1d(x, weight, bias=None, activation=None):
|
def causal_conv1d(x, weight, bias=None, activation=None):
|
||||||
"""
|
"""
|
||||||
@ -808,6 +801,213 @@ def causal_conv1d(x, weight, bias=None, activation=None):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class Mamba(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs, layer_idx: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.d_state = config.mamba_d_state
|
||||||
|
self.d_conv = config.mamba_d_conv
|
||||||
|
self.chunk_size = config.mamba_chunk_size
|
||||||
|
self.num_heads = config.mamba_num_heads
|
||||||
|
# TODO add mamba_hidden_size_per_head config (?)
|
||||||
|
self.hidden_size_per_head = config.hidden_size_per_head
|
||||||
|
|
||||||
|
self.intermediate_size = self.num_heads * self.hidden_size_per_head
|
||||||
|
|
||||||
|
self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
||||||
|
self.conv1d = nn.Conv1d(
|
||||||
|
in_channels=self.intermediate_size,
|
||||||
|
out_channels=self.intermediate_size,
|
||||||
|
bias=False, # TODO the original implementation uses bias
|
||||||
|
kernel_size=self.d_conv,
|
||||||
|
groups=self.intermediate_size,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
self.dt_dim = max(64, self.hidden_size // 16)
|
||||||
|
# Notes:
|
||||||
|
# Mamba2 removes this linear projection for simplicity (Figure 6 in the paper),
|
||||||
|
# but it may degrade the ability of content-length extrapolation.
|
||||||
|
self.bcdt_proj = nn.Linear(
|
||||||
|
self.intermediate_size,
|
||||||
|
self.dt_dim + 2 * self.d_state,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False)
|
||||||
|
|
||||||
|
self.dt_bias = get_initial_dt_bias(self.num_heads)
|
||||||
|
self.A_log = get_initial_A(self.num_heads)
|
||||||
|
self.D = mx.ones(self.num_heads, dtype=mx.float32)
|
||||||
|
|
||||||
|
# TODO norm weight before gating like Mamba2
|
||||||
|
self.dt_norm_weight = mx.ones(self.dt_dim)
|
||||||
|
self.B_norm_weight = mx.ones(self.d_state)
|
||||||
|
self.C_norm_weight = mx.ones(self.d_state)
|
||||||
|
|
||||||
|
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
|
|
||||||
|
def _no_weight_decay_param_names(self) -> set[str]:
|
||||||
|
return set(["D", "dt_bias", "A_log"])
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states: mx.array,
|
||||||
|
attention_mask: Optional[mx.array] = None,
|
||||||
|
past_states: Optional[PlamoCache] = None,
|
||||||
|
) -> tuple[mx.array, Optional[PlamoCache]]:
|
||||||
|
bsize, length, _ = hidden_states.shape
|
||||||
|
is_update = length == 1 and past_states is not None
|
||||||
|
|
||||||
|
bool_mask: mx.array | None = None
|
||||||
|
seq_idx: mx.array | None = None
|
||||||
|
if attention_mask is not None:
|
||||||
|
if len(attention_mask.shape) == 2:
|
||||||
|
attention_mask = mx.broadcast_to(
|
||||||
|
attention_mask[None, None],
|
||||||
|
(bsize, 1, attention_mask.shape[0], attention_mask.shape[1]),
|
||||||
|
)
|
||||||
|
assert len(attention_mask.shape) == 4
|
||||||
|
|
||||||
|
if past_states is None:
|
||||||
|
# TODO: support seq_idx with cache
|
||||||
|
bool_mask_4d = mx.array(attention_mask == 0, dtype=mx.bool_) # type: ignore
|
||||||
|
is_first_token = _is_first_token(bool_mask_4d)[:, 0, :]
|
||||||
|
seq_idx = mx.cumsum(is_first_token, axis=-1) - 1
|
||||||
|
seq_idx = seq_idx.astype(mx.int32)
|
||||||
|
|
||||||
|
# `generate` function creates attention mask that contains past tokens,
|
||||||
|
# but mamba does not use them
|
||||||
|
attention_mask = attention_mask[:, 0, -length:, -length:]
|
||||||
|
bool_mask = mx.array(mx.diagonal(attention_mask, axis1=-2, axis2=-1) == 0)
|
||||||
|
|
||||||
|
conv_state: mx.array | None
|
||||||
|
ssm_state: mx.array | None
|
||||||
|
if past_states is None:
|
||||||
|
conv_state = None
|
||||||
|
ssm_state = None
|
||||||
|
elif past_states[self.layer_idx] is None:
|
||||||
|
conv_state = mx.zeros(
|
||||||
|
(bsize, self.intermediate_size, self.d_conv - 1),
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
ssm_state = mx.zeros(
|
||||||
|
(bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
|
||||||
|
dtype=mx.float32,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
c = past_states[self.layer_idx]
|
||||||
|
assert isinstance(c, PlamoMambaCache)
|
||||||
|
conv_state = c.conv_state
|
||||||
|
ssm_state = c.ssm_state
|
||||||
|
|
||||||
|
zx = self.in_proj(hidden_states)
|
||||||
|
zx = zx.reshape(bsize, length, self.num_heads, -1)
|
||||||
|
# z: (bsize, length, num_heads, hidden_size_per_head)
|
||||||
|
# x: (bsize, length, num_heads, hidden_size_per_head)
|
||||||
|
z, x = mx.split(
|
||||||
|
zx,
|
||||||
|
[
|
||||||
|
self.hidden_size_per_head,
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# conv
|
||||||
|
x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length)
|
||||||
|
if bool_mask is not None:
|
||||||
|
x = mx.where(bool_mask[:, None, :], x, 0.0)
|
||||||
|
if is_update:
|
||||||
|
assert conv_state is not None
|
||||||
|
x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
|
||||||
|
else:
|
||||||
|
x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
|
||||||
|
x = x.astype(hidden_states.dtype)
|
||||||
|
x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size)
|
||||||
|
x = x.reshape(bsize, length, -1)
|
||||||
|
# x: (bsize, length, num_heads, hidden_size_per_head)
|
||||||
|
# B: (bsize, length, 1, d_state)
|
||||||
|
# C: (bsize, length, 1, d_state)
|
||||||
|
# dt: (bsize, length, dt_dim)
|
||||||
|
BCdt = self.bcdt_proj(x)
|
||||||
|
x = x.reshape(bsize, length, self.num_heads, -1)
|
||||||
|
B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1)
|
||||||
|
B = B[:, :, None, :]
|
||||||
|
C = C[:, :, None, :]
|
||||||
|
|
||||||
|
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
|
||||||
|
dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
|
||||||
|
B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
|
||||||
|
C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
|
||||||
|
|
||||||
|
# (bsize, length, num_heads, 1)
|
||||||
|
dt = self.dt_proj(dt)[..., None]
|
||||||
|
|
||||||
|
# TODO it may not be required
|
||||||
|
B = mx.broadcast_to(B, (B.shape[0], B.shape[1], self.num_heads, B.shape[3]))
|
||||||
|
C = mx.broadcast_to(C, (C.shape[0], C.shape[1], self.num_heads, C.shape[3]))
|
||||||
|
|
||||||
|
if bool_mask is not None:
|
||||||
|
"""
|
||||||
|
state will be updates by following:
|
||||||
|
```
|
||||||
|
dt = softplus(dt)
|
||||||
|
dA = exp(dt * A)
|
||||||
|
state_next = state * dA + dB * x
|
||||||
|
```
|
||||||
|
To avoid updating state, we set dt to -inf and x to 0
|
||||||
|
because `softplus(-inf) = 0` and `exp(0) = 1`
|
||||||
|
"""
|
||||||
|
dt = mx.where(bool_mask[:, :, None, None], dt, float("-inf"))
|
||||||
|
x = mx.where(bool_mask[:, :, None, None], x, 0.0)
|
||||||
|
|
||||||
|
# ssm
|
||||||
|
if is_update:
|
||||||
|
assert ssm_state is not None
|
||||||
|
out, ssm_state = ssd_update_state(
|
||||||
|
ssm_state,
|
||||||
|
x[:, 0],
|
||||||
|
dt[:, 0].reshape(bsize, -1),
|
||||||
|
A,
|
||||||
|
B[:, 0],
|
||||||
|
C[:, 0],
|
||||||
|
D=self.D,
|
||||||
|
z=z[:, 0],
|
||||||
|
dt_bias=self.dt_bias,
|
||||||
|
dt_softplus=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tmp = ssd_chunk_scan_combined(
|
||||||
|
x,
|
||||||
|
dt.reshape(bsize, length, -1),
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
self.chunk_size,
|
||||||
|
D=self.D,
|
||||||
|
z=z,
|
||||||
|
dt_bias=self.dt_bias,
|
||||||
|
dt_softplus=True,
|
||||||
|
return_final_states=past_states is not None,
|
||||||
|
seq_idx=seq_idx,
|
||||||
|
ssm_state=ssm_state,
|
||||||
|
)
|
||||||
|
if past_states is not None:
|
||||||
|
out, ssm_state = tmp
|
||||||
|
else:
|
||||||
|
assert isinstance(tmp, mx.array)
|
||||||
|
out = tmp
|
||||||
|
|
||||||
|
y = self.out_proj(out.reshape(bsize, length, -1))
|
||||||
|
|
||||||
|
if past_states is not None:
|
||||||
|
assert ssm_state is not None
|
||||||
|
assert conv_state is not None
|
||||||
|
past_states.update_mamba(conv_state, ssm_state, self.layer_idx)
|
||||||
|
|
||||||
|
return y, past_states
|
||||||
|
|
||||||
|
|
||||||
def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array:
|
def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array:
|
||||||
max_len = max(q_len, kv_len)
|
max_len = max(q_len, kv_len)
|
||||||
mask = mx.tril(
|
mask = mx.tril(
|
||||||
@ -950,209 +1150,6 @@ class Attention(nn.Module):
|
|||||||
return attn_output, attn_weights, past_states
|
return attn_output, attn_weights, past_states
|
||||||
|
|
||||||
|
|
||||||
class Mamba(nn.Module):
|
|
||||||
def __init__(self, config: ModelArgs, layer_idx: int) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.layer_idx = layer_idx
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.d_state = config.mamba_d_state
|
|
||||||
self.d_conv = config.mamba_d_conv
|
|
||||||
self.chunk_size = config.mamba_chunk_size
|
|
||||||
self.num_heads = config.mamba_num_heads
|
|
||||||
# TODO add mamba_hidden_size_per_head config (?)
|
|
||||||
self.hidden_size_per_head = config.hidden_size_per_head
|
|
||||||
|
|
||||||
self.intermediate_size = self.num_heads * self.hidden_size_per_head
|
|
||||||
|
|
||||||
self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
|
||||||
self.conv1d = nn.Conv1d(
|
|
||||||
in_channels=self.intermediate_size,
|
|
||||||
out_channels=self.intermediate_size,
|
|
||||||
bias=False, # TODO the original implementation uses bias
|
|
||||||
kernel_size=self.d_conv,
|
|
||||||
groups=self.intermediate_size,
|
|
||||||
padding=0,
|
|
||||||
)
|
|
||||||
self.dt_dim = max(64, self.hidden_size // 16)
|
|
||||||
# Notes:
|
|
||||||
# Mamba2 removes this linear projection for simplicity (Figure 6 in the paper),
|
|
||||||
# but it may degrade the ability of content-length extrapolation.
|
|
||||||
self.bcdt_proj = nn.Linear(
|
|
||||||
self.intermediate_size,
|
|
||||||
self.dt_dim + 2 * self.d_state,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False)
|
|
||||||
|
|
||||||
self.dt_bias = get_initial_dt_bias(self.num_heads)
|
|
||||||
self.A_log = get_initial_A(self.num_heads)
|
|
||||||
self.D = mx.ones(self.num_heads, dtype=mx.float32)
|
|
||||||
|
|
||||||
# TODO norm weight before gating like Mamba2
|
|
||||||
self.dt_norm_weight = mx.ones(self.dt_dim)
|
|
||||||
self.B_norm_weight = mx.ones(self.d_state)
|
|
||||||
self.C_norm_weight = mx.ones(self.d_state)
|
|
||||||
|
|
||||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
||||||
|
|
||||||
def _no_weight_decay_param_names(self) -> set[str]:
|
|
||||||
return set(["D", "dt_bias", "A_log"])
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
hidden_states: mx.array,
|
|
||||||
attention_mask: Optional[mx.array] = None,
|
|
||||||
past_states: Optional[PlamoCache] = None,
|
|
||||||
) -> tuple[mx.array, Optional[PlamoCache]]:
|
|
||||||
bsize, length, _ = hidden_states.shape
|
|
||||||
is_update = length == 1 and past_states is not None
|
|
||||||
|
|
||||||
bool_mask: mx.array | None = None
|
|
||||||
seq_idx: mx.array | None = None
|
|
||||||
if attention_mask is not None:
|
|
||||||
if len(attention_mask.shape) == 2:
|
|
||||||
attention_mask = mx.broadcast_to(
|
|
||||||
attention_mask[None, None],
|
|
||||||
(bsize, 1, attention_mask.shape[0], attention_mask.shape[1]),
|
|
||||||
)
|
|
||||||
assert len(attention_mask.shape) == 4
|
|
||||||
|
|
||||||
if past_states is None:
|
|
||||||
# TODO: support seq_idx with cache
|
|
||||||
bool_mask_4d = mx.array(attention_mask == 0, dtype=mx.bool_) # type: ignore
|
|
||||||
is_first_token = _is_first_token(bool_mask_4d)[:, 0, :]
|
|
||||||
seq_idx = mx.cumsum(is_first_token, axis=-1) - 1
|
|
||||||
seq_idx = seq_idx.astype(mx.int32)
|
|
||||||
|
|
||||||
# `generate` function creates attention mask that contains past tokens,
|
|
||||||
# but mamba does not use them
|
|
||||||
attention_mask = attention_mask[:, 0, -length:, -length:]
|
|
||||||
bool_mask = mx.array(mx.diagonal(attention_mask, axis1=-2, axis2=-1) == 0)
|
|
||||||
|
|
||||||
conv_state: mx.array | None
|
|
||||||
ssm_state: mx.array | None
|
|
||||||
if past_states is None:
|
|
||||||
conv_state = None
|
|
||||||
ssm_state = None
|
|
||||||
elif past_states[self.layer_idx] is None:
|
|
||||||
conv_state = mx.zeros(
|
|
||||||
(bsize, self.intermediate_size, self.d_conv - 1),
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
ssm_state = mx.zeros(
|
|
||||||
(bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
|
|
||||||
dtype=mx.float32,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
c = past_states[self.layer_idx]
|
|
||||||
assert isinstance(c, PlamoMambaCache)
|
|
||||||
conv_state = c.conv_state
|
|
||||||
ssm_state = c.ssm_state
|
|
||||||
|
|
||||||
zx = self.in_proj(hidden_states)
|
|
||||||
zx = zx.reshape(bsize, length, self.num_heads, -1).astype(mx.float32)
|
|
||||||
# z: (bsize, length, num_heads, hidden_size_per_head)
|
|
||||||
# x: (bsize, length, num_heads, hidden_size_per_head)
|
|
||||||
z, x = mx.split(
|
|
||||||
zx,
|
|
||||||
[
|
|
||||||
self.hidden_size_per_head,
|
|
||||||
],
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# conv
|
|
||||||
x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length)
|
|
||||||
if bool_mask is not None:
|
|
||||||
x = mx.where(bool_mask[:, None, :], x, 0.0)
|
|
||||||
if is_update:
|
|
||||||
assert conv_state is not None
|
|
||||||
x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
|
|
||||||
else:
|
|
||||||
x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
|
|
||||||
x = x.astype(mx.float32)
|
|
||||||
x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size)
|
|
||||||
x = x.reshape(bsize, length, -1)
|
|
||||||
# x: (bsize, length, num_heads, hidden_size_per_head)
|
|
||||||
# B: (bsize, length, 1, d_state)
|
|
||||||
# C: (bsize, length, 1, d_state)
|
|
||||||
# dt: (bsize, length, dt_dim)
|
|
||||||
BCdt = self.bcdt_proj(x).astype(mx.float32)
|
|
||||||
x = x.reshape(bsize, length, self.num_heads, -1)
|
|
||||||
B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1)
|
|
||||||
B = B[:, :, None, :]
|
|
||||||
C = C[:, :, None, :]
|
|
||||||
|
|
||||||
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
|
|
||||||
dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
|
|
||||||
B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
|
|
||||||
C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
|
|
||||||
|
|
||||||
# (bsize, length, num_heads, 1)
|
|
||||||
dt = self.dt_proj(dt)[..., None]
|
|
||||||
|
|
||||||
# TODO it may not be required
|
|
||||||
B = mx.broadcast_to(B, (B.shape[0], B.shape[1], self.num_heads, B.shape[3]))
|
|
||||||
C = mx.broadcast_to(C, (C.shape[0], C.shape[1], self.num_heads, C.shape[3]))
|
|
||||||
|
|
||||||
if bool_mask is not None:
|
|
||||||
"""
|
|
||||||
state will be updates by following:
|
|
||||||
```
|
|
||||||
dt = softplus(dt)
|
|
||||||
dA = exp(dt * A)
|
|
||||||
state_next = state * dA + dB * x
|
|
||||||
```
|
|
||||||
To avoid updating state, we set dt to -inf and x to 0
|
|
||||||
because `softplus(-inf) = 0` and `exp(0) = 1`
|
|
||||||
"""
|
|
||||||
dt = mx.where(bool_mask[:, :, None, None], dt, float("-inf"))
|
|
||||||
x = mx.where(bool_mask[:, :, None, None], x, 0.0)
|
|
||||||
|
|
||||||
# ssm
|
|
||||||
self.D = self.D.astype(mx.float32)
|
|
||||||
self.dt_bias = self.dt_bias.astype(mx.float32)
|
|
||||||
if is_update:
|
|
||||||
assert ssm_state is not None
|
|
||||||
out, ssm_state = ssd_update_state(
|
|
||||||
ssm_state,
|
|
||||||
x[:, 0],
|
|
||||||
dt[:, 0].reshape(bsize, -1),
|
|
||||||
A,
|
|
||||||
B[:, 0],
|
|
||||||
C[:, 0],
|
|
||||||
D=self.D,
|
|
||||||
z=z[:, 0],
|
|
||||||
dt_bias=self.dt_bias,
|
|
||||||
dt_softplus=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
out, ssm_state = ssd_chunk_scan_combined(
|
|
||||||
x,
|
|
||||||
dt.reshape(bsize, length, -1),
|
|
||||||
A,
|
|
||||||
B,
|
|
||||||
C,
|
|
||||||
self.chunk_size,
|
|
||||||
D=self.D,
|
|
||||||
z=z,
|
|
||||||
dt_bias=self.dt_bias,
|
|
||||||
dt_softplus=True,
|
|
||||||
seq_idx=seq_idx,
|
|
||||||
ssm_state=ssm_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
y = self.out_proj(out.reshape(bsize, length, -1))
|
|
||||||
|
|
||||||
if past_states is not None:
|
|
||||||
assert ssm_state is not None
|
|
||||||
assert conv_state is not None
|
|
||||||
past_states.update_mamba(conv_state, ssm_state, self.layer_idx)
|
|
||||||
|
|
||||||
return y, past_states
|
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, config: ModelArgs) -> None:
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1236,6 +1233,18 @@ class PlamoDecoderLayer(nn.Module):
|
|||||||
return outputs # type: ignore
|
return outputs # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def is_mamba(config: ModelArgs, i: int) -> bool:
|
||||||
|
if not config.mamba_enabled:
|
||||||
|
return False
|
||||||
|
assert config.mamba_step > 1
|
||||||
|
assert i < config.num_hidden_layers
|
||||||
|
|
||||||
|
if config.num_hidden_layers <= (config.mamba_step // 2):
|
||||||
|
# use attention in last layer
|
||||||
|
return i != config.num_hidden_layers - 1
|
||||||
|
return (i % config.mamba_step) != (config.mamba_step // 2)
|
||||||
|
|
||||||
|
|
||||||
class PlamoDecoder(nn.Module):
|
class PlamoDecoder(nn.Module):
|
||||||
def __init__(self, config: ModelArgs) -> None:
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1545,6 +1554,7 @@ class PlamoModel(PlamoPreTrainedModel):
|
|||||||
self.gradient_checkpointing,
|
self.gradient_checkpointing,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(out, DecoderOutput)
|
assert isinstance(out, DecoderOutput)
|
||||||
hidden_states = out.hidden_states
|
hidden_states = out.hidden_states
|
||||||
all_hidden_states = out.all_hidden_states
|
all_hidden_states = out.all_hidden_states
|
||||||
|
Loading…
Reference in New Issue
Block a user