mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 19:31:20 +08:00
adding multi token input and correct cache handling in ssm step
This commit is contained in:
parent
5326d9373a
commit
758597eaa8
@ -27,10 +27,10 @@ class ModelArgs(BaseModelArgs):
|
|||||||
time_step_max: float
|
time_step_max: float
|
||||||
time_step_floor: float
|
time_step_floor: float
|
||||||
rescale_prenorm_residual: bool
|
rescale_prenorm_residual: bool
|
||||||
use_cache: bool
|
|
||||||
rms_norm: bool
|
rms_norm: bool
|
||||||
chunk_size: int
|
chunk_size: int
|
||||||
tie_word_embeddings: bool
|
tie_word_embeddings: bool
|
||||||
|
use_cache: bool = True
|
||||||
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
|
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
|
||||||
time_step_rank: Union[int, str] = "auto"
|
time_step_rank: Union[int, str] = "auto"
|
||||||
model_type: str = "mamba2"
|
model_type: str = "mamba2"
|
||||||
@ -58,6 +58,29 @@ class MambaRMSNormGated(nn.Module):
|
|||||||
return self.weight * hidden_states
|
return self.weight * hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def silu(x):
|
||||||
|
return x * mx.sigmoid(x)
|
||||||
|
|
||||||
|
def ssd(x, A, B, C, chunk_size):
|
||||||
|
batch, seqlen, nheads, dim = x.shape
|
||||||
|
B = mx.expand_dims(B, axis=2)
|
||||||
|
C = mx.expand_dims(C, axis=2)
|
||||||
|
|
||||||
|
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
for i in range(0, seqlen, chunk_size):
|
||||||
|
chunk = slice(i, min(i + chunk_size, seqlen))
|
||||||
|
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
|
||||||
|
|
||||||
|
dBx = mx.einsum('blhp,bln->bhpn', x[:, chunk], B[:, chunk])
|
||||||
|
state = state * mx.expand_dims(dA, axis=-1) + dBx
|
||||||
|
y = mx.einsum('bhpn,bln->blhp', state, C[:, chunk])
|
||||||
|
outputs.append(y)
|
||||||
|
|
||||||
|
return mx.concatenate(outputs, axis=1), state
|
||||||
|
|
||||||
|
|
||||||
class DepthWiseConv1d(nn.Module):
|
class DepthWiseConv1d(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
|
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -67,127 +90,142 @@ class DepthWiseConv1d(nn.Module):
|
|||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.groups = groups if groups is not None else in_channels
|
self.groups = groups if groups is not None else in_channels
|
||||||
|
|
||||||
# Ensure in_channels and out_channels are the same for depthwise conv
|
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
|
||||||
assert in_channels == out_channels, "In and out channels must be the same for depthwise convolution"
|
|
||||||
# Ensure groups is equal to in_channels for depthwise conv
|
|
||||||
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
|
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
|
||||||
|
|
||||||
# Initialize weight with shape (out_channels, kernel_size, 1)
|
# Initialize with shape (channels, 1, kernel_size) to match pretrained weights
|
||||||
self.weight = mx.random.normal((out_channels, kernel_size, 1))
|
self.weight = mx.random.normal((in_channels, 1, kernel_size))
|
||||||
self.bias = mx.zeros((out_channels,)) if bias else None
|
self.bias = mx.zeros((out_channels,)) if bias else None
|
||||||
|
|
||||||
def __call__(self, x, cache=None):
|
def __call__(self, x: mx.array, cache=None, cache_idx: int = 0) -> mx.array:
|
||||||
B, L, C = x.shape
|
B, L, C = x.shape
|
||||||
_, K, _ = self.weight.shape
|
K = self.kernel_size
|
||||||
|
|
||||||
|
# Handle padding and caching
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
x = mx.concatenate([cache, x], axis=1)
|
conv_cache = cache[cache_idx]
|
||||||
|
if conv_cache is not None:
|
||||||
|
x = mx.concatenate([conv_cache, x], axis=1)
|
||||||
|
L = x.shape[1] # Update L after concatenation
|
||||||
else:
|
else:
|
||||||
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
pad_left = K - 1
|
||||||
|
x = mx.pad(x, [(0, 0), (pad_left, 0), (0, 0)])
|
||||||
|
L = x.shape[1] # Update L after padding
|
||||||
|
|
||||||
y = mx.conv_general(x, self.weight, groups=self.groups)
|
# Implement depthwise convolution manually for each channel
|
||||||
|
outputs = []
|
||||||
|
for c in range(C):
|
||||||
|
# Extract single channel and reshape for 1D convolution
|
||||||
|
x_c = x[:, :, c] # Shape: [B, L]
|
||||||
|
x_c = mx.expand_dims(x_c, axis=1) # Shape: [B, 1, L]
|
||||||
|
|
||||||
|
# Extract and ensure filter is 3D
|
||||||
|
w_c = self.weight[c] # Shape: [1, kernel_size] or [1, 1, kernel_size]
|
||||||
|
if w_c.ndim == 2:
|
||||||
|
w_c = mx.expand_dims(w_c, axis=0) # Shape: [1, 1, kernel_size]
|
||||||
|
elif w_c.ndim == 1:
|
||||||
|
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
|
||||||
|
|
||||||
|
# For inference mode (single token), adjust the input
|
||||||
|
if L < K:
|
||||||
|
# Pad input to match kernel size
|
||||||
|
pad_size = K - L
|
||||||
|
x_c = mx.pad(x_c, [(0, 0), (0, 0), (pad_size, 0)])
|
||||||
|
|
||||||
|
# Apply 1D convolution for this channel
|
||||||
|
y_c = mx.conv_general(
|
||||||
|
x_c,
|
||||||
|
w_c,
|
||||||
|
stride=1,
|
||||||
|
padding=0 # We've already handled padding
|
||||||
|
)
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
y = y + self.bias
|
y_c = y_c + self.bias[c]
|
||||||
|
|
||||||
return y, x[:, -K + 1 :, :]
|
outputs.append(mx.squeeze(y_c, axis=1)) # Shape: [B, 1]
|
||||||
|
|
||||||
|
# Stack all channel outputs
|
||||||
|
y = mx.stack(outputs, axis=-1) # Shape: [B, L', C]
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
# Update cache with the most recent K-1 tokens
|
||||||
|
cache[cache_idx] = x[:, -(K-1):, :] if L >= K else x
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
class Mamba2Block(nn.Module):
|
class Mamba2Block(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
self.intermediate_size = args.intermediate_size
|
|
||||||
self.time_step_rank = args.time_step_rank
|
|
||||||
self.conv_kernel_size = args.conv_kernel
|
|
||||||
self.hidden_size = args.hidden_size
|
|
||||||
self.state_size = args.state_size
|
|
||||||
self.num_heads = args.num_heads
|
|
||||||
self.head_dim = args.hidden_size // args.num_heads
|
|
||||||
self.n_groups = args.n_groups
|
|
||||||
|
|
||||||
# projection_size = 2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads
|
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
|
||||||
projection_size = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
|
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
|
||||||
self.in_proj = nn.Linear(
|
|
||||||
args.hidden_size,
|
|
||||||
projection_size,
|
|
||||||
bias=args.use_bias
|
|
||||||
)
|
|
||||||
|
|
||||||
# self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
|
conv_dim = args.intermediate_size + 2 * args.state_size
|
||||||
self.conv_dim = args.intermediate_size + 2 * args.state_size
|
|
||||||
self.conv1d = DepthWiseConv1d(
|
self.conv1d = DepthWiseConv1d(
|
||||||
in_channels=self.conv_dim,
|
in_channels=conv_dim,
|
||||||
out_channels=self.conv_dim,
|
out_channels=conv_dim,
|
||||||
kernel_size=args.conv_kernel,
|
kernel_size=args.conv_kernel,
|
||||||
|
groups=conv_dim,
|
||||||
bias=args.use_conv_bias,
|
bias=args.use_conv_bias,
|
||||||
groups=self.conv_dim,
|
|
||||||
padding=args.conv_kernel - 1
|
padding=args.conv_kernel - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
self.A_log = mx.zeros(args.num_heads)
|
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
|
||||||
self.D = mx.ones((args.num_heads,))
|
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
|
||||||
self.dt_bias = mx.zeros(args.num_heads)
|
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
|
||||||
|
|
||||||
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
|
|
||||||
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
|
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
|
||||||
|
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
|
||||||
|
|
||||||
def _ssd(self, x, A, B, C, chunk_size):
|
if args.rescale_prenorm_residual:
|
||||||
batch, seq_len, nheads, head_dim = x.shape
|
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
|
||||||
n_state = B.shape[-1]
|
self.out_proj.weight = self.out_proj.weight * layer_scale
|
||||||
|
|
||||||
h = mx.zeros((batch, nheads, head_dim, n_state))
|
def __call__(self, u: mx.array, cache = None):
|
||||||
ys = []
|
if cache is not None and self.args.use_cache:
|
||||||
|
return self.step(u, cache)
|
||||||
for i in range(0, seq_len, chunk_size):
|
|
||||||
chunk_size_i = min(chunk_size, seq_len - i)
|
|
||||||
xi = x[:, i:i + chunk_size_i]
|
|
||||||
Bi = B[:, i:i + chunk_size_i]
|
|
||||||
Ci = C[:, i:i + chunk_size_i]
|
|
||||||
|
|
||||||
for t in range(chunk_size_i):
|
|
||||||
h = h * mx.exp(A)[:, None, None]
|
|
||||||
h = h + mx.expand_dims(Bi[:, t], -2) * mx.expand_dims(xi[:, t], -1)
|
|
||||||
y = mx.sum(h * mx.expand_dims(Ci[:, t], -2), axis=-1)
|
|
||||||
ys.append(y)
|
|
||||||
|
|
||||||
y = mx.stack(ys, axis=1)
|
|
||||||
return y, h
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array, cache) -> mx.array:
|
|
||||||
if cache is not None:
|
|
||||||
return self.step(x, cache)
|
|
||||||
|
|
||||||
A = -mx.exp(self.A_log)
|
A = -mx.exp(self.A_log)
|
||||||
zxbcdt = self.in_proj(u)
|
zxbcdt = self.in_proj(u)
|
||||||
|
|
||||||
z, xBC, dt = mx.split(
|
splits = [
|
||||||
zxbcdt,
|
self.args.intermediate_size,
|
||||||
[
|
self.args.intermediate_size + 2 * self.args.state_size,
|
||||||
self.args.d_inner,
|
self.args.num_heads,
|
||||||
self.args.d_inner + 2 * self.args.d_state,
|
]
|
||||||
self.args.nheads,
|
|
||||||
],
|
z, xBC, dt = mx.split(zxbcdt, splits, axis=-1)
|
||||||
axis=-1,
|
|
||||||
|
dt = mx.clip(
|
||||||
|
nn.softplus(dt + self.dt_bias),
|
||||||
|
self.args.time_step_min,
|
||||||
|
self.args.time_step_max
|
||||||
)
|
)
|
||||||
|
|
||||||
dt = mx.softplus(dt + self.dt_bias)
|
dt = mx.maximum(dt, self.args.time_step_floor)
|
||||||
|
|
||||||
# Use the custom DepthWiseConv1d with cache
|
xBC = silu(self.conv1d(xBC))
|
||||||
xBC = self.conv1d(xBC, cache, cache_idx=0)
|
|
||||||
xBC = mx.sigmoid(xBC) * xBC # SiLU activation
|
|
||||||
|
|
||||||
x, B, C = mx.split(
|
xBC_parts = mx.split(
|
||||||
xBC,
|
xBC,
|
||||||
[self.args.d_inner, self.args.d_state, self.args.d_state],
|
[self.args.intermediate_size, self.args.state_size, self.args.state_size],
|
||||||
axis=-1
|
axis=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
x = self._reshape_heads(x, True)
|
x = xBC_parts[0]
|
||||||
B = mx.expand_dims(B, axis=2)
|
B = xBC_parts[1]
|
||||||
C = mx.expand_dims(C, axis=2)
|
C = xBC_parts[2]
|
||||||
|
|
||||||
y, ssm_state = self._ssd(
|
# Replace rearrange with reshape and transpose
|
||||||
|
b, l, hp = x.shape
|
||||||
|
h = self.args.num_heads
|
||||||
|
p = hp // h
|
||||||
|
x = mx.reshape(x, (b, l, h, p))
|
||||||
|
|
||||||
|
y, ssm_state = ssd(
|
||||||
x * mx.expand_dims(dt, -1),
|
x * mx.expand_dims(dt, -1),
|
||||||
A * dt,
|
A * dt,
|
||||||
B,
|
B,
|
||||||
@ -196,61 +234,127 @@ class Mamba2Block(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
y = y + x * mx.expand_dims(self.D, -1)
|
y = y + x * mx.expand_dims(self.D, -1)
|
||||||
y = self._reshape_heads(y, False)
|
# Replace rearrange with reshape
|
||||||
y = self.norm(y, z)
|
y = mx.reshape(y, (b, l, h * p))
|
||||||
|
|
||||||
|
y = self.norm(y + z)
|
||||||
y = self.out_proj(y)
|
y = self.out_proj(y)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None and self.args.use_cache:
|
||||||
cache[1] = ssm_state
|
cache[1] = ssm_state
|
||||||
|
|
||||||
|
if self.args.residual_in_fp32:
|
||||||
|
y = mx.cast(y, mx.float32)
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def step(self, x: mx.array, cache) -> mx.array:
|
def step(self, u: mx.array, cache: MambaCache):
|
||||||
"""Single inference step"""
|
batch_size = u.shape[0]
|
||||||
assert x.shape[1] == 1, "Only one token can be decoded per inference step"
|
seq_len = u.shape[1]
|
||||||
|
outputs = []
|
||||||
|
|
||||||
zxbcdt = self.in_proj(mx.squeeze(x, 1))
|
# Initialize SSM state if needed
|
||||||
z, xBC, dt = mx.split(
|
if cache[1] is None:
|
||||||
zxbcdt,
|
cache[1] = mx.zeros((
|
||||||
[
|
batch_size,
|
||||||
self.args.d_inner,
|
self.args.num_heads,
|
||||||
self.args.d_inner + 2 * self.args.d_state,
|
self.args.head_dim,
|
||||||
self.args.nheads,
|
self.args.state_size
|
||||||
],
|
))
|
||||||
axis=-1,
|
|
||||||
|
for pos in range(seq_len):
|
||||||
|
# Get single token
|
||||||
|
u_t = u[:, pos:pos+1, :]
|
||||||
|
|
||||||
|
# Project input
|
||||||
|
zxbcdt = self.in_proj(u_t)
|
||||||
|
|
||||||
|
# Calculate sizes
|
||||||
|
d_model = self.args.intermediate_size
|
||||||
|
d_state = self.args.state_size
|
||||||
|
n_heads = self.args.num_heads
|
||||||
|
d_head = self.args.head_dim
|
||||||
|
|
||||||
|
# Correct splits for z, xBC, dt
|
||||||
|
splits = [
|
||||||
|
d_model, # z size
|
||||||
|
d_model + 2 * d_state, # xBC size (delta, B, C)
|
||||||
|
n_heads # dt size
|
||||||
|
]
|
||||||
|
|
||||||
|
# Split the projected input
|
||||||
|
z = zxbcdt[:, :, :splits[0]]
|
||||||
|
xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]]
|
||||||
|
dt = zxbcdt[:, :, -splits[2]:] # Take last n_heads elements
|
||||||
|
|
||||||
|
# Process dt
|
||||||
|
dt = mx.reshape(dt, (batch_size, n_heads))
|
||||||
|
dt = mx.clip(
|
||||||
|
nn.softplus(dt + self.dt_bias),
|
||||||
|
self.args.time_step_min,
|
||||||
|
self.args.time_step_max
|
||||||
)
|
)
|
||||||
|
dt = mx.maximum(dt, self.args.time_step_floor)
|
||||||
|
|
||||||
# Use the custom DepthWiseConv1d with cache
|
# Process convolution
|
||||||
xBC = self.conv1d(xBC, cache, cache_idx=0)
|
xBC = self.conv1d(xBC, cache=cache, cache_idx=0)
|
||||||
xBC = mx.sigmoid(xBC) * xBC # SiLU activation
|
xBC = silu(xBC)
|
||||||
|
|
||||||
x, B, C = mx.split(
|
# Split convolved xBC into x, B, C
|
||||||
xBC,
|
x = xBC[:, :, :d_model]
|
||||||
[self.args.d_inner, self.args.d_state, self.args.d_state],
|
B = xBC[:, :, d_model:d_model + d_state]
|
||||||
axis=-1
|
C = xBC[:, :, -d_state:]
|
||||||
)
|
|
||||||
|
# Reshape x into (batch, heads, dim)
|
||||||
|
x = mx.reshape(x, (batch_size, 1, n_heads, d_head))
|
||||||
|
x = mx.squeeze(x, axis=1) # (batch, heads, dim)
|
||||||
|
|
||||||
|
# Reshape B into (batch, heads, dim, state)
|
||||||
|
B = mx.reshape(B, (batch_size, 1, d_state))
|
||||||
|
B = mx.broadcast_to(B, (batch_size, n_heads, d_state))
|
||||||
|
B = mx.expand_dims(B, axis=2) # (batch, heads, 1, state)
|
||||||
|
|
||||||
|
# Reshape C for later use
|
||||||
|
C = mx.reshape(C, (batch_size, 1, d_state))
|
||||||
|
C = mx.broadcast_to(C, (batch_size, n_heads, d_state))
|
||||||
|
C = mx.expand_dims(C, axis=3) # (batch, heads, state, 1)
|
||||||
|
|
||||||
|
# Compute SSM updates
|
||||||
A = -mx.exp(self.A_log)
|
A = -mx.exp(self.A_log)
|
||||||
|
dA = mx.exp(dt * mx.expand_dims(A, 0))
|
||||||
|
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) # (batch, heads, 1, 1)
|
||||||
|
|
||||||
dt = mx.softplus(dt + self.dt_bias)
|
# Prepare x for Bx computation
|
||||||
dA = mx.exp(dt * A)
|
x = mx.expand_dims(x, axis=3) # (batch, heads, dim, 1)
|
||||||
|
|
||||||
x = mx.reshape(x, (-1, self.args.nheads, self.args.headdim))
|
# Compute dBx with proper broadcasting
|
||||||
|
dBx = mx.matmul(x, B) # (batch, heads, dim, state)
|
||||||
|
|
||||||
ssm_state = cache[1]
|
# Update state
|
||||||
dBx = mx.expand_dims(dt, -1) * mx.expand_dims(B, 1) * mx.expand_dims(x, -1)
|
ssm_state = cache[1] # (batch, heads, dim, state)
|
||||||
ssm_state = ssm_state * mx.expand_dims(mx.expand_dims(dA, -1), -1) + dBx
|
ssm_state = ssm_state * dA + dBx
|
||||||
|
|
||||||
y = mx.sum(ssm_state * mx.expand_dims(mx.expand_dims(C, 1), 1), axis=-1)
|
|
||||||
y = y + mx.expand_dims(self.D, -1) * x
|
|
||||||
y = mx.reshape(y, (-1, self.args.nheads * self.args.headdim))
|
|
||||||
|
|
||||||
y = self.norm(y, z)
|
|
||||||
y = self.out_proj(y)
|
|
||||||
|
|
||||||
# Update SSM state in cache
|
|
||||||
cache[1] = ssm_state
|
cache[1] = ssm_state
|
||||||
|
|
||||||
return mx.expand_dims(y, 1)
|
# Compute output
|
||||||
|
y = mx.matmul(ssm_state, C) # (batch, heads, dim, 1)
|
||||||
|
y = mx.squeeze(y, axis=-1) # (batch, heads, dim)
|
||||||
|
|
||||||
|
# Add skip connection with D
|
||||||
|
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
|
||||||
|
|
||||||
|
# Reshape to original dimensions
|
||||||
|
y = mx.reshape(y, (batch_size, 1, n_heads * d_head))
|
||||||
|
|
||||||
|
# Apply norm and output projection
|
||||||
|
y = self.norm(y + z)
|
||||||
|
y = self.out_proj(y)
|
||||||
|
|
||||||
|
if self.args.residual_in_fp32:
|
||||||
|
y.astype(mx.float32)
|
||||||
|
|
||||||
|
outputs.append(y)
|
||||||
|
|
||||||
|
return mx.concatenate(outputs, axis=1)
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
@ -287,7 +391,6 @@ class Model(nn.Module):
|
|||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
|
|
||||||
self.backbone = Mamba2(args)
|
self.backbone = Mamba2(args)
|
||||||
# self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
|
||||||
|
|
||||||
if not args.tie_word_embeddings:
|
if not args.tie_word_embeddings:
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
@ -302,17 +405,26 @@ class Model(nn.Module):
|
|||||||
else:
|
else:
|
||||||
logits = self.lm_head(x)
|
logits = self.lm_head(x)
|
||||||
|
|
||||||
|
print('ouput')
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def sanitize(self, weights):
|
|
||||||
for k, v in weights.items():
|
|
||||||
if "conv1d.weight" in k and v.ndim == 3:
|
|
||||||
weights[k] = v.moveaxis(2, 1)
|
|
||||||
return weights
|
|
||||||
|
|
||||||
def make_cache(self):
|
def make_cache(self):
|
||||||
return [MambaCache() for _ in range(len(self.layers))]
|
return [MambaCache() for _ in range(len(self.layers))]
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
sanitized = {}
|
||||||
|
for k, v in weights.items():
|
||||||
|
if "conv1d.weight" in k:
|
||||||
|
# Ensure weights are in correct shape (channels, 1, kernel_size)
|
||||||
|
if v.ndim == 2:
|
||||||
|
v = mx.expand_dims(v, axis=1)
|
||||||
|
elif v.ndim == 1:
|
||||||
|
v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0)
|
||||||
|
sanitized[k] = v
|
||||||
|
else:
|
||||||
|
sanitized[k] = v
|
||||||
|
return sanitized
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.backbone.layers
|
return self.backbone.layers
|
||||||
|
@ -146,6 +146,8 @@ def linear_to_lora_layers(
|
|||||||
elif model.model_type == "mamba2":
|
elif model.model_type == "mamba2":
|
||||||
keys = set(
|
keys = set(
|
||||||
[
|
[
|
||||||
|
"mixer.in_proj",
|
||||||
|
"mixer.out_proj",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user