This commit is contained in:
Goekdeniz-Guelmez 2024-12-10 18:15:12 +01:00
parent 9f8a6a3509
commit b10afe3662

View File

@ -62,6 +62,7 @@ class MambaRMSNormGated(nn.Module):
def silu(x):
return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size):
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
@ -175,7 +176,7 @@ class Mamba2Block(nn.Module):
# Calculate split indices and slice tensors
z = proj[..., :self.d_inner]
x_conv = proj[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.d_state)]
x_conv = proj[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.n_groups * self.d_state)]
dt = proj[..., -self.n_heads:]
# Process time steps