fixing loading the model

This commit is contained in:
Goekdeniz-Guelmez 2024-10-11 20:53:29 +02:00
parent 264ba43707
commit 4e1236cbf6
4 changed files with 644 additions and 97 deletions

1
llms/mamba2-130m-hf Submodule

@ -0,0 +1 @@
Subproject commit 05e8773fc4ac1cd067e8a18a5c45372ce5178405

View File

@ -0,0 +1,256 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "mamba2"
num_heads: int = 128
head_dim: int = 64
vocab_size: int = 32768
hidden_size: int = 4096
state_size: int = 128
num_hidden_layers: int = 64
layer_norm_epsilon: float = 1e-5
pad_token_id: int = 1
bos_token_id: int = 0
eos_token_id: int = 2
expand: int = 2
conv_kernel: int = 4
n_groups: int = 8
use_bias: bool = False
use_conv_bias: bool = True
hidden_act: str = "silu"
initializer_range: float = 0.1
residual_in_fp32: bool = True
time_step_rank: Union[int, str] = "auto"
time_step_min: float = 0.001
time_step_max: float = 0.1
time_step_floor: float = 1e-4
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
rescale_prenorm_residual: bool = False
use_cache: bool = True
rms_norm: bool = True
chunk_size: int = 256
tie_word_embeddings: bool = False
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class Mamba2Cache:
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, groups=1, padding=0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups
self.weight = mx.random.normal((self.channels, kernel_size, 1))
self.bias = mx.zeros((channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
_, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=self.groups)
if self.bias is not None:
y = y + self.bias
return y, x[:, -K + 1 :, :]
class Mamba2Mixer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.intermediate_size = args.intermediate_size
self.time_step_rank = args.time_step_rank
self.conv_kernel_size = args.conv_kernel
self.hidden_size = args.hidden_size
self.state_size = args.state_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.n_groups = args.n_groups
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = DepthWiseConv1d(
channels=self.conv_dim,
kernel_size=self.conv_kernel_size,
bias=self.args.use_conv_bias,
groups=self.conv_dim,
padding=self.conv_kernel_size - 1,
)
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(
self.hidden_size,
projection_size,
bias=args.use_bias
)
self.act = nn.SiLU()
self.dt_bias = mx.ones((self.num_heads,))
self.A_log = mx.log(mx.arange(1, self.num_heads + 1))
self.D = mx.ones((self.num_heads,))
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
def ssm_step(self, x, state=None):
A = -mx.exp(self.A_log)
D = self.D
deltaBC = self.x_proj(x)
delta, B, C = mx.split(
deltaBC,
indices_or_sections=[
self.time_step_rank,
self.time_step_rank + self.ssm_state_size,
],
axis=-1,
)
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None:
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
y = y + D * x
return y, new_state
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
outputs = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1])
z_t = nn.silu(z_t)
output_t = y_t * z_t
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
return output
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Mixer(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [Mamba2Block(args) for idx in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
cache=None
):
hidden_states = self.embeddings(inputs)
if cache is None:
cache = Mamba2Cache(len(self.layers))
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, cache[i])
hidden_states = self.norm_f(hidden_states)
return hidden_states
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self, batch_size: int = 1):
return [Mamba2Cache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers

View File

@ -0,0 +1,275 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass, field
from typing import Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "mamba2"
num_heads: int = 128
head_dim: int = 64
vocab_size: int = 32768
hidden_size: int = 4096
state_size: int = 128
num_hidden_layers: int = 64
layer_norm_epsilon: float = 1e-5
expand: int = 2
conv_kernel: int = 4
n_groups: int = 8
use_bias: bool = False
use_conv_bias: bool = True
initializer_range: float = 0.1
residual_in_fp32: bool = True
time_step_rank: Union[int, str] = "auto"
time_step_min: float = 0.001
time_step_max: float = 0.1
time_step_floor: float = 1e-4
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
rescale_prenorm_residual: bool = False
use_cache: bool = True
rms_norm: bool = True
chunk_size: int = 256
tie_word_embeddings: bool = False
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class Mamba2Cache:
def __init__(self, num_layers):
self.cache = [[None, None] for _ in range(num_layers)]
def __getitem__(self, idx):
return self.cache[idx]
def __setitem__(self, idx, value):
self.cache[idx] = value
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class Mamba2Mixer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.intermediate_size = args.intermediate_size
self.time_step_rank = args.time_step_rank
self.conv_kernel_size = args.conv_kernel
self.hidden_size = args.hidden_size
self.state_size = args.state_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.n_groups = args.n_groups
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=args.use_conv_bias,
kernel_size=args.conv_kernel,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(
self.hidden_size,
projection_size,
bias=args.use_bias
)
self.act = nn.SiLU()
self.dt_bias = mx.ones((self.num_heads,))
self.A_log = mx.log(mx.arange(1, self.num_heads + 1))
self.D = mx.ones((self.num_heads,))
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
def ssm_step(self, x, dt, state):
B, L, C = x.shape
print(f"x shape: {x.shape}")
projected_states = self.in_proj(x)
print(f"deltaBC shape: {projected_states.shape}")
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.state_size - self.num_heads) // 2
gate = projected_states[:, :, 2*d_mlp:2*d_mlp+self.intermediate_size]
conv_state = projected_states[:, :, 2*d_mlp+self.intermediate_size:2*d_mlp+self.intermediate_size+self.conv_dim]
time_step = projected_states[:, :, -self.num_heads:]
print(f"conv_state shape before reshape: {conv_state.shape}")
print(f"self.conv_dim: {self.conv_dim}")
# Reshape and handle the case where L=1
conv_state = conv_state.reshape(B, self.conv_dim, L)
if L == 1:
# If sequence length is 1, we need to pad to apply convolution
conv_state = mx.pad(conv_state, ((0, 0), (0, 0), (0, self.conv_kernel_size - 1)))
conv_out = self.conv1d(conv_state)
# If we padded, we need to remove the padding
if L == 1:
conv_out = conv_out[:, :, :L]
# Reshape back to (B, L, C)
conv_out = conv_out.transpose(0, 2, 1)
x_and_conv_out, B, C = mx.split(
conv_out,
[self.intermediate_size, self.n_groups * self.state_size],
axis=-1
)
dt = nn.softplus(time_step + self.dt_bias)
dt = mx.clip(dt, self.args.time_step_min, self.args.time_step_max)
B = B.reshape(-1, self.num_heads, self.head_dim, self.state_size)
C = C.reshape(-1, self.num_heads, self.head_dim, self.state_size)
dA = mx.exp(dt[:, :, None, None] * A[None, :, None, None])
dB = dt[:, :, None, None] * B
new_state = state * dA + x_and_conv_out[:, :, None, None] * dB
y = mx.sum(new_state * C, axis=-1)
y = y + C[None, :, None] * x_and_conv_out
y = self.norm(y.reshape(-1, self.intermediate_size), gate)
output = self.out_proj(y)
return output, new_state
def __call__(
self,
x: mx.array,
cache = None
):
B, L, _ = x.shape
if cache[0] is not None: # Using cached state
conv_state, ssm_state = cache
x = x[:, -1:]
output, new_ssm_state = self.ssm_step(x, None, ssm_state)
cache[1] = new_ssm_state # Update SSM state in cache
else:
conv_state, ssm_state = None, None
outputs = []
for t in range(L):
x = x[:, t:t+1]
output, ssm_state = self.ssm_step(x, None, ssm_state)
outputs.append(output)
output = mx.concatenate(outputs, axis=1)
cache[1] = ssm_state # Store final SSM state in cache
# Update conv state in cache
new_conv_state = x[:, -self.conv_kernel_size:]
cache[0] = new_conv_state
return output
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.residual_in_fp32 = args.residual_in_fp32
self.norm = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mixer = Mamba2Mixer(args)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.mixer(self.norm(inputs), cache=cache)
r = inputs + h
return r
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [Mamba2Block(args) for idx in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
cache=None
):
hidden_states = self.embeddings(inputs)
if cache is None:
cache = Mamba2Cache(len(self.layers))
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, cache[i])
hidden_states = self.norm_f(hidden_states)
return hidden_states
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None
):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self, batch_size: int = 1):
return Mamba2Cache(len(self.backbone.layers))
@property
def layers(self):
return self.backbone.layers

