From cd036ccfb535e349c1e176e9bb497b7887a95ce7 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 16 Oct 2024 21:13:36 +0200 Subject: [PATCH] fix generation works too (almost) --- llms/mlx_lm/models/mamba2.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index a017d8de..7182fb69 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -157,11 +157,12 @@ class Mamba2Mixer(nn.Module): B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1) print(f"ssm_step split shapes - B: {B.shape}, C: {C.shape}") - B = B.reshape(-1, self.n_groups, self.state_size) - C = C.reshape(-1, self.n_groups, self.state_size) + batch_size = B.shape[0] + B = B.reshape(batch_size, self.n_groups, self.state_size) + C = C.reshape(batch_size, -1, self.state_size) print(f"After reshape - B: {B.shape}, C: {C.shape}") - delta = delta.reshape(-1, self.num_heads, 1) + delta = delta.reshape(batch_size, self.num_heads, 1) A = A.reshape(1, self.num_heads, 1) if state is None: @@ -170,7 +171,7 @@ class Mamba2Mixer(nn.Module): new_state = delta * (B + state * mx.exp(delta * A)) print(f"Before final computation - new_state: {new_state.shape}, C: {C.shape}") - y = mx.sum(new_state * C, axis=-1) + y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2)) y = y + D * x[:, :self.num_heads] print(f"ssm_step output shape - y: {y.shape}") return y, new_state