diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index 5935024b..1d8215dd 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -283,9 +283,11 @@ class Mamba(nn.Module): cache=None, ): bsize, length, _ = hidden_states.shape - is_update = length == 1 and cache[0] is not None - if not is_update: + if cache is not None and cache[0] is not None: + conv_state = cache[0] + ssm_state = cache[1] + else: conv_state = mx.zeros( (bsize, self.d_conv - 1, self.intermediate_size), dtype=hidden_states.dtype, @@ -294,9 +296,6 @@ class Mamba(nn.Module): (bsize, self.num_heads, self.hidden_size_per_head, self.d_state), dtype=mx.float32, ) - else: - conv_state = cache[0] - ssm_state = cache[1] zx = self.in_proj(hidden_states) zx = zx.reshape(bsize, length, self.num_heads, -1) @@ -337,8 +336,9 @@ class Mamba(nn.Module): ssm_state=ssm_state, ) - cache[0] = conv_state - cache[1] = ssm_state + if cache is not None: + cache[0] = conv_state + cache[1] = ssm_state y = self.out_proj(out.reshape(bsize, length, -1)) return y @@ -540,10 +540,10 @@ class PlamoModel(nn.Module): h = self.embed_tokens(inputs) if mask is None: - mask = create_attention_mask(h, [cache[1]]) + mask = create_attention_mask(h, [cache[1]] if cache is not None else None) if cache is None: - cache = [None] * len(self.layers) + cache = [None] * len(self.layers.layers) # decoder layers out = self.layers(