mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 15:08:37 +08:00
fix
This commit is contained in:
parent
2edcc0355f
commit
7900a6c22c
@ -283,9 +283,11 @@ class Mamba(nn.Module):
|
|||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
bsize, length, _ = hidden_states.shape
|
bsize, length, _ = hidden_states.shape
|
||||||
is_update = length == 1 and cache[0] is not None
|
|
||||||
|
|
||||||
if not is_update:
|
if cache is not None and cache[0] is not None:
|
||||||
|
conv_state = cache[0]
|
||||||
|
ssm_state = cache[1]
|
||||||
|
else:
|
||||||
conv_state = mx.zeros(
|
conv_state = mx.zeros(
|
||||||
(bsize, self.d_conv - 1, self.intermediate_size),
|
(bsize, self.d_conv - 1, self.intermediate_size),
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
@ -294,9 +296,6 @@ class Mamba(nn.Module):
|
|||||||
(bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
|
(bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
|
||||||
dtype=mx.float32,
|
dtype=mx.float32,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
conv_state = cache[0]
|
|
||||||
ssm_state = cache[1]
|
|
||||||
|
|
||||||
zx = self.in_proj(hidden_states)
|
zx = self.in_proj(hidden_states)
|
||||||
zx = zx.reshape(bsize, length, self.num_heads, -1)
|
zx = zx.reshape(bsize, length, self.num_heads, -1)
|
||||||
@ -337,8 +336,9 @@ class Mamba(nn.Module):
|
|||||||
ssm_state=ssm_state,
|
ssm_state=ssm_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
cache[0] = conv_state
|
if cache is not None:
|
||||||
cache[1] = ssm_state
|
cache[0] = conv_state
|
||||||
|
cache[1] = ssm_state
|
||||||
y = self.out_proj(out.reshape(bsize, length, -1))
|
y = self.out_proj(out.reshape(bsize, length, -1))
|
||||||
|
|
||||||
return y
|
return y
|
||||||
@ -540,10 +540,10 @@ class PlamoModel(nn.Module):
|
|||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = create_attention_mask(h, [cache[1]])
|
mask = create_attention_mask(h, [cache[1]] if cache is not None else None)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers.layers)
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
out = self.layers(
|
out = self.layers(
|
||||||
|
Loading…
Reference in New Issue
Block a user