removing unnessesairy lines and cleaning up

This commit is contained in:
Goekdeniz-Guelmez 2025-01-21 23:06:40 +01:00
parent c13de475f6
commit 12e9f34524

View File

@ -134,8 +134,6 @@ class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args self.args = args
# Same dimensions as before
self.d_model = args.hidden_size self.d_model = args.hidden_size
self.d_state = args.state_size self.d_state = args.state_size
self.d_conv = args.conv_kernel self.d_conv = args.conv_kernel
@ -146,16 +144,13 @@ class Mamba2Block(nn.Module):
self.d_head = self.d_inner // self.n_heads self.d_head = self.d_inner // self.n_heads
self.chunk_size = args.chunk_size self.chunk_size = args.chunk_size
# Input projection
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias) self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
# Parameters
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
# Convolution
self.conv1d = DepthWiseConv1d( self.conv1d = DepthWiseConv1d(
channels=self.d_inner + 2 * self.n_groups * self.d_state, channels=self.d_inner + 2 * self.n_groups * self.d_state,
kernel_size=self.d_conv, kernel_size=self.d_conv,
@ -163,46 +158,38 @@ class Mamba2Block(nn.Module):
padding=self.d_conv-1 padding=self.d_conv-1
) )
# Output projections
self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon) self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias) self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
def __call__(self, u: mx.array, cache=None): def __call__(self, u: mx.array, cache=None):
batch_size, seq_len, _ = u.shape batch_size, seq_len, _ = u.shape
# Get or initialize states from cache
if cache is None: if cache is None:
cache = [None, None] # [conv_state, ssm_state] cache = [None, None]
conv_state, _ = cache # We ignore ssm_state as it's not used in the parallel version conv_state, _ = cache
# Project input
zxBCdt = self.in_proj(u) zxBCdt = self.in_proj(u)
# Split projections
z, xBC, dt = mx.split( z, xBC, dt = mx.split(
zxBCdt, zxBCdt,
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state], [self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
axis=-1 axis=-1
) )
# Process convolution
xBC, conv_state = self.conv1d(xBC, conv_state) xBC, conv_state = self.conv1d(xBC, conv_state)
xBC = silu(xBC) xBC = silu(xBC)
xBC = xBC[:, :seq_len, :] xBC = xBC[:, :seq_len, :]
# Split conv output
x, B, C = mx.split( x, B, C = mx.split(
xBC, xBC,
[self.d_inner, self.d_inner + self.d_state * self.n_groups], [self.d_inner, self.d_inner + self.d_state * self.n_groups],
axis=-1 axis=-1
) )
# Reshape for SSM processing
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head)) x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1)) B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1))
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1)) C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
# Process with parallel attention
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
y, next_state = ssd_forward_attn( y, next_state = ssd_forward_attn(
x=x, x=x,
@ -216,7 +203,6 @@ class Mamba2Block(nn.Module):
dt_max=self.args.time_step_max dt_max=self.args.time_step_max
) )
# Apply normalization based on norm_before_gate setting
if self.args.norm_before_gate: if self.args.norm_before_gate:
y = self.norm(y) y = self.norm(y)
y = y * nn.silu(z) y = y * nn.silu(z)
@ -224,10 +210,8 @@ class Mamba2Block(nn.Module):
y = y * nn.silu(z) y = y * nn.silu(z)
y = self.norm(y) y = self.norm(y)
# Final projection
y = self.out_proj(y) y = self.out_proj(y)
# Update cache
cache[0] = conv_state cache[0] = conv_state
cache[1] = next_state cache[1] = next_state
@ -242,8 +226,8 @@ class ResidualBlock(nn.Module):
self.norm = nn.RMSNorm(args.hidden_size) self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache): def __call__(self, x: mx.array, cache):
# if self.residual_in_fp32: if self.residual_in_fp32:
# x = x.astype(mx.float32) x = x.astype(mx.float32)
normed = self.norm(x) normed = self.norm(x)
output = self.mixer(normed, cache) output = self.mixer(normed, cache)
return output + x return output + x