adding correct initialisation of dt, A and D

This commit is contained in:
Goekdeniz-Guelmez
2025-01-13 21:28:43 +01:00
parent 5509ef8e52
commit dd4957f3da
2 changed files with 319 additions and 30 deletions

View File

@@ -24,9 +24,6 @@ class ModelArgs(BaseModelArgs):
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
@@ -35,6 +32,11 @@ class ModelArgs(BaseModelArgs):
intermediate_size: int = None
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
time_step_min: float = 0.001
time_step_max: float = 0.1
time_step_floor: float = 1e-4
A_init_min: float = 1.0
A_init_max: float = 16.0
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
@@ -93,12 +95,12 @@ class Mamba2Block(nn.Module):
super().__init__()
self.args = args
# Calculate dimensions
# Same dimensions as before
self.d_model = args.hidden_size
self.d_state = args.state_size
self.d_conv = args.conv_kernel
self.expand = args.expand
self.d_inner = args.intermediate_size or int(self.expand * self.d_model)
self.d_inner = int(self.expand * self.d_model)
self.n_groups = args.n_groups
self.n_heads = args.num_heads
self.d_head = self.d_inner // self.n_heads
@@ -107,50 +109,66 @@ class Mamba2Block(nn.Module):
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
# Convolution
conv_dim = self.d_inner + 2 * self.n_groups * self.d_state
# Improved initialization of dt
dt = mx.exp(
mx.random.uniform(
low=math.log(args.time_step_min),
high=math.log(args.time_step_max),
shape=(self.n_heads,)
)
)
dt = mx.clip(dt, args.time_step_floor, float('inf'))
inv_dt = dt + mx.log(-mx.exp(-dt) + 1) # Inverse softplus
self.dt_bias = mx.array(inv_dt)
# Improved A initialization
A = mx.random.uniform(
low=args.A_init_min,
high=args.A_init_max,
shape=(self.n_heads,)
)
self.A_log = mx.log(A)
# Same D initialization
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
# Convolution with proper initialization
self.conv1d = DepthWiseConv1d(
channels=conv_dim,
channels=self.d_inner + 2 * self.n_groups * self.d_state,
kernel_size=self.d_conv,
bias=args.use_conv_bias
bias=args.use_conv_bias,
padding=self.d_conv-1
)
# SSM parameters
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
# Output projection
# Output projections
self.norm = MambaRMSNormGated(self.d_inner, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache=None):
batch_size, seq_len, _ = u.shape
# Project input
proj = self.in_proj(u)
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
# Split projections
z = proj[..., :self.d_inner]
x_conv = proj[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.n_groups * self.d_state)]
dt = proj[..., -self.n_heads:]
z = zxbcdt[..., :self.d_inner]
xBC = zxbcdt[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.n_groups * self.d_state)]
dt = zxbcdt[..., -self.n_heads:]
# Process time steps - simplified to match PyTorch
dt = nn.softplus(dt + self.dt_bias)
dt = nn.softplus(dt + self.dt_bias) # (B, L, nheads)
x_conv, conv_state = self.conv1d(x_conv, cache[0] if cache else None)
xBC, conv_state = self.conv1d(xBC, cache[0] if cache else None) # (B, L, self.d_inner + 2 * ngroups * d_state)
if cache is not None:
cache[0] = conv_state
x_conv = silu(x_conv)
xBC = silu(xBC)
xBC = xBC[:, :seq_len, :]
# Split conv output and reshape
x = x_conv[..., :self.d_inner]
B = mx.reshape(x_conv[..., self.d_inner:self.d_inner + self.n_groups * self.d_state], (batch_size, seq_len, self.n_groups, -1))
C = mx.reshape(x_conv[..., -self.n_groups * self.d_state:], (batch_size, seq_len, self.n_groups, -1))
x = xBC[..., :self.d_inner]
B = mx.reshape(xBC[..., self.d_inner:self.d_inner + self.n_groups * self.d_state], (batch_size, seq_len, self.n_groups, -1))
C = mx.reshape(xBC[..., -self.n_groups * self.d_state:], (batch_size, seq_len, self.n_groups, -1))
# Reshape for SSM processing
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
@@ -198,7 +216,7 @@ class Mamba2Block(nn.Module):
cache[1] = next_state
return mx.concatenate(outputs, axis=1)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):