mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
adding correct initialisation of dt, A and D
This commit is contained in:
@@ -24,9 +24,6 @@ class ModelArgs(BaseModelArgs):
|
||||
use_conv_bias: bool
|
||||
initializer_range: float
|
||||
residual_in_fp32: bool
|
||||
time_step_min: float
|
||||
time_step_max: float
|
||||
time_step_floor: float
|
||||
rescale_prenorm_residual: bool
|
||||
rms_norm: bool
|
||||
chunk_size: int
|
||||
@@ -35,6 +32,11 @@ class ModelArgs(BaseModelArgs):
|
||||
intermediate_size: int = None
|
||||
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
|
||||
time_step_rank: Union[int, str] = "auto"
|
||||
time_step_min: float = 0.001
|
||||
time_step_max: float = 0.1
|
||||
time_step_floor: float = 1e-4
|
||||
A_init_min: float = 1.0
|
||||
A_init_max: float = 16.0
|
||||
|
||||
def __post_init__(self):
|
||||
if not hasattr(self, "intermediate_size"):
|
||||
@@ -93,12 +95,12 @@ class Mamba2Block(nn.Module):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
# Calculate dimensions
|
||||
# Same dimensions as before
|
||||
self.d_model = args.hidden_size
|
||||
self.d_state = args.state_size
|
||||
self.d_conv = args.conv_kernel
|
||||
self.expand = args.expand
|
||||
self.d_inner = args.intermediate_size or int(self.expand * self.d_model)
|
||||
self.d_inner = int(self.expand * self.d_model)
|
||||
self.n_groups = args.n_groups
|
||||
self.n_heads = args.num_heads
|
||||
self.d_head = self.d_inner // self.n_heads
|
||||
@@ -107,50 +109,66 @@ class Mamba2Block(nn.Module):
|
||||
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
|
||||
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
|
||||
|
||||
# Convolution
|
||||
conv_dim = self.d_inner + 2 * self.n_groups * self.d_state
|
||||
# Improved initialization of dt
|
||||
dt = mx.exp(
|
||||
mx.random.uniform(
|
||||
low=math.log(args.time_step_min),
|
||||
high=math.log(args.time_step_max),
|
||||
shape=(self.n_heads,)
|
||||
)
|
||||
)
|
||||
dt = mx.clip(dt, args.time_step_floor, float('inf'))
|
||||
inv_dt = dt + mx.log(-mx.exp(-dt) + 1) # Inverse softplus
|
||||
self.dt_bias = mx.array(inv_dt)
|
||||
|
||||
# Improved A initialization
|
||||
A = mx.random.uniform(
|
||||
low=args.A_init_min,
|
||||
high=args.A_init_max,
|
||||
shape=(self.n_heads,)
|
||||
)
|
||||
self.A_log = mx.log(A)
|
||||
|
||||
# Same D initialization
|
||||
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||
|
||||
# Convolution with proper initialization
|
||||
self.conv1d = DepthWiseConv1d(
|
||||
channels=conv_dim,
|
||||
channels=self.d_inner + 2 * self.n_groups * self.d_state,
|
||||
kernel_size=self.d_conv,
|
||||
bias=args.use_conv_bias
|
||||
bias=args.use_conv_bias,
|
||||
padding=self.d_conv-1
|
||||
)
|
||||
|
||||
# SSM parameters
|
||||
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||
|
||||
# Output projection
|
||||
# Output projections
|
||||
self.norm = MambaRMSNormGated(self.d_inner, eps=args.layer_norm_epsilon)
|
||||
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
|
||||
|
||||
if args.rescale_prenorm_residual:
|
||||
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
|
||||
self.out_proj.weight = self.out_proj.weight * layer_scale
|
||||
|
||||
def __call__(self, u: mx.array, cache=None):
|
||||
batch_size, seq_len, _ = u.shape
|
||||
|
||||
# Project input
|
||||
proj = self.in_proj(u)
|
||||
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
|
||||
|
||||
# Split projections
|
||||
z = proj[..., :self.d_inner]
|
||||
x_conv = proj[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.n_groups * self.d_state)]
|
||||
dt = proj[..., -self.n_heads:]
|
||||
z = zxbcdt[..., :self.d_inner]
|
||||
xBC = zxbcdt[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.n_groups * self.d_state)]
|
||||
dt = zxbcdt[..., -self.n_heads:]
|
||||
|
||||
# Process time steps - simplified to match PyTorch
|
||||
dt = nn.softplus(dt + self.dt_bias)
|
||||
dt = nn.softplus(dt + self.dt_bias) # (B, L, nheads)
|
||||
|
||||
x_conv, conv_state = self.conv1d(x_conv, cache[0] if cache else None)
|
||||
xBC, conv_state = self.conv1d(xBC, cache[0] if cache else None) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
||||
if cache is not None:
|
||||
cache[0] = conv_state
|
||||
x_conv = silu(x_conv)
|
||||
xBC = silu(xBC)
|
||||
|
||||
xBC = xBC[:, :seq_len, :]
|
||||
|
||||
# Split conv output and reshape
|
||||
x = x_conv[..., :self.d_inner]
|
||||
B = mx.reshape(x_conv[..., self.d_inner:self.d_inner + self.n_groups * self.d_state], (batch_size, seq_len, self.n_groups, -1))
|
||||
C = mx.reshape(x_conv[..., -self.n_groups * self.d_state:], (batch_size, seq_len, self.n_groups, -1))
|
||||
x = xBC[..., :self.d_inner]
|
||||
B = mx.reshape(xBC[..., self.d_inner:self.d_inner + self.n_groups * self.d_state], (batch_size, seq_len, self.n_groups, -1))
|
||||
C = mx.reshape(xBC[..., -self.n_groups * self.d_state:], (batch_size, seq_len, self.n_groups, -1))
|
||||
|
||||
# Reshape for SSM processing
|
||||
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
|
||||
@@ -198,7 +216,7 @@ class Mamba2Block(nn.Module):
|
||||
cache[1] = next_state
|
||||
|
||||
return mx.concatenate(outputs, axis=1)
|
||||
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
|
||||
Reference in New Issue
Block a user