mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 03:41:17 +08:00
small optimization
This commit is contained in:
parent
12e9f34524
commit
a4b716e65d
@ -44,10 +44,6 @@ class ModelArgs(BaseModelArgs):
|
||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||
|
||||
|
||||
def silu(x):
|
||||
return x * mx.sigmoid(x)
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||
super().__init__()
|
||||
@ -166,18 +162,18 @@ class Mamba2Block(nn.Module):
|
||||
|
||||
if cache is None:
|
||||
cache = [None, None]
|
||||
conv_state, _ = cache
|
||||
conv_state, _ = cache
|
||||
|
||||
zxBCdt = self.in_proj(u)
|
||||
|
||||
z, xBC, dt = mx.split(
|
||||
zxBCdt,
|
||||
zxBCdt,
|
||||
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
|
||||
axis=-1
|
||||
)
|
||||
|
||||
xBC, conv_state = self.conv1d(xBC, conv_state)
|
||||
xBC = silu(xBC)
|
||||
xBC =xBC * mx.sigmoid(xBC)
|
||||
xBC = xBC[:, :seq_len, :]
|
||||
|
||||
x, B, C = mx.split(
|
||||
|
Loading…
Reference in New Issue
Block a user