mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 16:16:27 +08:00
Fix channel first weights to channel last for right use of MLX's conv1d
This commit is contained in:
parent
81917d41d5
commit
66dd97ed3d
@ -285,7 +285,7 @@ class PlamoCache(nn.Module):
|
||||
assert layer_idx < len(self.cache)
|
||||
layer_cache = self.cache[layer_idx]
|
||||
return layer_cache # type: ignore
|
||||
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.cache
|
||||
@ -528,6 +528,7 @@ def ssd_update_state(
|
||||
dt_softplus: bool,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
assert ssm_state.dtype == mx.float32
|
||||
dtype = x.dtype
|
||||
|
||||
hidden_size_per_head = x.shape[-1]
|
||||
d_state = B.shape[-1]
|
||||
@ -535,17 +536,16 @@ def ssd_update_state(
|
||||
dt = mx.broadcast_to(dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head))
|
||||
dt_bias = mx.broadcast_to(dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head))
|
||||
D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head))
|
||||
assert ssm_state.dtype == mx.float32
|
||||
out, ssm_state = selective_state_update_ref(
|
||||
ssm_state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
dt_bias,
|
||||
x.astype(dtype),
|
||||
dt.astype(dtype),
|
||||
A.astype(mx.float32),
|
||||
B.astype(dtype),
|
||||
C.astype(dtype),
|
||||
D.astype(mx.float32),
|
||||
z.astype(dtype),
|
||||
dt_bias.astype(mx.float32),
|
||||
dt_softplus=dt_softplus,
|
||||
)
|
||||
return out[:, None], ssm_state
|
||||
@ -601,12 +601,14 @@ def ssd_chunk_scan_combined(
|
||||
z: mx.array,
|
||||
dt_bias: mx.array,
|
||||
dt_softplus: bool,
|
||||
return_final_states: bool,
|
||||
seq_idx: mx.array | None,
|
||||
ssm_state: mx.array | None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
) -> tuple[mx.array, mx.array] | mx.array:
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.dtype == mx.int32
|
||||
assert ssm_state is None
|
||||
assert not return_final_states
|
||||
if ssm_state is not None:
|
||||
assert ssm_state.dtype == mx.float32
|
||||
assert seq_idx is None
|
||||
@ -637,7 +639,10 @@ def ssd_chunk_scan_combined(
|
||||
seq_idx=seq_idx,
|
||||
ssm_state=ssm_state,
|
||||
)
|
||||
return tmp, ssm_state
|
||||
if return_final_states:
|
||||
return tmp, ssm_state
|
||||
else:
|
||||
return tmp
|
||||
|
||||
|
||||
def _causal_conv1d(
|
||||
@ -650,7 +655,7 @@ def _causal_conv1d(
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.dtype == mx.int32
|
||||
assert conv_state is None
|
||||
weight = weight.transpose(0, 2, 1).astype(dtype)
|
||||
weight = weight.astype(dtype)
|
||||
x = x.astype(dtype)
|
||||
|
||||
return_final_states = conv_state is not None
|
||||
@ -664,7 +669,7 @@ def _causal_conv1d(
|
||||
for i in range(length):
|
||||
if i != 0 and seq_idx is not None:
|
||||
conv_state = mx.where(
|
||||
mx.array(seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None],
|
||||
seq_idx[:, i - 1][:, None, None] != seq_idx[:, i][:, None, None],
|
||||
mx.zeros_like(conv_state),
|
||||
conv_state,
|
||||
)
|
||||
@ -738,24 +743,12 @@ def _causal_conv1d_update(conv_state: mx.array, weight: mx.array, xBC: mx.array)
|
||||
x, conv_state = causal_conv1d_update(
|
||||
x=xBC,
|
||||
conv_state=conv_state,
|
||||
weight=weight[:, 0, :],
|
||||
weight=weight[:, :, 0],
|
||||
activation="silu",
|
||||
)
|
||||
return x, conv_state
|
||||
|
||||
|
||||
def is_mamba(config: ModelArgs, i: int) -> bool:
|
||||
if not config.mamba_enabled:
|
||||
return False
|
||||
assert config.mamba_step > 1
|
||||
assert i < config.num_hidden_layers
|
||||
|
||||
if config.num_hidden_layers <= (config.mamba_step // 2):
|
||||
# use attention in last layer
|
||||
return i != config.num_hidden_layers - 1
|
||||
return (i % config.mamba_step) != (config.mamba_step // 2)
|
||||
|
||||
|
||||
# Based on: https://github.com/Dao-AILab/causal-conv1d/blob/82867a9d2e6907cc0f637ac6aff318f696838548/causal_conv1d/causal_conv1d_interface.py#L206
|
||||
def causal_conv1d(x, weight, bias=None, activation=None):
|
||||
"""
|
||||
@ -808,6 +801,213 @@ def causal_conv1d(x, weight, bias=None, activation=None):
|
||||
return y
|
||||
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(self, config: ModelArgs, layer_idx: int) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.d_state = config.mamba_d_state
|
||||
self.d_conv = config.mamba_d_conv
|
||||
self.chunk_size = config.mamba_chunk_size
|
||||
self.num_heads = config.mamba_num_heads
|
||||
# TODO add mamba_hidden_size_per_head config (?)
|
||||
self.hidden_size_per_head = config.hidden_size_per_head
|
||||
|
||||
self.intermediate_size = self.num_heads * self.hidden_size_per_head
|
||||
|
||||
self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=self.intermediate_size,
|
||||
out_channels=self.intermediate_size,
|
||||
bias=False, # TODO the original implementation uses bias
|
||||
kernel_size=self.d_conv,
|
||||
groups=self.intermediate_size,
|
||||
padding=0,
|
||||
)
|
||||
self.dt_dim = max(64, self.hidden_size // 16)
|
||||
# Notes:
|
||||
# Mamba2 removes this linear projection for simplicity (Figure 6 in the paper),
|
||||
# but it may degrade the ability of content-length extrapolation.
|
||||
self.bcdt_proj = nn.Linear(
|
||||
self.intermediate_size,
|
||||
self.dt_dim + 2 * self.d_state,
|
||||
bias=False,
|
||||
)
|
||||
self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False)
|
||||
|
||||
self.dt_bias = get_initial_dt_bias(self.num_heads)
|
||||
self.A_log = get_initial_A(self.num_heads)
|
||||
self.D = mx.ones(self.num_heads, dtype=mx.float32)
|
||||
|
||||
# TODO norm weight before gating like Mamba2
|
||||
self.dt_norm_weight = mx.ones(self.dt_dim)
|
||||
self.B_norm_weight = mx.ones(self.d_state)
|
||||
self.C_norm_weight = mx.ones(self.d_state)
|
||||
|
||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def _no_weight_decay_param_names(self) -> set[str]:
|
||||
return set(["D", "dt_bias", "A_log"])
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
past_states: Optional[PlamoCache] = None,
|
||||
) -> tuple[mx.array, Optional[PlamoCache]]:
|
||||
bsize, length, _ = hidden_states.shape
|
||||
is_update = length == 1 and past_states is not None
|
||||
|
||||
bool_mask: mx.array | None = None
|
||||
seq_idx: mx.array | None = None
|
||||
if attention_mask is not None:
|
||||
if len(attention_mask.shape) == 2:
|
||||
attention_mask = mx.broadcast_to(
|
||||
attention_mask[None, None],
|
||||
(bsize, 1, attention_mask.shape[0], attention_mask.shape[1]),
|
||||
)
|
||||
assert len(attention_mask.shape) == 4
|
||||
|
||||
if past_states is None:
|
||||
# TODO: support seq_idx with cache
|
||||
bool_mask_4d = mx.array(attention_mask == 0, dtype=mx.bool_) # type: ignore
|
||||
is_first_token = _is_first_token(bool_mask_4d)[:, 0, :]
|
||||
seq_idx = mx.cumsum(is_first_token, axis=-1) - 1
|
||||
seq_idx = seq_idx.astype(mx.int32)
|
||||
|
||||
# `generate` function creates attention mask that contains past tokens,
|
||||
# but mamba does not use them
|
||||
attention_mask = attention_mask[:, 0, -length:, -length:]
|
||||
bool_mask = mx.array(mx.diagonal(attention_mask, axis1=-2, axis2=-1) == 0)
|
||||
|
||||
conv_state: mx.array | None
|
||||
ssm_state: mx.array | None
|
||||
if past_states is None:
|
||||
conv_state = None
|
||||
ssm_state = None
|
||||
elif past_states[self.layer_idx] is None:
|
||||
conv_state = mx.zeros(
|
||||
(bsize, self.intermediate_size, self.d_conv - 1),
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
ssm_state = mx.zeros(
|
||||
(bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
|
||||
dtype=mx.float32,
|
||||
)
|
||||
else:
|
||||
c = past_states[self.layer_idx]
|
||||
assert isinstance(c, PlamoMambaCache)
|
||||
conv_state = c.conv_state
|
||||
ssm_state = c.ssm_state
|
||||
|
||||
zx = self.in_proj(hidden_states)
|
||||
zx = zx.reshape(bsize, length, self.num_heads, -1)
|
||||
# z: (bsize, length, num_heads, hidden_size_per_head)
|
||||
# x: (bsize, length, num_heads, hidden_size_per_head)
|
||||
z, x = mx.split(
|
||||
zx,
|
||||
[
|
||||
self.hidden_size_per_head,
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# conv
|
||||
x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length)
|
||||
if bool_mask is not None:
|
||||
x = mx.where(bool_mask[:, None, :], x, 0.0)
|
||||
if is_update:
|
||||
assert conv_state is not None
|
||||
x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
|
||||
else:
|
||||
x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
|
||||
x = x.astype(hidden_states.dtype)
|
||||
x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size)
|
||||
x = x.reshape(bsize, length, -1)
|
||||
# x: (bsize, length, num_heads, hidden_size_per_head)
|
||||
# B: (bsize, length, 1, d_state)
|
||||
# C: (bsize, length, 1, d_state)
|
||||
# dt: (bsize, length, dt_dim)
|
||||
BCdt = self.bcdt_proj(x)
|
||||
x = x.reshape(bsize, length, self.num_heads, -1)
|
||||
B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1)
|
||||
B = B[:, :, None, :]
|
||||
C = C[:, :, None, :]
|
||||
|
||||
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
|
||||
dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
|
||||
B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
|
||||
C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
|
||||
|
||||
# (bsize, length, num_heads, 1)
|
||||
dt = self.dt_proj(dt)[..., None]
|
||||
|
||||
# TODO it may not be required
|
||||
B = mx.broadcast_to(B, (B.shape[0], B.shape[1], self.num_heads, B.shape[3]))
|
||||
C = mx.broadcast_to(C, (C.shape[0], C.shape[1], self.num_heads, C.shape[3]))
|
||||
|
||||
if bool_mask is not None:
|
||||
"""
|
||||
state will be updates by following:
|
||||
```
|
||||
dt = softplus(dt)
|
||||
dA = exp(dt * A)
|
||||
state_next = state * dA + dB * x
|
||||
```
|
||||
To avoid updating state, we set dt to -inf and x to 0
|
||||
because `softplus(-inf) = 0` and `exp(0) = 1`
|
||||
"""
|
||||
dt = mx.where(bool_mask[:, :, None, None], dt, float("-inf"))
|
||||
x = mx.where(bool_mask[:, :, None, None], x, 0.0)
|
||||
|
||||
# ssm
|
||||
if is_update:
|
||||
assert ssm_state is not None
|
||||
out, ssm_state = ssd_update_state(
|
||||
ssm_state,
|
||||
x[:, 0],
|
||||
dt[:, 0].reshape(bsize, -1),
|
||||
A,
|
||||
B[:, 0],
|
||||
C[:, 0],
|
||||
D=self.D,
|
||||
z=z[:, 0],
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
else:
|
||||
tmp = ssd_chunk_scan_combined(
|
||||
x,
|
||||
dt.reshape(bsize, length, -1),
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
self.chunk_size,
|
||||
D=self.D,
|
||||
z=z,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
return_final_states=past_states is not None,
|
||||
seq_idx=seq_idx,
|
||||
ssm_state=ssm_state,
|
||||
)
|
||||
if past_states is not None:
|
||||
out, ssm_state = tmp
|
||||
else:
|
||||
assert isinstance(tmp, mx.array)
|
||||
out = tmp
|
||||
|
||||
y = self.out_proj(out.reshape(bsize, length, -1))
|
||||
|
||||
if past_states is not None:
|
||||
assert ssm_state is not None
|
||||
assert conv_state is not None
|
||||
past_states.update_mamba(conv_state, ssm_state, self.layer_idx)
|
||||
|
||||
return y, past_states
|
||||
|
||||
|
||||
def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array:
|
||||
max_len = max(q_len, kv_len)
|
||||
mask = mx.tril(
|
||||
@ -950,209 +1150,6 @@ class Attention(nn.Module):
|
||||
return attn_output, attn_weights, past_states
|
||||
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(self, config: ModelArgs, layer_idx: int) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.d_state = config.mamba_d_state
|
||||
self.d_conv = config.mamba_d_conv
|
||||
self.chunk_size = config.mamba_chunk_size
|
||||
self.num_heads = config.mamba_num_heads
|
||||
# TODO add mamba_hidden_size_per_head config (?)
|
||||
self.hidden_size_per_head = config.hidden_size_per_head
|
||||
|
||||
self.intermediate_size = self.num_heads * self.hidden_size_per_head
|
||||
|
||||
self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=self.intermediate_size,
|
||||
out_channels=self.intermediate_size,
|
||||
bias=False, # TODO the original implementation uses bias
|
||||
kernel_size=self.d_conv,
|
||||
groups=self.intermediate_size,
|
||||
padding=0,
|
||||
)
|
||||
self.dt_dim = max(64, self.hidden_size // 16)
|
||||
# Notes:
|
||||
# Mamba2 removes this linear projection for simplicity (Figure 6 in the paper),
|
||||
# but it may degrade the ability of content-length extrapolation.
|
||||
self.bcdt_proj = nn.Linear(
|
||||
self.intermediate_size,
|
||||
self.dt_dim + 2 * self.d_state,
|
||||
bias=False,
|
||||
)
|
||||
self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False)
|
||||
|
||||
self.dt_bias = get_initial_dt_bias(self.num_heads)
|
||||
self.A_log = get_initial_A(self.num_heads)
|
||||
self.D = mx.ones(self.num_heads, dtype=mx.float32)
|
||||
|
||||
# TODO norm weight before gating like Mamba2
|
||||
self.dt_norm_weight = mx.ones(self.dt_dim)
|
||||
self.B_norm_weight = mx.ones(self.d_state)
|
||||
self.C_norm_weight = mx.ones(self.d_state)
|
||||
|
||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def _no_weight_decay_param_names(self) -> set[str]:
|
||||
return set(["D", "dt_bias", "A_log"])
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
past_states: Optional[PlamoCache] = None,
|
||||
) -> tuple[mx.array, Optional[PlamoCache]]:
|
||||
bsize, length, _ = hidden_states.shape
|
||||
is_update = length == 1 and past_states is not None
|
||||
|
||||
bool_mask: mx.array | None = None
|
||||
seq_idx: mx.array | None = None
|
||||
if attention_mask is not None:
|
||||
if len(attention_mask.shape) == 2:
|
||||
attention_mask = mx.broadcast_to(
|
||||
attention_mask[None, None],
|
||||
(bsize, 1, attention_mask.shape[0], attention_mask.shape[1]),
|
||||
)
|
||||
assert len(attention_mask.shape) == 4
|
||||
|
||||
if past_states is None:
|
||||
# TODO: support seq_idx with cache
|
||||
bool_mask_4d = mx.array(attention_mask == 0, dtype=mx.bool_) # type: ignore
|
||||
is_first_token = _is_first_token(bool_mask_4d)[:, 0, :]
|
||||
seq_idx = mx.cumsum(is_first_token, axis=-1) - 1
|
||||
seq_idx = seq_idx.astype(mx.int32)
|
||||
|
||||
# `generate` function creates attention mask that contains past tokens,
|
||||
# but mamba does not use them
|
||||
attention_mask = attention_mask[:, 0, -length:, -length:]
|
||||
bool_mask = mx.array(mx.diagonal(attention_mask, axis1=-2, axis2=-1) == 0)
|
||||
|
||||
conv_state: mx.array | None
|
||||
ssm_state: mx.array | None
|
||||
if past_states is None:
|
||||
conv_state = None
|
||||
ssm_state = None
|
||||
elif past_states[self.layer_idx] is None:
|
||||
conv_state = mx.zeros(
|
||||
(bsize, self.intermediate_size, self.d_conv - 1),
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
ssm_state = mx.zeros(
|
||||
(bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
|
||||
dtype=mx.float32,
|
||||
)
|
||||
else:
|
||||
c = past_states[self.layer_idx]
|
||||
assert isinstance(c, PlamoMambaCache)
|
||||
conv_state = c.conv_state
|
||||
ssm_state = c.ssm_state
|
||||
|
||||
zx = self.in_proj(hidden_states)
|
||||
zx = zx.reshape(bsize, length, self.num_heads, -1).astype(mx.float32)
|
||||
# z: (bsize, length, num_heads, hidden_size_per_head)
|
||||
# x: (bsize, length, num_heads, hidden_size_per_head)
|
||||
z, x = mx.split(
|
||||
zx,
|
||||
[
|
||||
self.hidden_size_per_head,
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# conv
|
||||
x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length)
|
||||
if bool_mask is not None:
|
||||
x = mx.where(bool_mask[:, None, :], x, 0.0)
|
||||
if is_update:
|
||||
assert conv_state is not None
|
||||
x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
|
||||
else:
|
||||
x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
|
||||
x = x.astype(mx.float32)
|
||||
x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size)
|
||||
x = x.reshape(bsize, length, -1)
|
||||
# x: (bsize, length, num_heads, hidden_size_per_head)
|
||||
# B: (bsize, length, 1, d_state)
|
||||
# C: (bsize, length, 1, d_state)
|
||||
# dt: (bsize, length, dt_dim)
|
||||
BCdt = self.bcdt_proj(x).astype(mx.float32)
|
||||
x = x.reshape(bsize, length, self.num_heads, -1)
|
||||
B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1)
|
||||
B = B[:, :, None, :]
|
||||
C = C[:, :, None, :]
|
||||
|
||||
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
|
||||
dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
|
||||
B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
|
||||
C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
|
||||
|
||||
# (bsize, length, num_heads, 1)
|
||||
dt = self.dt_proj(dt)[..., None]
|
||||
|
||||
# TODO it may not be required
|
||||
B = mx.broadcast_to(B, (B.shape[0], B.shape[1], self.num_heads, B.shape[3]))
|
||||
C = mx.broadcast_to(C, (C.shape[0], C.shape[1], self.num_heads, C.shape[3]))
|
||||
|
||||
if bool_mask is not None:
|
||||
"""
|
||||
state will be updates by following:
|
||||
```
|
||||
dt = softplus(dt)
|
||||
dA = exp(dt * A)
|
||||
state_next = state * dA + dB * x
|
||||
```
|
||||
To avoid updating state, we set dt to -inf and x to 0
|
||||
because `softplus(-inf) = 0` and `exp(0) = 1`
|
||||
"""
|
||||
dt = mx.where(bool_mask[:, :, None, None], dt, float("-inf"))
|
||||
x = mx.where(bool_mask[:, :, None, None], x, 0.0)
|
||||
|
||||
# ssm
|
||||
self.D = self.D.astype(mx.float32)
|
||||
self.dt_bias = self.dt_bias.astype(mx.float32)
|
||||
if is_update:
|
||||
assert ssm_state is not None
|
||||
out, ssm_state = ssd_update_state(
|
||||
ssm_state,
|
||||
x[:, 0],
|
||||
dt[:, 0].reshape(bsize, -1),
|
||||
A,
|
||||
B[:, 0],
|
||||
C[:, 0],
|
||||
D=self.D,
|
||||
z=z[:, 0],
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
else:
|
||||
out, ssm_state = ssd_chunk_scan_combined(
|
||||
x,
|
||||
dt.reshape(bsize, length, -1),
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
self.chunk_size,
|
||||
D=self.D,
|
||||
z=z,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
seq_idx=seq_idx,
|
||||
ssm_state=ssm_state,
|
||||
)
|
||||
|
||||
y = self.out_proj(out.reshape(bsize, length, -1))
|
||||
|
||||
if past_states is not None:
|
||||
assert ssm_state is not None
|
||||
assert conv_state is not None
|
||||
past_states.update_mamba(conv_state, ssm_state, self.layer_idx)
|
||||
|
||||
return y, past_states
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
@ -1236,6 +1233,18 @@ class PlamoDecoderLayer(nn.Module):
|
||||
return outputs # type: ignore
|
||||
|
||||
|
||||
def is_mamba(config: ModelArgs, i: int) -> bool:
|
||||
if not config.mamba_enabled:
|
||||
return False
|
||||
assert config.mamba_step > 1
|
||||
assert i < config.num_hidden_layers
|
||||
|
||||
if config.num_hidden_layers <= (config.mamba_step // 2):
|
||||
# use attention in last layer
|
||||
return i != config.num_hidden_layers - 1
|
||||
return (i % config.mamba_step) != (config.mamba_step // 2)
|
||||
|
||||
|
||||
class PlamoDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
@ -1545,6 +1554,7 @@ class PlamoModel(PlamoPreTrainedModel):
|
||||
self.gradient_checkpointing,
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(out, DecoderOutput)
|
||||
hidden_states = out.hidden_states
|
||||
all_hidden_states = out.all_hidden_states
|
||||
|
Loading…
Reference in New Issue
Block a user