mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
removig sanitize
This commit is contained in:
parent
dd29e74b89
commit
2462a34194
@ -24,8 +24,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
use_conv_bias: bool
|
use_conv_bias: bool
|
||||||
initializer_range: float
|
initializer_range: float
|
||||||
residual_in_fp32: bool
|
residual_in_fp32: bool
|
||||||
rescale_prenorm_residual: bool
|
|
||||||
rms_norm: bool
|
|
||||||
chunk_size: int
|
chunk_size: int
|
||||||
tie_word_embeddings: bool
|
tie_word_embeddings: bool
|
||||||
time_step_limit: Tuple[float, float]
|
time_step_limit: Tuple[float, float]
|
||||||
@ -162,7 +160,8 @@ class Mamba2Block(nn.Module):
|
|||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None, None]
|
cache = [None, None]
|
||||||
conv_state, _ = cache
|
else:
|
||||||
|
conv_state, ssm_state = cache
|
||||||
|
|
||||||
zxBCdt = self.in_proj(u)
|
zxBCdt = self.in_proj(u)
|
||||||
|
|
||||||
@ -187,7 +186,7 @@ class Mamba2Block(nn.Module):
|
|||||||
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
|
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
|
||||||
|
|
||||||
A = -mx.exp(self.A_log)
|
A = -mx.exp(self.A_log)
|
||||||
y, next_state = ssd_forward_attn(
|
y, next_ssm_state = ssd_forward_attn(
|
||||||
x=x,
|
x=x,
|
||||||
dt=dt,
|
dt=dt,
|
||||||
A=A,
|
A=A,
|
||||||
@ -209,8 +208,7 @@ class Mamba2Block(nn.Module):
|
|||||||
y = self.out_proj(y)
|
y = self.out_proj(y)
|
||||||
|
|
||||||
cache[0] = conv_state
|
cache[0] = conv_state
|
||||||
cache[1] = next_state
|
cache[1] = next_ssm_state
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
@ -268,12 +266,6 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def sanitize(self, weights):
|
|
||||||
for k, v in weights.items():
|
|
||||||
if "conv1d.weight" in k and v.shape[-1] != 1:
|
|
||||||
weights[k] = v.moveaxis(2, 1)
|
|
||||||
return weights
|
|
||||||
|
|
||||||
def make_cache(self):
|
def make_cache(self):
|
||||||
return [MambaCache() for _ in range(len(self.layers))]
|
return [MambaCache() for _ in range(len(self.layers))]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user