From e43ac7c90e349c340cfbc9e2b2bc07be5527c5be Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 20 Jan 2025 18:37:58 +0100 Subject: [PATCH] added mx.einsum() operations: before: 41.293 tokens-per-sec, after: 57.822 tokens-per-sec --- llms/mlx_lm/models/mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index f2414660..70ac70a3 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -138,10 +138,10 @@ class MambaBlock(nn.Module): if self.use_bcdt_rms: delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta)) - new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) + new_state = mx.einsum('bs,bs,sd->bsd', delta, x, B) if state is not None: new_state += state * mx.exp(mx.expand_dims(delta, -1) * A) - y = (new_state @ mx.expand_dims(C, -1)).squeeze(2) + y = mx.einsum('bsd,sd->bs', new_state, C) y = y + D * x return y, new_state