From 4e94e87f57062995236190102d2f4586869a496e Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 27 Dec 2024 15:41:54 +0100 Subject: [PATCH] nits --- llms/mlx_lm/models/mamba2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index e123be74..5b97f7d9 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -186,8 +186,7 @@ class Mamba2Block(nn.Module): # Update state - matches PyTorch implementation next_state = ( - next_state * mx.expand_dims(dAt, axis=(-1, -2)) + - dBx + next_state * mx.expand_dims(dAt, axis=(-1, -2)) + dBx ) # Compute output