mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 03:41:17 +08:00
updates
This commit is contained in:
parent
61fad00892
commit
932b196b48
@ -42,6 +42,29 @@ class ModelArgs(BaseModelArgs):
|
||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = padding
|
||||
self.weight = mx.random.normal((channels, kernel_size, 1))
|
||||
self.bias = mx.zeros((channels,)) if bias else None
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
B, L, C = x.shape
|
||||
_, K, _ = self.weight.shape
|
||||
|
||||
if cache is not None:
|
||||
x = mx.concatenate([cache, x], axis=1)
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||
|
||||
y = mx.conv_general(x, self.weight, groups=C)
|
||||
y = y + self.bias
|
||||
return y, x[:, -K + 1:, :]
|
||||
|
||||
|
||||
def ssd_forward_attn(
|
||||
x: mx.array,
|
||||
dt: mx.array,
|
||||
@ -144,13 +167,11 @@ class Mamba2Block(nn.Module):
|
||||
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||
|
||||
conv_channels = self.d_inner + 2 * self.n_groups * self.d_state
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=conv_channels,
|
||||
out_channels=conv_channels,
|
||||
self.conv1d = DepthWiseConv1d(
|
||||
channels=self.d_inner + 2 * self.n_groups * self.d_state,
|
||||
kernel_size=self.d_conv,
|
||||
groups=conv_channels,
|
||||
padding=self.d_conv - 1,
|
||||
bias=args.use_conv_bias,
|
||||
padding=self.d_conv-1
|
||||
)
|
||||
|
||||
self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon)
|
||||
@ -172,35 +193,9 @@ class Mamba2Block(nn.Module):
|
||||
axis=-1
|
||||
)
|
||||
|
||||
# Handle convolution with caching
|
||||
xBC = mx.swapaxes(xBC, 1, 2) # [B, L, C] -> [B, C, L]
|
||||
|
||||
if conv_state is not None and seq_len > 0:
|
||||
# Concatenate cached state with current input
|
||||
xBC_with_cache = mx.concatenate([conv_state, xBC], axis=2)
|
||||
elif seq_len > 0:
|
||||
# For the first call, pad with zeros
|
||||
padding = mx.zeros((batch_size, xBC.shape[1], self.d_conv - 1))
|
||||
xBC_with_cache = mx.concatenate([padding, xBC], axis=2)
|
||||
else:
|
||||
xBC_with_cache = conv_state if conv_state is not None else mx.zeros((batch_size, xBC.shape[1], 0))
|
||||
|
||||
# Save state for next iteration
|
||||
if seq_len > 0:
|
||||
next_conv_state = xBC_with_cache[:, :, -(self.d_conv - 1):]
|
||||
else:
|
||||
next_conv_state = conv_state
|
||||
|
||||
# Apply regular convolution using nn.Conv1d
|
||||
if seq_len > 0:
|
||||
# Use the standard Conv1d module for the actual computation
|
||||
xBC_conv = self.conv1d(xBC_with_cache)
|
||||
xBC = xBC_conv[:, :, -seq_len:] # Take only the relevant output positions
|
||||
xBC = mx.swapaxes(xBC, 1, 2) # [B, C, L] -> [B, L, C]
|
||||
xBC, conv_state = self.conv1d(xBC, conv_state)
|
||||
xBC = xBC * mx.sigmoid(xBC)
|
||||
else:
|
||||
# Handle empty sequence case
|
||||
xBC = mx.swapaxes(xBC, 1, 2) # [B, C, L] -> [B, L, C]
|
||||
xBC = xBC[:, :seq_len, :]
|
||||
|
||||
x, B, C = mx.split(
|
||||
xBC,
|
||||
@ -212,6 +207,7 @@ class Mamba2Block(nn.Module):
|
||||
B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1))
|
||||
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
|
||||
|
||||
A = -mx.exp(self.A_log)
|
||||
y, next_ssm_state = ssd_forward_attn(
|
||||
x=x,
|
||||
dt=dt,
|
||||
@ -234,7 +230,7 @@ class Mamba2Block(nn.Module):
|
||||
|
||||
y = self.out_proj(y)
|
||||
|
||||
cache[0] = next_conv_state
|
||||
cache[0] = conv_state
|
||||
cache[1] = next_ssm_state
|
||||
return y
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user