mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 06:00:19 +08:00
fix
This commit is contained in:
parent
2edcc0355f
commit
7900a6c22c
@ -283,9 +283,11 @@ class Mamba(nn.Module):
|
||||
cache=None,
|
||||
):
|
||||
bsize, length, _ = hidden_states.shape
|
||||
is_update = length == 1 and cache[0] is not None
|
||||
|
||||
if not is_update:
|
||||
if cache is not None and cache[0] is not None:
|
||||
conv_state = cache[0]
|
||||
ssm_state = cache[1]
|
||||
else:
|
||||
conv_state = mx.zeros(
|
||||
(bsize, self.d_conv - 1, self.intermediate_size),
|
||||
dtype=hidden_states.dtype,
|
||||
@ -294,9 +296,6 @@ class Mamba(nn.Module):
|
||||
(bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
|
||||
dtype=mx.float32,
|
||||
)
|
||||
else:
|
||||
conv_state = cache[0]
|
||||
ssm_state = cache[1]
|
||||
|
||||
zx = self.in_proj(hidden_states)
|
||||
zx = zx.reshape(bsize, length, self.num_heads, -1)
|
||||
@ -337,8 +336,9 @@ class Mamba(nn.Module):
|
||||
ssm_state=ssm_state,
|
||||
)
|
||||
|
||||
cache[0] = conv_state
|
||||
cache[1] = ssm_state
|
||||
if cache is not None:
|
||||
cache[0] = conv_state
|
||||
cache[1] = ssm_state
|
||||
y = self.out_proj(out.reshape(bsize, length, -1))
|
||||
|
||||
return y
|
||||
@ -540,10 +540,10 @@ class PlamoModel(nn.Module):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, [cache[1]])
|
||||
mask = create_attention_mask(h, [cache[1]] if cache is not None else None)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
cache = [None] * len(self.layers.layers)
|
||||
|
||||
# decoder layers
|
||||
out = self.layers(
|
||||
|
Loading…
Reference in New Issue
Block a user