diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index e123be74..5b97f7d9 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -186,8 +186,7 @@ class Mamba2Block(nn.Module): # Update state - matches PyTorch implementation next_state = ( - next_state * mx.expand_dims(dAt, axis=(-1, -2)) + - dBx + next_state * mx.expand_dims(dAt, axis=(-1, -2)) + dBx ) # Compute output