removig sanitize

This commit is contained in:
Goekdeniz-Guelmez 2025-01-22 22:30:15 +01:00
parent dd29e74b89
commit 2462a34194

View File

@ -24,8 +24,6 @@ class ModelArgs(BaseModelArgs):
use_conv_bias: bool use_conv_bias: bool
initializer_range: float initializer_range: float
residual_in_fp32: bool residual_in_fp32: bool
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int chunk_size: int
tie_word_embeddings: bool tie_word_embeddings: bool
time_step_limit: Tuple[float, float] time_step_limit: Tuple[float, float]
@ -162,7 +160,8 @@ class Mamba2Block(nn.Module):
if cache is None: if cache is None:
cache = [None, None] cache = [None, None]
conv_state, _ = cache else:
conv_state, ssm_state = cache
zxBCdt = self.in_proj(u) zxBCdt = self.in_proj(u)
@ -173,7 +172,7 @@ class Mamba2Block(nn.Module):
) )
xBC, conv_state = self.conv1d(xBC, conv_state) xBC, conv_state = self.conv1d(xBC, conv_state)
xBC =xBC * mx.sigmoid(xBC) xBC = xBC * mx.sigmoid(xBC)
xBC = xBC[:, :seq_len, :] xBC = xBC[:, :seq_len, :]
x, B, C = mx.split( x, B, C = mx.split(
@ -187,7 +186,7 @@ class Mamba2Block(nn.Module):
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1)) C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
y, next_state = ssd_forward_attn( y, next_ssm_state = ssd_forward_attn(
x=x, x=x,
dt=dt, dt=dt,
A=A, A=A,
@ -209,8 +208,7 @@ class Mamba2Block(nn.Module):
y = self.out_proj(y) y = self.out_proj(y)
cache[0] = conv_state cache[0] = conv_state
cache[1] = next_state cache[1] = next_ssm_state
return y return y
@ -268,12 +266,6 @@ class Model(nn.Module):
return logits return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self): def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))] return [MambaCache() for _ in range(len(self.layers))]