diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index b7eff756..5c09c999 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -123,14 +123,20 @@ class MambaBlock(nn.Module): self.intermediate_size, self.hidden_size, bias=args.use_bias ) - def ssm_step(self, x, state=None): - A = -mx.exp(self.A_log) + def ssm_step(self, x, A, state=None): D = self.D deltaBC = self.x_proj(x) - delta, B, C = map(self.mixer_norm if self.use_bcdt_rms else lambda x: x, - mx.split(deltaBC, [self.time_step_rank, - self.time_step_rank + self.ssm_state_size], - axis=-1)) + delta, B, C = map( + self.mixer_norm if self.use_bcdt_rms else lambda x: x, + mx.split( + deltaBC, + [ + self.time_step_rank, + self.time_step_rank + self.ssm_state_size + ], + axis=-1 + ) + ) if self.use_bcdt_rms: delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta)) @@ -143,6 +149,9 @@ class MambaBlock(nn.Module): def __call__(self, x, cache): B, T, D = x.shape + + A = -mx.exp(self.A_log) + if cache is None: cache = [None, None] @@ -154,7 +163,7 @@ class MambaBlock(nn.Module): conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) x_t = conv_out.squeeze(1) x_t = nn.silu(x_t) - y_t, cache[1] = self.ssm_step(x_t, cache[1]) + y_t, cache[1] = self.ssm_step(x_t, A, cache[1]) z_t = nn.silu(z_t) output_t = y_t * z_t output_t = self.out_proj(output_t)