View File

@ -2,7 +2,7 @@
import math import math
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Tuple, Union from typing import Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -20,15 +20,11 @@ class ModelArgs(BaseModelArgs):
state_size: int = 128 state_size: int = 128
num_hidden_layers: int = 64 num_hidden_layers: int = 64
layer_norm_epsilon: float = 1e-5 layer_norm_epsilon: float = 1e-5
pad_token_id: int = 1
bos_token_id: int = 0
eos_token_id: int = 2
expand: int = 2 expand: int = 2
conv_kernel: int = 4 conv_kernel: int = 4
n_groups: int = 8 n_groups: int = 8
use_bias: bool = False use_bias: bool = False
use_conv_bias: bool = True use_conv_bias: bool = True
hidden_act: str = "silu"
initializer_range: float = 0.1 initializer_range: float = 0.1
residual_in_fp32: bool = True residual_in_fp32: bool = True
time_step_rank: Union[int, str] = "auto" time_step_rank: Union[int, str] = "auto"
@ -52,14 +48,18 @@ class ModelArgs(BaseModelArgs):
class Mamba2Cache: class Mamba2Cache:
def __init__(self, num_layers): def __init__(self):
self.cache = [[None, None] for _ in range(num_layers)] self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx): def __getitem__(self, idx):
return self.cache[idx] return self.cache[idx]
def __setitem__(self, idx, value): @property
self.cache[idx] = value def state(self):
return self.cache
class MambaRMSNormGated(nn.Module): class MambaRMSNormGated(nn.Module):
@ -75,66 +75,53 @@ class MambaRMSNormGated(nn.Module):
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states return self.weight * hidden_states
class DepthWiseConv1d(nn.Module): class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__() super().__init__()
assert in_channels == out_channels, "For depthwise conv, in_channels must equal out_channels" self.in_channels = in_channels
self.channels = in_channels self.out_channels = out_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.padding = padding self.padding = padding
self.groups = groups if groups is not None else in_channels
# For depthwise conv, we use groups equal to the number of channels # Ensure in_channels and out_channels are the same for depthwise conv
self.groups = self.channels if groups is None else groups assert in_channels == out_channels, "In and out channels must be the same for depthwise convolution"
assert self.groups == self.channels, "For depthwise conv, groups must equal the number of channels" # Ensure groups is equal to in_channels for depthwise conv
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
# Weight shape: (channels, 1, kernel_size) for depthwise conv # Initialize weight with shape (out_channels, kernel_size, 1)
self.weight = mx.random.normal((self.channels, 1, kernel_size)) self.weight = mx.random.normal((out_channels, kernel_size, 1))
self.bias = mx.zeros((self.channels,)) if bias else None self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x, cache=None): def __call__(self, x, cache=None):
B, L, C = x.shape B, L, C = x.shape
K = self.kernel_size _, K, _ = self.weight.shape
if cache is not None: if cache is not None:
x = mx.concatenate([cache, x], axis=1) x = mx.concatenate([cache, x], axis=1)
else: else:
x = mx.pad(x, [(0, 0), (self.padding, 0), (0, 0)]) x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
# Reshape for depthwise convolution y = mx.conv_general(x, self.weight, groups=self.groups)
x = x.transpose(0, 2, 1) # (B, C, L)
# Perform depthwise convolution
y = mx.conv(x, self.weight, groups=self.groups)
# Reshape back
y = y.transpose(0, 2, 1) # (B, L, C)
if self.bias is not None: if self.bias is not None:
y = y + self.bias y = y + self.bias
return y, x.transpose(0, 2, 1)[:, -K:, :] return y, x[:, -K + 1 :, :]
class Mamba2Mixer(nn.Module): class Mamba2Mixer(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args self.args = args
self.hidden_size = args.hidden_size
self.intermediate_size = args.intermediate_size self.intermediate_size = args.intermediate_size
self.time_step_rank = args.time_step_rank
self.conv_kernel_size = args.conv_kernel self.conv_kernel_size = args.conv_kernel
self.hidden_size = args.hidden_size
self.state_size = args.state_size self.state_size = args.state_size
self.num_heads = args.num_heads self.num_heads = args.num_heads
self.head_dim = args.head_dim self.head_dim = args.head_dim
self.n_groups = args.n_groups self.n_groups = args.n_groups
self.time_step_rank = args.time_step_rank
projection_size = self.intermediate_size + self.intermediate_size + 2 * self.n_groups * self.state_size + self.num_heads
self.in_proj = nn.Linear(
self.hidden_size,
projection_size,
bias=args.use_bias
)
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = DepthWiseConv1d( self.conv1d = DepthWiseConv1d(
@ -143,32 +130,74 @@ class Mamba2Mixer(nn.Module):
bias=args.use_conv_bias, bias=args.use_conv_bias,
kernel_size=args.conv_kernel, kernel_size=args.conv_kernel,
groups=self.conv_dim, groups=self.conv_dim,
padding=args.conv_kernel - 1, padding=args.conv_kernel - 1
)
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(
self.hidden_size,
projection_size,
bias=args.use_bias
) )
self.act = nn.SiLU() self.act = nn.SiLU()
self.dt_bias = mx.ones((self.num_heads,)) self.dt_bias = mx.ones((self.num_heads,))
self.A_log = mx.log(mx.arange(1, self.num_heads + 1, dtype=mx.float32)) self.A_log = mx.log(mx.arange(1, self.num_heads + 1))
self.D = mx.ones((self.num_heads,)) self.D = mx.ones((self.num_heads,))
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon) self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
# def ssm_step(self, x, state=None):
# A = -mx.exp(self.A_log)
# D = self.D
# deltaBC = self.x_proj(x)
# delta, B, C = mx.split(
# deltaBC,
# indices_or_sections=[
# self.time_step_rank,
# self.time_step_rank + self.ssm_state_size,
# ],
# axis=-1,
# )
# delta = nn.softplus(self.dt_proj(delta))
# new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
# if state is not None:
# new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
# y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
# y = y + D * x
# return y, new_state
def ssm_step(self, x, dt, state): def ssm_step(self, x, dt, state):
A = -mx.exp(self.A_log) B, L, C = x.shape
D = self.D print(f"x shape: {x.shape}")
projected_states = self.in_proj(x)
print(f"deltaBC shape: {projected_states.shape}")
deltaBC = self.in_proj(x) d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.state_size - self.num_heads) // 2
gate, conv_state, time_step = mx.split(
deltaBC, gate = projected_states[:, :, 2*d_mlp:2*d_mlp+self.intermediate_size]
[self.intermediate_size, self.intermediate_size + 2 * self.n_groups * self.state_size], conv_state = projected_states[:, :, 2*d_mlp+self.intermediate_size:2*d_mlp+self.intermediate_size+self.conv_dim]
axis=-1 time_step = projected_states[:, :, -self.num_heads:]
)
print(f"conv_state shape before reshape: {conv_state.shape}")
print(f"self.conv_dim: {self.conv_dim}")
# Reshape and handle the case where L=1
conv_state = conv_state.reshape(B, self.conv_dim, L)
if L == 1:
# If sequence length is 1, we need to pad to apply convolution
conv_state = mx.pad(conv_state, ((0, 0), (0, 0), (0, self.conv_kernel_size - 1)))
conv_state = conv_state.transpose(0, 2, 1)
conv_out = self.conv1d(conv_state) conv_out = self.conv1d(conv_state)
# If we padded, we need to remove the padding
if L == 1:
conv_out = conv_out[:, :, :L]
# Reshape back to (B, L, C)
conv_out = conv_out.transpose(0, 2, 1) conv_out = conv_out.transpose(0, 2, 1)
conv_out = self.act(conv_out)
x_and_conv_out, B, C = mx.split( x_and_conv_out, B, C = mx.split(
conv_out, conv_out,
@ -187,58 +216,47 @@ class Mamba2Mixer(nn.Module):
new_state = state * dA + x_and_conv_out[:, :, None, None] * dB new_state = state * dA + x_and_conv_out[:, :, None, None] * dB
y = mx.sum(new_state * C, axis=-1) y = mx.sum(new_state * C, axis=-1)
y = y + D[None, :, None] * x_and_conv_out y = y + C[None, :, None] * x_and_conv_out
y = self.norm(y.reshape(-1, self.intermediate_size), gate) y = self.norm(y.reshape(-1, self.intermediate_size), gate)
output = self.out_proj(y) output = self.out_proj(y)
return output, new_state return output, new_state
def __call__( def __call__(self, x, cache):
self, B, T, D = x.shape
x: mx.array, if cache is None:
cache = None cache = [None, None]
):
B, L, _ = x.shape
if cache[0] is not None: # Using cached state
conv_state, ssm_state = cache
x = x[:, -1:]
output, new_ssm_state = self.ssm_step(x, None, ssm_state)
cache[1] = new_ssm_state # Update SSM state in cache
else:
conv_state, ssm_state = None, None
outputs = [] outputs = []
for t in range(L): for t in range(T):
x = x[:, t:t+1] xt = x[:, t, :]
output, ssm_state = self.ssm_step(x, None, ssm_state) xz = self.in_proj(xt)
outputs.append(output) x_t, z_t = xz.split(indices_or_sections=2, axis=1)
output = mx.concatenate(outputs, axis=1)
cache[1] = ssm_state # Store final SSM state in cache
# Update conv state in cache if x_t.shape[-1] != self.conv_dim:
new_conv_state = x[:, -self.conv_kernel_size:] raise ValueError(f"Expected conv input dim {self.conv_dim}, got {x_t.shape[-1]}")
cache[0] = new_conv_state
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1])
z_t = nn.silu(z_t)
output_t = y_t * z_t
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
return output return output
class Mamba2Block(nn.Module): class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args
self.residual_in_fp32 = args.residual_in_fp32
self.norm = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mixer = Mamba2Mixer(args) self.mixer = Mamba2Mixer(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__( def __call__(self, x: mx.array, cache):
self, return self.mixer(self.norm(x), cache) + x
inputs: mx.array,
cache=None,
):
h = self.mixer(self.norm(inputs), cache_params=cache)
r = inputs + h
return r
class Mamba2(nn.Module): class Mamba2(nn.Module):
@ -275,11 +293,7 @@ class Model(nn.Module):
if not args.tie_word_embeddings: if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__( def __call__(self, inputs: mx.array, cache=None):
self,
inputs: mx.array,
cache=None
):
B, T = inputs.shape B, T = inputs.shape
x = self.backbone(inputs, cache) x = self.backbone(inputs, cache)
@ -288,16 +302,17 @@ class Model(nn.Module):
logits = self.backbone.embeddings.as_linear(x) logits = self.backbone.embeddings.as_linear(x)
else: else:
logits = self.lm_head(x) logits = self.lm_head(x)
return logits return logits
def sanitize_mabey(self, weights): def sanitize(self, weights):
for k, v in weights.items(): for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3: if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1) weights[k] = v.moveaxis(2, 1)
return weights return weights
def make_cache(self, batch_size: int = 1): def make_cache(self, batch_size: int = 1):
return Mamba2Cache(len(self.backbone.layers)) return [Mamba2Cache() for _ in range(len(self.layers))]
@property @property
def layers(self): def layers(self):