From a4b716e65db590d8b8346208b3cf9977c3937929 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 22 Jan 2025 00:15:02 +0100 Subject: [PATCH] small optimization --- llms/mlx_lm/models/mamba2.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 747db9e2..888fb4fa 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -44,10 +44,6 @@ class ModelArgs(BaseModelArgs): self.time_step_rank = math.ceil(self.hidden_size / 16) -def silu(x): - return x * mx.sigmoid(x) - - class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() @@ -166,18 +162,18 @@ class Mamba2Block(nn.Module): if cache is None: cache = [None, None] - conv_state, _ = cache + conv_state, _ = cache zxBCdt = self.in_proj(u) z, xBC, dt = mx.split( - zxBCdt, + zxBCdt, [self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state], axis=-1 ) xBC, conv_state = self.conv1d(xBC, conv_state) - xBC = silu(xBC) + xBC =xBC * mx.sigmoid(xBC) xBC = xBC[:, :seq_len, :] x, B, C = mx.split(