From 66dd97ed3d52d51f2437074a489d3ed817935fff Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Fri, 14 Feb 2025 22:38:13 +0900 Subject: [PATCH] Fix channel first weights to channel last for right use of MLX's conv1d --- llms/mlx_lm/models/plamo2.py | 470 ++++++++++++++++++----------------- 1 file changed, 240 insertions(+), 230 deletions(-) diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index c7624811..7edfef07 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -285,7 +285,7 @@ class PlamoCache(nn.Module): assert layer_idx < len(self.cache) layer_cache = self.cache[layer_idx] return layer_cache # type: ignore - + @property def state(self): return self.cache @@ -528,6 +528,7 @@ def ssd_update_state( dt_softplus: bool, ) -> tuple[mx.array, mx.array]: assert ssm_state.dtype == mx.float32 + dtype = x.dtype hidden_size_per_head = x.shape[-1] d_state = B.shape[-1] @@ -535,17 +536,16 @@ def ssd_update_state( dt = mx.broadcast_to(dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head)) dt_bias = mx.broadcast_to(dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head)) D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head)) - assert ssm_state.dtype == mx.float32 out, ssm_state = selective_state_update_ref( ssm_state, - x, - dt, - A, - B, - C, - D, - z, - dt_bias, + x.astype(dtype), + dt.astype(dtype), + A.astype(mx.float32), + B.astype(dtype), + C.astype(dtype), + D.astype(mx.float32), + z.astype(dtype), + dt_bias.astype(mx.float32), dt_softplus=dt_softplus, ) return out[:, None], ssm_state @@ -601,12 +601,14 @@ def ssd_chunk_scan_combined( z: mx.array, dt_bias: mx.array, dt_softplus: bool, + return_final_states: bool, seq_idx: mx.array | None, ssm_state: mx.array | None, -) -> tuple[mx.array, mx.array]: +) -> tuple[mx.array, mx.array] | mx.array: if seq_idx is not None: assert seq_idx.dtype == mx.int32 assert ssm_state is None + assert not return_final_states if ssm_state is not None: assert ssm_state.dtype == mx.float32 assert seq_idx is None @@ -637,7 +639,10 @@ def ssd_chunk_scan_combined( seq_idx=seq_idx, ssm_state=ssm_state, ) - return tmp, ssm_state + if return_final_states: + return tmp, ssm_state + else: + return tmp def _causal_conv1d( @@ -650,7 +655,7 @@ def _causal_conv1d( if seq_idx is not None: assert seq_idx.dtype == mx.int32 assert conv_state is None - weight = weight.transpose(0, 2, 1).astype(dtype) + weight = weight.astype(dtype) x = x.astype(dtype) return_final_states = conv_state is not None @@ -664,7 +669,7 @@ def _causal_conv1d( for i in range(length): if i != 0 and seq_idx is not None: conv_state = mx.where( - mx.array(seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None], + seq_idx[:, i - 1][:, None, None] != seq_idx[:, i][:, None, None], mx.zeros_like(conv_state), conv_state, ) @@ -738,24 +743,12 @@ def _causal_conv1d_update(conv_state: mx.array, weight: mx.array, xBC: mx.array) x, conv_state = causal_conv1d_update( x=xBC, conv_state=conv_state, - weight=weight[:, 0, :], + weight=weight[:, :, 0], activation="silu", ) return x, conv_state -def is_mamba(config: ModelArgs, i: int) -> bool: - if not config.mamba_enabled: - return False - assert config.mamba_step > 1 - assert i < config.num_hidden_layers - - if config.num_hidden_layers <= (config.mamba_step // 2): - # use attention in last layer - return i != config.num_hidden_layers - 1 - return (i % config.mamba_step) != (config.mamba_step // 2) - - # Based on: https://github.com/Dao-AILab/causal-conv1d/blob/82867a9d2e6907cc0f637ac6aff318f696838548/causal_conv1d/causal_conv1d_interface.py#L206 def causal_conv1d(x, weight, bias=None, activation=None): """ @@ -808,6 +801,213 @@ def causal_conv1d(x, weight, bias=None, activation=None): return y +class Mamba(nn.Module): + def __init__(self, config: ModelArgs, layer_idx: int) -> None: + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.d_state = config.mamba_d_state + self.d_conv = config.mamba_d_conv + self.chunk_size = config.mamba_chunk_size + self.num_heads = config.mamba_num_heads + # TODO add mamba_hidden_size_per_head config (?) + self.hidden_size_per_head = config.hidden_size_per_head + + self.intermediate_size = self.num_heads * self.hidden_size_per_head + + self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False) + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=False, # TODO the original implementation uses bias + kernel_size=self.d_conv, + groups=self.intermediate_size, + padding=0, + ) + self.dt_dim = max(64, self.hidden_size // 16) + # Notes: + # Mamba2 removes this linear projection for simplicity (Figure 6 in the paper), + # but it may degrade the ability of content-length extrapolation. + self.bcdt_proj = nn.Linear( + self.intermediate_size, + self.dt_dim + 2 * self.d_state, + bias=False, + ) + self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False) + + self.dt_bias = get_initial_dt_bias(self.num_heads) + self.A_log = get_initial_A(self.num_heads) + self.D = mx.ones(self.num_heads, dtype=mx.float32) + + # TODO norm weight before gating like Mamba2 + self.dt_norm_weight = mx.ones(self.dt_dim) + self.B_norm_weight = mx.ones(self.d_state) + self.C_norm_weight = mx.ones(self.d_state) + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def _no_weight_decay_param_names(self) -> set[str]: + return set(["D", "dt_bias", "A_log"]) + + def __call__( + self, + hidden_states: mx.array, + attention_mask: Optional[mx.array] = None, + past_states: Optional[PlamoCache] = None, + ) -> tuple[mx.array, Optional[PlamoCache]]: + bsize, length, _ = hidden_states.shape + is_update = length == 1 and past_states is not None + + bool_mask: mx.array | None = None + seq_idx: mx.array | None = None + if attention_mask is not None: + if len(attention_mask.shape) == 2: + attention_mask = mx.broadcast_to( + attention_mask[None, None], + (bsize, 1, attention_mask.shape[0], attention_mask.shape[1]), + ) + assert len(attention_mask.shape) == 4 + + if past_states is None: + # TODO: support seq_idx with cache + bool_mask_4d = mx.array(attention_mask == 0, dtype=mx.bool_) # type: ignore + is_first_token = _is_first_token(bool_mask_4d)[:, 0, :] + seq_idx = mx.cumsum(is_first_token, axis=-1) - 1 + seq_idx = seq_idx.astype(mx.int32) + + # `generate` function creates attention mask that contains past tokens, + # but mamba does not use them + attention_mask = attention_mask[:, 0, -length:, -length:] + bool_mask = mx.array(mx.diagonal(attention_mask, axis1=-2, axis2=-1) == 0) + + conv_state: mx.array | None + ssm_state: mx.array | None + if past_states is None: + conv_state = None + ssm_state = None + elif past_states[self.layer_idx] is None: + conv_state = mx.zeros( + (bsize, self.intermediate_size, self.d_conv - 1), + dtype=hidden_states.dtype, + ) + ssm_state = mx.zeros( + (bsize, self.num_heads, self.hidden_size_per_head, self.d_state), + dtype=mx.float32, + ) + else: + c = past_states[self.layer_idx] + assert isinstance(c, PlamoMambaCache) + conv_state = c.conv_state + ssm_state = c.ssm_state + + zx = self.in_proj(hidden_states) + zx = zx.reshape(bsize, length, self.num_heads, -1) + # z: (bsize, length, num_heads, hidden_size_per_head) + # x: (bsize, length, num_heads, hidden_size_per_head) + z, x = mx.split( + zx, + [ + self.hidden_size_per_head, + ], + axis=-1, + ) + + # conv + x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length) + if bool_mask is not None: + x = mx.where(bool_mask[:, None, :], x, 0.0) + if is_update: + assert conv_state is not None + x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x) + else: + x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx) + x = x.astype(hidden_states.dtype) + x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size) + x = x.reshape(bsize, length, -1) + # x: (bsize, length, num_heads, hidden_size_per_head) + # B: (bsize, length, 1, d_state) + # C: (bsize, length, 1, d_state) + # dt: (bsize, length, dt_dim) + BCdt = self.bcdt_proj(x) + x = x.reshape(bsize, length, self.num_heads, -1) + B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1) + B = B[:, :, None, :] + C = C[:, :, None, :] + + A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,) + dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :] + B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :] + C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :] + + # (bsize, length, num_heads, 1) + dt = self.dt_proj(dt)[..., None] + + # TODO it may not be required + B = mx.broadcast_to(B, (B.shape[0], B.shape[1], self.num_heads, B.shape[3])) + C = mx.broadcast_to(C, (C.shape[0], C.shape[1], self.num_heads, C.shape[3])) + + if bool_mask is not None: + """ + state will be updates by following: + ``` + dt = softplus(dt) + dA = exp(dt * A) + state_next = state * dA + dB * x + ``` + To avoid updating state, we set dt to -inf and x to 0 + because `softplus(-inf) = 0` and `exp(0) = 1` + """ + dt = mx.where(bool_mask[:, :, None, None], dt, float("-inf")) + x = mx.where(bool_mask[:, :, None, None], x, 0.0) + + # ssm + if is_update: + assert ssm_state is not None + out, ssm_state = ssd_update_state( + ssm_state, + x[:, 0], + dt[:, 0].reshape(bsize, -1), + A, + B[:, 0], + C[:, 0], + D=self.D, + z=z[:, 0], + dt_bias=self.dt_bias, + dt_softplus=True, + ) + else: + tmp = ssd_chunk_scan_combined( + x, + dt.reshape(bsize, length, -1), + A, + B, + C, + self.chunk_size, + D=self.D, + z=z, + dt_bias=self.dt_bias, + dt_softplus=True, + return_final_states=past_states is not None, + seq_idx=seq_idx, + ssm_state=ssm_state, + ) + if past_states is not None: + out, ssm_state = tmp + else: + assert isinstance(tmp, mx.array) + out = tmp + + y = self.out_proj(out.reshape(bsize, length, -1)) + + if past_states is not None: + assert ssm_state is not None + assert conv_state is not None + past_states.update_mamba(conv_state, ssm_state, self.layer_idx) + + return y, past_states + + def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array: max_len = max(q_len, kv_len) mask = mx.tril( @@ -950,209 +1150,6 @@ class Attention(nn.Module): return attn_output, attn_weights, past_states -class Mamba(nn.Module): - def __init__(self, config: ModelArgs, layer_idx: int) -> None: - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.d_state = config.mamba_d_state - self.d_conv = config.mamba_d_conv - self.chunk_size = config.mamba_chunk_size - self.num_heads = config.mamba_num_heads - # TODO add mamba_hidden_size_per_head config (?) - self.hidden_size_per_head = config.hidden_size_per_head - - self.intermediate_size = self.num_heads * self.hidden_size_per_head - - self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False) - self.conv1d = nn.Conv1d( - in_channels=self.intermediate_size, - out_channels=self.intermediate_size, - bias=False, # TODO the original implementation uses bias - kernel_size=self.d_conv, - groups=self.intermediate_size, - padding=0, - ) - self.dt_dim = max(64, self.hidden_size // 16) - # Notes: - # Mamba2 removes this linear projection for simplicity (Figure 6 in the paper), - # but it may degrade the ability of content-length extrapolation. - self.bcdt_proj = nn.Linear( - self.intermediate_size, - self.dt_dim + 2 * self.d_state, - bias=False, - ) - self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False) - - self.dt_bias = get_initial_dt_bias(self.num_heads) - self.A_log = get_initial_A(self.num_heads) - self.D = mx.ones(self.num_heads, dtype=mx.float32) - - # TODO norm weight before gating like Mamba2 - self.dt_norm_weight = mx.ones(self.dt_dim) - self.B_norm_weight = mx.ones(self.d_state) - self.C_norm_weight = mx.ones(self.d_state) - - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - - def _no_weight_decay_param_names(self) -> set[str]: - return set(["D", "dt_bias", "A_log"]) - - def __call__( - self, - hidden_states: mx.array, - attention_mask: Optional[mx.array] = None, - past_states: Optional[PlamoCache] = None, - ) -> tuple[mx.array, Optional[PlamoCache]]: - bsize, length, _ = hidden_states.shape - is_update = length == 1 and past_states is not None - - bool_mask: mx.array | None = None - seq_idx: mx.array | None = None - if attention_mask is not None: - if len(attention_mask.shape) == 2: - attention_mask = mx.broadcast_to( - attention_mask[None, None], - (bsize, 1, attention_mask.shape[0], attention_mask.shape[1]), - ) - assert len(attention_mask.shape) == 4 - - if past_states is None: - # TODO: support seq_idx with cache - bool_mask_4d = mx.array(attention_mask == 0, dtype=mx.bool_) # type: ignore - is_first_token = _is_first_token(bool_mask_4d)[:, 0, :] - seq_idx = mx.cumsum(is_first_token, axis=-1) - 1 - seq_idx = seq_idx.astype(mx.int32) - - # `generate` function creates attention mask that contains past tokens, - # but mamba does not use them - attention_mask = attention_mask[:, 0, -length:, -length:] - bool_mask = mx.array(mx.diagonal(attention_mask, axis1=-2, axis2=-1) == 0) - - conv_state: mx.array | None - ssm_state: mx.array | None - if past_states is None: - conv_state = None - ssm_state = None - elif past_states[self.layer_idx] is None: - conv_state = mx.zeros( - (bsize, self.intermediate_size, self.d_conv - 1), - dtype=hidden_states.dtype, - ) - ssm_state = mx.zeros( - (bsize, self.num_heads, self.hidden_size_per_head, self.d_state), - dtype=mx.float32, - ) - else: - c = past_states[self.layer_idx] - assert isinstance(c, PlamoMambaCache) - conv_state = c.conv_state - ssm_state = c.ssm_state - - zx = self.in_proj(hidden_states) - zx = zx.reshape(bsize, length, self.num_heads, -1).astype(mx.float32) - # z: (bsize, length, num_heads, hidden_size_per_head) - # x: (bsize, length, num_heads, hidden_size_per_head) - z, x = mx.split( - zx, - [ - self.hidden_size_per_head, - ], - axis=-1, - ) - - # conv - x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length) - if bool_mask is not None: - x = mx.where(bool_mask[:, None, :], x, 0.0) - if is_update: - assert conv_state is not None - x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x) - else: - x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx) - x = x.astype(mx.float32) - x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size) - x = x.reshape(bsize, length, -1) - # x: (bsize, length, num_heads, hidden_size_per_head) - # B: (bsize, length, 1, d_state) - # C: (bsize, length, 1, d_state) - # dt: (bsize, length, dt_dim) - BCdt = self.bcdt_proj(x).astype(mx.float32) - x = x.reshape(bsize, length, self.num_heads, -1) - B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1) - B = B[:, :, None, :] - C = C[:, :, None, :] - - A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,) - dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :] - B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :] - C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :] - - # (bsize, length, num_heads, 1) - dt = self.dt_proj(dt)[..., None] - - # TODO it may not be required - B = mx.broadcast_to(B, (B.shape[0], B.shape[1], self.num_heads, B.shape[3])) - C = mx.broadcast_to(C, (C.shape[0], C.shape[1], self.num_heads, C.shape[3])) - - if bool_mask is not None: - """ - state will be updates by following: - ``` - dt = softplus(dt) - dA = exp(dt * A) - state_next = state * dA + dB * x - ``` - To avoid updating state, we set dt to -inf and x to 0 - because `softplus(-inf) = 0` and `exp(0) = 1` - """ - dt = mx.where(bool_mask[:, :, None, None], dt, float("-inf")) - x = mx.where(bool_mask[:, :, None, None], x, 0.0) - - # ssm - self.D = self.D.astype(mx.float32) - self.dt_bias = self.dt_bias.astype(mx.float32) - if is_update: - assert ssm_state is not None - out, ssm_state = ssd_update_state( - ssm_state, - x[:, 0], - dt[:, 0].reshape(bsize, -1), - A, - B[:, 0], - C[:, 0], - D=self.D, - z=z[:, 0], - dt_bias=self.dt_bias, - dt_softplus=True, - ) - else: - out, ssm_state = ssd_chunk_scan_combined( - x, - dt.reshape(bsize, length, -1), - A, - B, - C, - self.chunk_size, - D=self.D, - z=z, - dt_bias=self.dt_bias, - dt_softplus=True, - seq_idx=seq_idx, - ssm_state=ssm_state, - ) - - y = self.out_proj(out.reshape(bsize, length, -1)) - - if past_states is not None: - assert ssm_state is not None - assert conv_state is not None - past_states.update_mamba(conv_state, ssm_state, self.layer_idx) - - return y, past_states - - class MLP(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() @@ -1236,6 +1233,18 @@ class PlamoDecoderLayer(nn.Module): return outputs # type: ignore +def is_mamba(config: ModelArgs, i: int) -> bool: + if not config.mamba_enabled: + return False + assert config.mamba_step > 1 + assert i < config.num_hidden_layers + + if config.num_hidden_layers <= (config.mamba_step // 2): + # use attention in last layer + return i != config.num_hidden_layers - 1 + return (i % config.mamba_step) != (config.mamba_step // 2) + + class PlamoDecoder(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() @@ -1545,6 +1554,7 @@ class PlamoModel(PlamoPreTrainedModel): self.gradient_checkpointing, ) ) + assert isinstance(out, DecoderOutput) hidden_states = out.hidden_states all_hidden_states = out.all_hidden_states