This commit is contained in:
Awni Hannun 2025-02-24 09:10:24 -08:00
parent 2edcc0355f
commit 7900a6c22c

View File

@ -283,9 +283,11 @@ class Mamba(nn.Module):
cache=None,
):
bsize, length, _ = hidden_states.shape
is_update = length == 1 and cache[0] is not None
if not is_update:
if cache is not None and cache[0] is not None:
conv_state = cache[0]
ssm_state = cache[1]
else:
conv_state = mx.zeros(
(bsize, self.d_conv - 1, self.intermediate_size),
dtype=hidden_states.dtype,
@ -294,9 +296,6 @@ class Mamba(nn.Module):
(bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
dtype=mx.float32,
)
else:
conv_state = cache[0]
ssm_state = cache[1]
zx = self.in_proj(hidden_states)
zx = zx.reshape(bsize, length, self.num_heads, -1)
@ -337,8 +336,9 @@ class Mamba(nn.Module):
ssm_state=ssm_state,
)
cache[0] = conv_state
cache[1] = ssm_state
if cache is not None:
cache[0] = conv_state
cache[1] = ssm_state
y = self.out_proj(out.reshape(bsize, length, -1))
return y
@ -540,10 +540,10 @@ class PlamoModel(nn.Module):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, [cache[1]])
mask = create_attention_mask(h, [cache[1]] if cache is not None else None)
if cache is None:
cache = [None] * len(self.layers)
cache = [None] * len(self.layers.layers)
# decoder layers
out = self.layers(