From 9494a275ac1c7dc8f1eda9c1db23e98976da88a4 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 20 Jan 2025 18:39:22 +0100 Subject: [PATCH] Fused Operations in delta, B, C = ... :. Before: 57.822 tokens-per-sec, after: 83.890 tokens-per-sec --- llms/mlx_lm/models/mamba.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 70ac70a3..b7eff756 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -127,14 +127,10 @@ class MambaBlock(nn.Module): A = -mx.exp(self.A_log) D = self.D deltaBC = self.x_proj(x) - delta, B, C = mx.split( - deltaBC, - indices_or_sections=[ - self.time_step_rank, - self.time_step_rank + self.ssm_state_size, - ], - axis=-1, - ) + delta, B, C = map(self.mixer_norm if self.use_bcdt_rms else lambda x: x, + mx.split(deltaBC, [self.time_step_rank, + self.time_step_rank + self.ssm_state_size], + axis=-1)) if self.use_bcdt_rms: delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta))