mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 09:51:19 +08:00
removig sanitize
This commit is contained in:
parent
dd29e74b89
commit
2462a34194
@ -24,8 +24,6 @@ class ModelArgs(BaseModelArgs):
|
||||
use_conv_bias: bool
|
||||
initializer_range: float
|
||||
residual_in_fp32: bool
|
||||
rescale_prenorm_residual: bool
|
||||
rms_norm: bool
|
||||
chunk_size: int
|
||||
tie_word_embeddings: bool
|
||||
time_step_limit: Tuple[float, float]
|
||||
@ -162,7 +160,8 @@ class Mamba2Block(nn.Module):
|
||||
|
||||
if cache is None:
|
||||
cache = [None, None]
|
||||
conv_state, _ = cache
|
||||
else:
|
||||
conv_state, ssm_state = cache
|
||||
|
||||
zxBCdt = self.in_proj(u)
|
||||
|
||||
@ -173,7 +172,7 @@ class Mamba2Block(nn.Module):
|
||||
)
|
||||
|
||||
xBC, conv_state = self.conv1d(xBC, conv_state)
|
||||
xBC =xBC * mx.sigmoid(xBC)
|
||||
xBC = xBC * mx.sigmoid(xBC)
|
||||
xBC = xBC[:, :seq_len, :]
|
||||
|
||||
x, B, C = mx.split(
|
||||
@ -187,7 +186,7 @@ class Mamba2Block(nn.Module):
|
||||
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
|
||||
|
||||
A = -mx.exp(self.A_log)
|
||||
y, next_state = ssd_forward_attn(
|
||||
y, next_ssm_state = ssd_forward_attn(
|
||||
x=x,
|
||||
dt=dt,
|
||||
A=A,
|
||||
@ -209,8 +208,7 @@ class Mamba2Block(nn.Module):
|
||||
y = self.out_proj(y)
|
||||
|
||||
cache[0] = conv_state
|
||||
cache[1] = next_state
|
||||
|
||||
cache[1] = next_ssm_state
|
||||
return y
|
||||
|
||||
|
||||
@ -267,12 +265,6 @@ class Model(nn.Module):
|
||||
logits = self.lm_head(hidden)
|
||||
|
||||
return logits
|
||||
|
||||
def sanitize(self, weights):
|
||||
for k, v in weights.items():
|
||||
if "conv1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
return weights
|
||||
|
||||
def make_cache(self):
|
||||
return [MambaCache() for _ in range(len(self.layers))]
|
||||
|
Loading…
Reference in New Issue
Block a user