mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 20:25:22 +08:00
small optimization
This commit is contained in:
parent
12e9f34524
commit
a4b716e65d
@ -44,10 +44,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||||
|
|
||||||
|
|
||||||
def silu(x):
|
|
||||||
return x * mx.sigmoid(x)
|
|
||||||
|
|
||||||
|
|
||||||
class DepthWiseConv1d(nn.Module):
|
class DepthWiseConv1d(nn.Module):
|
||||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -177,7 +173,7 @@ class Mamba2Block(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
xBC, conv_state = self.conv1d(xBC, conv_state)
|
xBC, conv_state = self.conv1d(xBC, conv_state)
|
||||||
xBC = silu(xBC)
|
xBC =xBC * mx.sigmoid(xBC)
|
||||||
xBC = xBC[:, :seq_len, :]
|
xBC = xBC[:, :seq_len, :]
|
||||||
|
|
||||||
x, B, C = mx.split(
|
x, B, C = mx.split(
|
||||||
|
Loading…
Reference in New Issue
Block a user