Fix channel first weights to channel last for right use of MLX's conv1d

This commit is contained in:
Shunta Saito 2025-02-14 22:38:13 +09:00
parent 81917d41d5
commit 66dd97ed3d

View File

@ -528,6 +528,7 @@ def ssd_update_state(
dt_softplus: bool,
) -> tuple[mx.array, mx.array]:
assert ssm_state.dtype == mx.float32
dtype = x.dtype
hidden_size_per_head = x.shape[-1]
d_state = B.shape[-1]
@ -535,17 +536,16 @@ def ssd_update_state(
dt = mx.broadcast_to(dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head))
dt_bias = mx.broadcast_to(dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head))
D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head))
assert ssm_state.dtype == mx.float32
out, ssm_state = selective_state_update_ref(
ssm_state,
x,
dt,
A,
B,
C,
D,
z,
dt_bias,
x.astype(dtype),
dt.astype(dtype),
A.astype(mx.float32),
B.astype(dtype),
C.astype(dtype),
D.astype(mx.float32),
z.astype(dtype),
dt_bias.astype(mx.float32),
dt_softplus=dt_softplus,
)
return out[:, None], ssm_state
@ -601,12 +601,14 @@ def ssd_chunk_scan_combined(
z: mx.array,
dt_bias: mx.array,
dt_softplus: bool,
return_final_states: bool,
seq_idx: mx.array | None,
ssm_state: mx.array | None,
) -> tuple[mx.array, mx.array]:
) -> tuple[mx.array, mx.array] | mx.array:
if seq_idx is not None:
assert seq_idx.dtype == mx.int32
assert ssm_state is None
assert not return_final_states
if ssm_state is not None:
assert ssm_state.dtype == mx.float32
assert seq_idx is None
@ -637,7 +639,10 @@ def ssd_chunk_scan_combined(
seq_idx=seq_idx,
ssm_state=ssm_state,
)
if return_final_states:
return tmp, ssm_state
else:
return tmp
def _causal_conv1d(
@ -650,7 +655,7 @@ def _causal_conv1d(
if seq_idx is not None:
assert seq_idx.dtype == mx.int32
assert conv_state is None
weight = weight.transpose(0, 2, 1).astype(dtype)
weight = weight.astype(dtype)
x = x.astype(dtype)
return_final_states = conv_state is not None
@ -664,7 +669,7 @@ def _causal_conv1d(
for i in range(length):
if i != 0 and seq_idx is not None:
conv_state = mx.where(
mx.array(seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None],
seq_idx[:, i - 1][:, None, None] != seq_idx[:, i][:, None, None],
mx.zeros_like(conv_state),
conv_state,
)
@ -738,24 +743,12 @@ def _causal_conv1d_update(conv_state: mx.array, weight: mx.array, xBC: mx.array)
x, conv_state = causal_conv1d_update(
x=xBC,
conv_state=conv_state,
weight=weight[:, 0, :],
weight=weight[:, :, 0],
activation="silu",
)
return x, conv_state
def is_mamba(config: ModelArgs, i: int) -> bool:
if not config.mamba_enabled:
return False
assert config.mamba_step > 1
assert i < config.num_hidden_layers
if config.num_hidden_layers <= (config.mamba_step // 2):
# use attention in last layer
return i != config.num_hidden_layers - 1
return (i % config.mamba_step) != (config.mamba_step // 2)
# Based on: https://github.com/Dao-AILab/causal-conv1d/blob/82867a9d2e6907cc0f637ac6aff318f696838548/causal_conv1d/causal_conv1d_interface.py#L206
def causal_conv1d(x, weight, bias=None, activation=None):
"""
@ -808,6 +801,213 @@ def causal_conv1d(x, weight, bias=None, activation=None):
return y
class Mamba(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int) -> None:
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.d_state = config.mamba_d_state
self.d_conv = config.mamba_d_conv
self.chunk_size = config.mamba_chunk_size
self.num_heads = config.mamba_num_heads
# TODO add mamba_hidden_size_per_head config (?)
self.hidden_size_per_head = config.hidden_size_per_head
self.intermediate_size = self.num_heads * self.hidden_size_per_head
self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
self.conv1d = nn.Conv1d(
in_channels=self.intermediate_size,
out_channels=self.intermediate_size,
bias=False, # TODO the original implementation uses bias
kernel_size=self.d_conv,
groups=self.intermediate_size,
padding=0,
)
self.dt_dim = max(64, self.hidden_size // 16)
# Notes:
# Mamba2 removes this linear projection for simplicity (Figure 6 in the paper),
# but it may degrade the ability of content-length extrapolation.
self.bcdt_proj = nn.Linear(
self.intermediate_size,
self.dt_dim + 2 * self.d_state,
bias=False,
)
self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False)
self.dt_bias = get_initial_dt_bias(self.num_heads)
self.A_log = get_initial_A(self.num_heads)
self.D = mx.ones(self.num_heads, dtype=mx.float32)
# TODO norm weight before gating like Mamba2
self.dt_norm_weight = mx.ones(self.dt_dim)
self.B_norm_weight = mx.ones(self.d_state)
self.C_norm_weight = mx.ones(self.d_state)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def _no_weight_decay_param_names(self) -> set[str]:
return set(["D", "dt_bias", "A_log"])
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
past_states: Optional[PlamoCache] = None,
) -> tuple[mx.array, Optional[PlamoCache]]:
bsize, length, _ = hidden_states.shape
is_update = length == 1 and past_states is not None
bool_mask: mx.array | None = None
seq_idx: mx.array | None = None
if attention_mask is not None:
if len(attention_mask.shape) == 2:
attention_mask = mx.broadcast_to(
attention_mask[None, None],
(bsize, 1, attention_mask.shape[0], attention_mask.shape[1]),
)
assert len(attention_mask.shape) == 4
if past_states is None:
# TODO: support seq_idx with cache
bool_mask_4d = mx.array(attention_mask == 0, dtype=mx.bool_) # type: ignore
is_first_token = _is_first_token(bool_mask_4d)[:, 0, :]
seq_idx = mx.cumsum(is_first_token, axis=-1) - 1
seq_idx = seq_idx.astype(mx.int32)
# `generate` function creates attention mask that contains past tokens,
# but mamba does not use them
attention_mask = attention_mask[:, 0, -length:, -length:]
bool_mask = mx.array(mx.diagonal(attention_mask, axis1=-2, axis2=-1) == 0)
conv_state: mx.array | None
ssm_state: mx.array | None
if past_states is None:
conv_state = None
ssm_state = None
elif past_states[self.layer_idx] is None:
conv_state = mx.zeros(
(bsize, self.intermediate_size, self.d_conv - 1),
dtype=hidden_states.dtype,
)
ssm_state = mx.zeros(
(bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
dtype=mx.float32,
)
else:
c = past_states[self.layer_idx]
assert isinstance(c, PlamoMambaCache)
conv_state = c.conv_state
ssm_state = c.ssm_state
zx = self.in_proj(hidden_states)
zx = zx.reshape(bsize, length, self.num_heads, -1)
# z: (bsize, length, num_heads, hidden_size_per_head)
# x: (bsize, length, num_heads, hidden_size_per_head)
z, x = mx.split(
zx,
[
self.hidden_size_per_head,
],
axis=-1,
)
# conv
x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length)
if bool_mask is not None:
x = mx.where(bool_mask[:, None, :], x, 0.0)
if is_update:
assert conv_state is not None
x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
else:
x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
x = x.astype(hidden_states.dtype)
x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size)
x = x.reshape(bsize, length, -1)
# x: (bsize, length, num_heads, hidden_size_per_head)
# B: (bsize, length, 1, d_state)
# C: (bsize, length, 1, d_state)
# dt: (bsize, length, dt_dim)
BCdt = self.bcdt_proj(x)
x = x.reshape(bsize, length, self.num_heads, -1)
B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1)
B = B[:, :, None, :]
C = C[:, :, None, :]
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
# (bsize, length, num_heads, 1)
dt = self.dt_proj(dt)[..., None]
# TODO it may not be required
B = mx.broadcast_to(B, (B.shape[0], B.shape[1], self.num_heads, B.shape[3]))
C = mx.broadcast_to(C, (C.shape[0], C.shape[1], self.num_heads, C.shape[3]))
if bool_mask is not None:
"""
state will be updates by following:
```
dt = softplus(dt)
dA = exp(dt * A)
state_next = state * dA + dB * x
```
To avoid updating state, we set dt to -inf and x to 0
because `softplus(-inf) = 0` and `exp(0) = 1`
"""
dt = mx.where(bool_mask[:, :, None, None], dt, float("-inf"))
x = mx.where(bool_mask[:, :, None, None], x, 0.0)
# ssm
if is_update:
assert ssm_state is not None
out, ssm_state = ssd_update_state(
ssm_state,
x[:, 0],
dt[:, 0].reshape(bsize, -1),
A,
B[:, 0],
C[:, 0],
D=self.D,
z=z[:, 0],
dt_bias=self.dt_bias,
dt_softplus=True,
)
else:
tmp = ssd_chunk_scan_combined(
x,
dt.reshape(bsize, length, -1),
A,
B,
C,
self.chunk_size,
D=self.D,
z=z,
dt_bias=self.dt_bias,
dt_softplus=True,
return_final_states=past_states is not None,
seq_idx=seq_idx,
ssm_state=ssm_state,
)
if past_states is not None:
out, ssm_state = tmp
else:
assert isinstance(tmp, mx.array)
out = tmp
y = self.out_proj(out.reshape(bsize, length, -1))
if past_states is not None:
assert ssm_state is not None
assert conv_state is not None
past_states.update_mamba(conv_state, ssm_state, self.layer_idx)
return y, past_states
def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array:
max_len = max(q_len, kv_len)
mask = mx.tril(
@ -950,209 +1150,6 @@ class Attention(nn.Module):
return attn_output, attn_weights, past_states
class Mamba(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int) -> None:
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.d_state = config.mamba_d_state
self.d_conv = config.mamba_d_conv
self.chunk_size = config.mamba_chunk_size
self.num_heads = config.mamba_num_heads
# TODO add mamba_hidden_size_per_head config (?)
self.hidden_size_per_head = config.hidden_size_per_head
self.intermediate_size = self.num_heads * self.hidden_size_per_head
self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
self.conv1d = nn.Conv1d(
in_channels=self.intermediate_size,
out_channels=self.intermediate_size,
bias=False, # TODO the original implementation uses bias
kernel_size=self.d_conv,
groups=self.intermediate_size,
padding=0,
)
self.dt_dim = max(64, self.hidden_size // 16)
# Notes:
# Mamba2 removes this linear projection for simplicity (Figure 6 in the paper),
# but it may degrade the ability of content-length extrapolation.
self.bcdt_proj = nn.Linear(
self.intermediate_size,
self.dt_dim + 2 * self.d_state,
bias=False,
)
self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False)
self.dt_bias = get_initial_dt_bias(self.num_heads)
self.A_log = get_initial_A(self.num_heads)
self.D = mx.ones(self.num_heads, dtype=mx.float32)
# TODO norm weight before gating like Mamba2
self.dt_norm_weight = mx.ones(self.dt_dim)
self.B_norm_weight = mx.ones(self.d_state)
self.C_norm_weight = mx.ones(self.d_state)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def _no_weight_decay_param_names(self) -> set[str]:
return set(["D", "dt_bias", "A_log"])
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
past_states: Optional[PlamoCache] = None,
) -> tuple[mx.array, Optional[PlamoCache]]:
bsize, length, _ = hidden_states.shape
is_update = length == 1 and past_states is not None
bool_mask: mx.array | None = None
seq_idx: mx.array | None = None
if attention_mask is not None:
if len(attention_mask.shape) == 2:
attention_mask = mx.broadcast_to(
attention_mask[None, None],
(bsize, 1, attention_mask.shape[0], attention_mask.shape[1]),
)
assert len(attention_mask.shape) == 4
if past_states is None:
# TODO: support seq_idx with cache
bool_mask_4d = mx.array(attention_mask == 0, dtype=mx.bool_) # type: ignore
is_first_token = _is_first_token(bool_mask_4d)[:, 0, :]
seq_idx = mx.cumsum(is_first_token, axis=-1) - 1
seq_idx = seq_idx.astype(mx.int32)
# `generate` function creates attention mask that contains past tokens,
# but mamba does not use them
attention_mask = attention_mask[:, 0, -length:, -length:]
bool_mask = mx.array(mx.diagonal(attention_mask, axis1=-2, axis2=-1) == 0)
conv_state: mx.array | None
ssm_state: mx.array | None
if past_states is None:
conv_state = None
ssm_state = None
elif past_states[self.layer_idx] is None:
conv_state = mx.zeros(
(bsize, self.intermediate_size, self.d_conv - 1),
dtype=hidden_states.dtype,
)
ssm_state = mx.zeros(
(bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
dtype=mx.float32,
)
else:
c = past_states[self.layer_idx]
assert isinstance(c, PlamoMambaCache)
conv_state = c.conv_state
ssm_state = c.ssm_state
zx = self.in_proj(hidden_states)
zx = zx.reshape(bsize, length, self.num_heads, -1).astype(mx.float32)
# z: (bsize, length, num_heads, hidden_size_per_head)
# x: (bsize, length, num_heads, hidden_size_per_head)
z, x = mx.split(
zx,
[
self.hidden_size_per_head,
],
axis=-1,
)
# conv
x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length)
if bool_mask is not None:
x = mx.where(bool_mask[:, None, :], x, 0.0)
if is_update:
assert conv_state is not None
x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
else:
x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
x = x.astype(mx.float32)
x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size)
x = x.reshape(bsize, length, -1)
# x: (bsize, length, num_heads, hidden_size_per_head)
# B: (bsize, length, 1, d_state)
# C: (bsize, length, 1, d_state)
# dt: (bsize, length, dt_dim)
BCdt = self.bcdt_proj(x).astype(mx.float32)
x = x.reshape(bsize, length, self.num_heads, -1)
B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1)
B = B[:, :, None, :]
C = C[:, :, None, :]
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
# (bsize, length, num_heads, 1)
dt = self.dt_proj(dt)[..., None]
# TODO it may not be required
B = mx.broadcast_to(B, (B.shape[0], B.shape[1], self.num_heads, B.shape[3]))
C = mx.broadcast_to(C, (C.shape[0], C.shape[1], self.num_heads, C.shape[3]))
if bool_mask is not None:
"""
state will be updates by following:
```
dt = softplus(dt)
dA = exp(dt * A)
state_next = state * dA + dB * x
```
To avoid updating state, we set dt to -inf and x to 0
because `softplus(-inf) = 0` and `exp(0) = 1`
"""
dt = mx.where(bool_mask[:, :, None, None], dt, float("-inf"))
x = mx.where(bool_mask[:, :, None, None], x, 0.0)
# ssm
self.D = self.D.astype(mx.float32)
self.dt_bias = self.dt_bias.astype(mx.float32)
if is_update:
assert ssm_state is not None
out, ssm_state = ssd_update_state(
ssm_state,
x[:, 0],
dt[:, 0].reshape(bsize, -1),
A,
B[:, 0],
C[:, 0],
D=self.D,
z=z[:, 0],
dt_bias=self.dt_bias,
dt_softplus=True,
)
else:
out, ssm_state = ssd_chunk_scan_combined(
x,
dt.reshape(bsize, length, -1),
A,
B,
C,
self.chunk_size,
D=self.D,
z=z,
dt_bias=self.dt_bias,
dt_softplus=True,
seq_idx=seq_idx,
ssm_state=ssm_state,
)
y = self.out_proj(out.reshape(bsize, length, -1))
if past_states is not None:
assert ssm_state is not None
assert conv_state is not None
past_states.update_mamba(conv_state, ssm_state, self.layer_idx)
return y, past_states
class MLP(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
@ -1236,6 +1233,18 @@ class PlamoDecoderLayer(nn.Module):
return outputs # type: ignore
def is_mamba(config: ModelArgs, i: int) -> bool:
if not config.mamba_enabled:
return False
assert config.mamba_step > 1
assert i < config.num_hidden_layers
if config.num_hidden_layers <= (config.mamba_step // 2):
# use attention in last layer
return i != config.num_hidden_layers - 1
return (i % config.mamba_step) != (config.mamba_step // 2)
class PlamoDecoder(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
@ -1545,6 +1554,7 @@ class PlamoModel(PlamoPreTrainedModel):
self.gradient_checkpointing,
)
)
assert isinstance(out, DecoderOutput)
hidden_states = out.hidden_states
all_hidden_states = out.all_hidden_states