small optimization

This commit is contained in:
Goekdeniz-Guelmez 2025-01-22 00:15:02 +01:00
parent 12e9f34524
commit a4b716e65d

View File

@ -44,10 +44,6 @@ class ModelArgs(BaseModelArgs):
self.time_step_rank = math.ceil(self.hidden_size / 16) self.time_step_rank = math.ceil(self.hidden_size / 16)
def silu(x):
return x * mx.sigmoid(x)
class DepthWiseConv1d(nn.Module): class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0): def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__() super().__init__()
@ -166,18 +162,18 @@ class Mamba2Block(nn.Module):
if cache is None: if cache is None:
cache = [None, None] cache = [None, None]
conv_state, _ = cache conv_state, _ = cache
zxBCdt = self.in_proj(u) zxBCdt = self.in_proj(u)
z, xBC, dt = mx.split( z, xBC, dt = mx.split(
zxBCdt, zxBCdt,
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state], [self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
axis=-1 axis=-1
) )
xBC, conv_state = self.conv1d(xBC, conv_state) xBC, conv_state = self.conv1d(xBC, conv_state)
xBC = silu(xBC) xBC =xBC * mx.sigmoid(xBC)
xBC = xBC[:, :seq_len, :] xBC = xBC[:, :seq_len, :]
x, B, C = mx.split( x, B, C = mx.split(