mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 09:51:19 +08:00
removing unnessesairy lines and cleaning up
This commit is contained in:
parent
c13de475f6
commit
12e9f34524
@ -134,8 +134,6 @@ class Mamba2Block(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
# Same dimensions as before
|
||||
self.d_model = args.hidden_size
|
||||
self.d_state = args.state_size
|
||||
self.d_conv = args.conv_kernel
|
||||
@ -146,16 +144,13 @@ class Mamba2Block(nn.Module):
|
||||
self.d_head = self.d_inner // self.n_heads
|
||||
self.chunk_size = args.chunk_size
|
||||
|
||||
# Input projection
|
||||
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
|
||||
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
|
||||
|
||||
# Parameters
|
||||
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||
|
||||
# Convolution
|
||||
self.conv1d = DepthWiseConv1d(
|
||||
channels=self.d_inner + 2 * self.n_groups * self.d_state,
|
||||
kernel_size=self.d_conv,
|
||||
@ -163,46 +158,38 @@ class Mamba2Block(nn.Module):
|
||||
padding=self.d_conv-1
|
||||
)
|
||||
|
||||
# Output projections
|
||||
self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon)
|
||||
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
|
||||
|
||||
def __call__(self, u: mx.array, cache=None):
|
||||
batch_size, seq_len, _ = u.shape
|
||||
|
||||
# Get or initialize states from cache
|
||||
if cache is None:
|
||||
cache = [None, None] # [conv_state, ssm_state]
|
||||
conv_state, _ = cache # We ignore ssm_state as it's not used in the parallel version
|
||||
cache = [None, None]
|
||||
conv_state, _ = cache
|
||||
|
||||
# Project input
|
||||
zxBCdt = self.in_proj(u)
|
||||
|
||||
# Split projections
|
||||
z, xBC, dt = mx.split(
|
||||
zxBCdt,
|
||||
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
|
||||
axis=-1
|
||||
)
|
||||
|
||||
# Process convolution
|
||||
xBC, conv_state = self.conv1d(xBC, conv_state)
|
||||
xBC = silu(xBC)
|
||||
xBC = xBC[:, :seq_len, :]
|
||||
|
||||
# Split conv output
|
||||
x, B, C = mx.split(
|
||||
xBC,
|
||||
[self.d_inner, self.d_inner + self.d_state * self.n_groups],
|
||||
axis=-1
|
||||
)
|
||||
|
||||
# Reshape for SSM processing
|
||||
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
|
||||
B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1))
|
||||
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
|
||||
|
||||
# Process with parallel attention
|
||||
A = -mx.exp(self.A_log)
|
||||
y, next_state = ssd_forward_attn(
|
||||
x=x,
|
||||
@ -216,7 +203,6 @@ class Mamba2Block(nn.Module):
|
||||
dt_max=self.args.time_step_max
|
||||
)
|
||||
|
||||
# Apply normalization based on norm_before_gate setting
|
||||
if self.args.norm_before_gate:
|
||||
y = self.norm(y)
|
||||
y = y * nn.silu(z)
|
||||
@ -224,10 +210,8 @@ class Mamba2Block(nn.Module):
|
||||
y = y * nn.silu(z)
|
||||
y = self.norm(y)
|
||||
|
||||
# Final projection
|
||||
y = self.out_proj(y)
|
||||
|
||||
# Update cache
|
||||
cache[0] = conv_state
|
||||
cache[1] = next_state
|
||||
|
||||
@ -242,8 +226,8 @@ class ResidualBlock(nn.Module):
|
||||
self.norm = nn.RMSNorm(args.hidden_size)
|
||||
|
||||
def __call__(self, x: mx.array, cache):
|
||||
# if self.residual_in_fp32:
|
||||
# x = x.astype(mx.float32)
|
||||
if self.residual_in_fp32:
|
||||
x = x.astype(mx.float32)
|
||||
normed = self.norm(x)
|
||||
output = self.mixer(normed, cache)
|
||||
return output + x
|
||||
|
Loading…
Reference in New Issue
Block a user