mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 11:21:32 +08:00
nits
This commit is contained in:
parent
9f8a6a3509
commit
b10afe3662
@ -57,11 +57,12 @@ class MambaRMSNormGated(nn.Module):
|
|||||||
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
|
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
|
||||||
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
|
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
|
||||||
return self.weight * hidden_states
|
return self.weight * hidden_states
|
||||||
|
|
||||||
|
|
||||||
def silu(x):
|
def silu(x):
|
||||||
return x * mx.sigmoid(x)
|
return x * mx.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
def ssd(x, A, B, C, chunk_size):
|
def ssd(x, A, B, C, chunk_size):
|
||||||
batch, seqlen, nheads, dim = x.shape
|
batch, seqlen, nheads, dim = x.shape
|
||||||
B = mx.expand_dims(B, axis=2)
|
B = mx.expand_dims(B, axis=2)
|
||||||
@ -87,7 +88,7 @@ def ssd(x, A, B, C, chunk_size):
|
|||||||
outputs.append(y)
|
outputs.append(y)
|
||||||
|
|
||||||
return mx.concatenate(outputs, axis=1), state
|
return mx.concatenate(outputs, axis=1), state
|
||||||
|
|
||||||
|
|
||||||
class DepthWiseConv1d(nn.Module):
|
class DepthWiseConv1d(nn.Module):
|
||||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||||
@ -175,7 +176,7 @@ class Mamba2Block(nn.Module):
|
|||||||
|
|
||||||
# Calculate split indices and slice tensors
|
# Calculate split indices and slice tensors
|
||||||
z = proj[..., :self.d_inner]
|
z = proj[..., :self.d_inner]
|
||||||
x_conv = proj[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.d_state)]
|
x_conv = proj[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.n_groups * self.d_state)]
|
||||||
dt = proj[..., -self.n_heads:]
|
dt = proj[..., -self.n_heads:]
|
||||||
|
|
||||||
# Process time steps
|
# Process time steps
|
||||||
|
Loading…
Reference in New Issue
Block a user