diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index a017d8de..7182fb69 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -157,11 +157,12 @@ class Mamba2Mixer(nn.Module): B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1) print(f"ssm_step split shapes - B: {B.shape}, C: {C.shape}") - B = B.reshape(-1, self.n_groups, self.state_size) - C = C.reshape(-1, self.n_groups, self.state_size) + batch_size = B.shape[0] + B = B.reshape(batch_size, self.n_groups, self.state_size) + C = C.reshape(batch_size, -1, self.state_size) print(f"After reshape - B: {B.shape}, C: {C.shape}") - delta = delta.reshape(-1, self.num_heads, 1) + delta = delta.reshape(batch_size, self.num_heads, 1) A = A.reshape(1, self.num_heads, 1) if state is None: @@ -170,7 +171,7 @@ class Mamba2Mixer(nn.Module): new_state = delta * (B + state * mx.exp(delta * A)) print(f"Before final computation - new_state: {new_state.shape}, C: {C.shape}") - y = mx.sum(new_state * C, axis=-1) + y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2)) y = y + D * x[:, :self.num_heads] print(f"ssm_step output shape - y: {y.shape}") return y, new_state