diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index f2414660..70ac70a3 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -138,10 +138,10 @@ class MambaBlock(nn.Module): if self.use_bcdt_rms: delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta)) - new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) + new_state = mx.einsum('bs,bs,sd->bsd', delta, x, B) if state is not None: new_state += state * mx.exp(mx.expand_dims(delta, -1) * A) - y = (new_state @ mx.expand_dims(C, -1)).squeeze(2) + y = mx.einsum('bsd,sd->bs', new_state, C) y = y + D * x return y, new_state