From 1d851069ea2cf463afbf8c0aeb25c8cb60f404d5 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 10 Nov 2024 17:21:18 +0100 Subject: [PATCH] nits --- llms/mlx_lm/models/mamba2.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 8ea641f4..2d8f4a09 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -62,7 +62,6 @@ def silu(x): return x * mx.sigmoid(x) def ssd(x, A, B, C, chunk_size): - # Replace einsum operations with explicit reshape and matrix multiply batch, seqlen, nheads, dim = x.shape B = mx.expand_dims(B, axis=2) C = mx.expand_dims(C, axis=2) @@ -74,7 +73,6 @@ def ssd(x, A, B, C, chunk_size): chunk = slice(i, min(i + chunk_size, seqlen)) dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) - # Replace einsum with explicit operations x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] B_chunk = B[:, chunk] # [batch, chunk_size, state_size] @@ -82,7 +80,6 @@ def ssd(x, A, B, C, chunk_size): state = state * mx.expand_dims(dA, axis=-1) + dBx - # Replace einsum with explicit operations C_chunk = C[:, chunk] # [batch, chunk_size, state_size] y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]