diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 70ac70a3..b7eff756 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -127,14 +127,10 @@ class MambaBlock(nn.Module): A = -mx.exp(self.A_log) D = self.D deltaBC = self.x_proj(x) - delta, B, C = mx.split( - deltaBC, - indices_or_sections=[ - self.time_step_rank, - self.time_step_rank + self.ssm_state_size, - ], - axis=-1, - ) + delta, B, C = map(self.mixer_norm if self.use_bcdt_rms else lambda x: x, + mx.split(deltaBC, [self.time_step_rank, + self.time_step_rank + self.ssm_state_size], + axis=-1)) if self.use_bcdt_rms: delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta))