mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-29 12:51:12 +08:00
small optimization
This commit is contained in:
parent
12e9f34524
commit
a4b716e65d
@ -44,10 +44,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||||
|
|
||||||
|
|
||||||
def silu(x):
|
|
||||||
return x * mx.sigmoid(x)
|
|
||||||
|
|
||||||
|
|
||||||
class DepthWiseConv1d(nn.Module):
|
class DepthWiseConv1d(nn.Module):
|
||||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -166,18 +162,18 @@ class Mamba2Block(nn.Module):
|
|||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None, None]
|
cache = [None, None]
|
||||||
conv_state, _ = cache
|
conv_state, _ = cache
|
||||||
|
|
||||||
zxBCdt = self.in_proj(u)
|
zxBCdt = self.in_proj(u)
|
||||||
|
|
||||||
z, xBC, dt = mx.split(
|
z, xBC, dt = mx.split(
|
||||||
zxBCdt,
|
zxBCdt,
|
||||||
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
|
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
|
||||||
axis=-1
|
axis=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
xBC, conv_state = self.conv1d(xBC, conv_state)
|
xBC, conv_state = self.conv1d(xBC, conv_state)
|
||||||
xBC = silu(xBC)
|
xBC =xBC * mx.sigmoid(xBC)
|
||||||
xBC = xBC[:, :seq_len, :]
|
xBC = xBC[:, :seq_len, :]
|
||||||
|
|
||||||
x, B, C = mx.split(
|
x, B, C = mx.split(
|
||||||
|
Loading…
Reference in New Issue
Block a user