diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 37fa2092..93cc616e 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -127,60 +127,57 @@ class MambaBlock(nn.Module): D = self.D deltaBC = self.x_proj(x) delta, B, C = map( - self.mixer_norm if self.use_bcdt_rms else lambda x: x, + self.mixer_norm if self.use_bcdt_rms else lambda x: x, mx.split( deltaBC, - [ - self.time_step_rank, - self.time_step_rank + self.ssm_state_size - ], - axis=-1 - ) + [self.time_step_rank, self.time_step_rank + self.ssm_state_size], + axis=-1, + ), ) if self.use_bcdt_rms: delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta)) - new_state = mx.einsum('bs,bs,sd->bsd', delta, x, B) + new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) if state is not None: new_state += state * mx.exp(mx.expand_dims(delta, -1) * A) - y = mx.einsum('bsd,sd->bs', new_state, C) + y = (new_state @ mx.expand_dims(C, -1)).squeeze(2) y = y + D * x return y, new_state def _process_sequence(self, x, conv_cache, state_cache): B, T, D = x.shape - xz = self.in_proj(x.reshape(-1, D)).reshape(B, T, -1) - x_t, z_t = xz.split(indices_or_sections=2, axis=-1) - - conv_out, new_conv_cache = self.conv1d(x_t, conv_cache) - x_t = nn.silu(conv_out) - + xz = self.in_proj(x) + x, z = xz.split(indices_or_sections=2, axis=-1) + + conv_out, new_conv_cache = self.conv1d(x, conv_cache) + x = nn.silu(conv_out) + A = -mx.exp(self.A_log) - + outputs = [] current_state = state_cache + y = [] for t in range(T): - y_t, current_state = self.ssm_step(x_t[:, t], A, current_state) - z_curr = nn.silu(z_t[:, t]) - output_t = self.out_proj(y_t * z_curr) - outputs.append(output_t) - - return mx.stack(outputs, axis=1), (new_conv_cache, current_state) + y_t, current_state = self.ssm_step(x[:, t], A, current_state) + y.append(y_t) + y = mx.stack(y, axis=1) + z = self.out_proj(nn.silu(z) * y) + return z, (new_conv_cache, current_state) def __call__(self, x, cache): - if cache is None or isinstance(cache, list): - conv_cache, state_cache = cache if cache is not None else (None, None) + if cache is None: + conv_cache, state_cache = None, None else: - conv_cache, state_cache = cache.state - + conv_cache, state_cache = cache[0], cache[1] + output, (new_conv_cache, new_state_cache) = self._process_sequence( x, conv_cache, state_cache ) - + if isinstance(cache, MambaCache): cache[0] = new_conv_cache cache[1] = new_state_cache - + return output