small optimization

This commit is contained in:
Goekdeniz-Guelmez 2025-01-22 00:15:02 +01:00
parent 12e9f34524
commit a4b716e65d

View File

@ -44,10 +44,6 @@ class ModelArgs(BaseModelArgs):
self.time_step_rank = math.ceil(self.hidden_size / 16)
def silu(x):
return x * mx.sigmoid(x)
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
@ -166,18 +162,18 @@ class Mamba2Block(nn.Module):
if cache is None:
cache = [None, None]
conv_state, _ = cache
conv_state, _ = cache
zxBCdt = self.in_proj(u)
z, xBC, dt = mx.split(
zxBCdt,
zxBCdt,
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
axis=-1
)
xBC, conv_state = self.conv1d(xBC, conv_state)
xBC = silu(xBC)
xBC =xBC * mx.sigmoid(xBC)
xBC = xBC[:, :seq_len, :]
x, B, C = mx.split(