quick save

This commit is contained in:
Goekdeniz-Guelmez 2024-10-20 16:11:39 +02:00
parent cd036ccfb5
commit 4ab5139c05
3 changed files with 266 additions and 170 deletions

View File

@ -338,3 +338,30 @@ class MambaCache(_BaseCache):
@state.setter @state.setter
def state(self, v): def state(self, v):
self.cache = v self.cache = v
class Mamba2Cache:
def __init__(self, num_layers):
self.conv_states = [None] * num_layers
self.ssm_states = [None] * num_layers
self.seqlen_offset = 0
def __getitem__(self, idx):
return (self.conv_states[idx], self.ssm_states[idx])
def __setitem__(self, idx, value):
self.conv_states[idx], self.ssm_states[idx] = value
@property
def state(self):
return {
'conv_states': self.conv_states,
'ssm_states': self.ssm_states,
'seqlen_offset': self.seqlen_offset
}
@state.setter
def state(self, v):
self.conv_states = v['conv_states']
self.ssm_states = v['ssm_states']
self.seqlen_offset = v['seqlen_offset']

View File

@ -6,41 +6,37 @@ from typing import Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
# python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello how are you."
@dataclass @dataclass
class ModelArgs(BaseModelArgs): class ModelArgs(BaseModelArgs):
model_type: str = "mamba2" num_heads: int
num_heads: int = 128 head_dim: int
head_dim: int = 64 vocab_size: int
vocab_size: int = 32768 hidden_size: int
hidden_size: int = 4096 state_size: int
state_size: int = 128 num_hidden_layers: int
num_hidden_layers: int = 64 layer_norm_epsilon: float
layer_norm_epsilon: float = 1e-5 expand: int
pad_token_id: int = 1 conv_kernel: int
bos_token_id: int = 0 n_groups: int
eos_token_id: int = 2 use_bias: bool
expand: int = 2 use_conv_bias: bool
conv_kernel: int = 4 initializer_range: float
n_groups: int = 8 residual_in_fp32: bool
use_bias: bool = False time_step_min: float
use_conv_bias: bool = True time_step_max: float
hidden_act: str = "silu" time_step_floor: float
initializer_range: float = 0.1 rescale_prenorm_residual: bool
residual_in_fp32: bool = True use_cache: bool
time_step_rank: Union[int, str] = "auto" rms_norm: bool
time_step_min: float = 0.001 chunk_size: int
time_step_max: float = 0.1 tie_word_embeddings: bool
time_step_floor: float = 1e-4
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
rescale_prenorm_residual: bool = False time_step_rank: Union[int, str] = "auto"
use_cache: bool = True model_type: str = "mamba2"
rms_norm: bool = True
chunk_size: int = 256
tie_word_embeddings: bool = False
def __post_init__(self): def __post_init__(self):
if not hasattr(self, "intermediate_size"): if not hasattr(self, "intermediate_size"):
@ -79,15 +75,24 @@ class MambaRMSNormGated(nn.Module):
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states return self.weight * hidden_states
class DepthWiseConv1d(nn.Module): class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, groups=1, padding=0): def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__() super().__init__()
self.channels = channels self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.padding = padding self.padding = padding
self.groups = groups self.groups = groups if groups is not None else in_channels
self.weight = mx.random.normal((self.channels, kernel_size, 1))
self.bias = mx.zeros((channels,)) if bias else None # Ensure in_channels and out_channels are the same for depthwise conv
assert in_channels == out_channels, "In and out channels must be the same for depthwise convolution"
# Ensure groups is equal to in_channels for depthwise conv
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
# Initialize weight with shape (out_channels, kernel_size, 1)
self.weight = mx.random.normal((out_channels, kernel_size, 1))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x, cache=None): def __call__(self, x, cache=None):
B, L, C = x.shape B, L, C = x.shape
@ -116,16 +121,17 @@ class Mamba2Mixer(nn.Module):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.state_size = args.state_size self.state_size = args.state_size
self.num_heads = args.num_heads self.num_heads = args.num_heads
self.head_dim = args.head_dim self.head_dim = args.hidden_size // args.num_heads
self.n_groups = args.n_groups self.n_groups = args.n_groups
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = DepthWiseConv1d( self.conv1d = DepthWiseConv1d(
channels=self.conv_dim, in_channels=self.conv_dim,
kernel_size=self.conv_kernel_size, out_channels=self.conv_dim,
bias=self.args.use_conv_bias, bias=args.use_conv_bias,
kernel_size=args.conv_kernel,
groups=self.conv_dim, groups=self.conv_dim,
padding=self.conv_kernel_size - 1, padding=args.conv_kernel - 1
) )
projection_size = self.intermediate_size + self.conv_dim + self.num_heads projection_size = self.intermediate_size + self.conv_dim + self.num_heads
@ -135,33 +141,35 @@ class Mamba2Mixer(nn.Module):
bias=args.use_bias bias=args.use_bias
) )
self.act = nn.SiLU() self.A_log = mx.zeros(self.num_heads)
self.dt_bias = mx.ones((self.num_heads,)) self.D = mx.ones(self.num_heads)
self.A_log = mx.log(mx.arange(1, self.num_heads + 1)) self.dt_bias = mx.zeros(self.num_heads)
self.D = mx.ones((self.num_heads,))
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon) self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
def ssm_step(self, x, state=None): def ssm_step(self, x, state, dt_proj):
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
D = self.D D = self.D
deltaBC = self.x_proj(x) delta = nn.softplus(dt_proj + self.dt_bias)
delta, B, C = mx.split(
deltaBC, B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
indices_or_sections=[
self.time_step_rank, batch_size = B.shape[0]
self.time_step_rank + self.ssm_state_size, B = B.reshape(batch_size, self.n_groups, self.state_size)
], C = C.reshape(batch_size, -1, self.state_size)
axis=-1,
) delta = delta.reshape(batch_size, self.num_heads, 1)
delta = nn.softplus(self.dt_proj(delta)) A = A.reshape(1, self.num_heads, 1)
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None: if state is None:
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A) new_state = delta * B
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2) else:
y = y + D * x new_state = delta * (B + state * mx.exp(delta * A))
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
return y, new_state return y, new_state
def __call__(self, x, cache): def __call__(self, x, cache):
@ -173,15 +181,28 @@ class Mamba2Mixer(nn.Module):
for t in range(T): for t in range(T):
xt = x[:, t, :] xt = x[:, t, :]
xz = self.in_proj(xt) xz = self.in_proj(xt)
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
x_t, z_t, dt_proj = mx.split(
xz,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
axis=-1
)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1) x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t) x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1]) y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj)
z_t = nn.silu(z_t) z_t = nn.silu(z_t)
output_t = y_t * z_t
# Element-wise multiplication
output_t = y_t[:, :, None] * z_t[:, None, :]
# Sum across the second dimension to match the intermediate_size
output_t = output_t.sum(axis=1)
output_t = self.out_proj(output_t) output_t = self.out_proj(output_t)
outputs.append(output_t) outputs.append(output_t)
output = mx.stack(outputs, axis=1) output = mx.stack(outputs, axis=1)
return output return output
@ -240,6 +261,9 @@ class Model(nn.Module):
else: else:
logits = self.lm_head(x) logits = self.lm_head(x)
print(logits)
print(logits.shape)
return logits return logits
def sanitize(self, weights): def sanitize(self, weights):

View File

@ -2,11 +2,13 @@
import math import math
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple, Union from typing import Tuple, Union, Optional
import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.core as mx
from .base import BaseModelArgs from .base import BaseModelArgs
from .cache import Mamba2Cache
# python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello how are you." # python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello how are you."
@ -46,22 +48,6 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto": if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16) self.time_step_rank = math.ceil(self.hidden_size / 16)
class Mamba2Cache:
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
class MambaRMSNormGated(nn.Module): class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
super().__init__() super().__init__()
@ -75,6 +61,7 @@ class MambaRMSNormGated(nn.Module):
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states return self.weight * hidden_states
class DepthWiseConv1d(nn.Module): class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__() super().__init__()
@ -111,27 +98,22 @@ class DepthWiseConv1d(nn.Module):
class Mamba2Mixer(nn.Module): class Mamba2Mixer(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args, layer_idx):
super().__init__() super().__init__()
self.args = args self.layer_idx = layer_idx
self.intermediate_size = args.intermediate_size
self.time_step_rank = args.time_step_rank
self.conv_kernel_size = args.conv_kernel
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.state_size = args.state_size self.intermediate_size = args.intermediate_size
self.num_heads = args.num_heads self.num_heads = args.num_heads
self.head_dim = args.hidden_size // args.num_heads self.head_dim = args.head_dim
self.ssm_state_size = args.state_size
self.n_groups = args.n_groups self.n_groups = args.n_groups
self.conv_kernel_size = args.conv_kernel
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size self.use_conv_bias = args.use_conv_bias
self.conv1d = DepthWiseConv1d( self.use_bias = args.use_bias
in_channels=self.conv_dim, self.time_step_min = args.time_step_min
out_channels=self.conv_dim, self.time_step_max = args.time_step_max
bias=args.use_conv_bias, self.chunk_size = args.chunk_size
kernel_size=args.conv_kernel, self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
projection_size = self.intermediate_size + self.conv_dim + self.num_heads projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear( self.in_proj = nn.Linear(
@ -139,91 +121,151 @@ class Mamba2Mixer(nn.Module):
projection_size, projection_size,
bias=args.use_bias bias=args.use_bias
) )
self.conv1d = nn.Conv1d(
self.dt_bias = mx.ones((self.num_heads,)) self.conv_dim,
self.A_log = mx.log(mx.arange(1, self.num_heads + 1)) self.conv_dim,
self.D = mx.ones((self.num_heads,)) self.conv_kernel_size,
groups=self.conv_dim,
bias=self.use_conv_bias
)
self.act = nn.SiLU()
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon) self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(
self.intermediate_size,
self.hidden_size,
bias=self.use_bias
)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) self.A_log = mx.zeros(self.num_heads)
self.D = mx.ones(self.num_heads)
self.dt_bias = mx.zeros(self.num_heads)
def __call__(self, input_states, cache):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
def ssm_step(self, x, state, dt_proj): projected_states = self.in_proj(input_states)
print(f"ssm_step input shapes - x: {x.shape}, dt_proj: {dt_proj.shape}")
A = -mx.exp(self.A_log)
D = self.D
delta = nn.softplus(dt_proj + self.dt_bias)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1) # Calculate the sizes of each split
print(f"ssm_step split shapes - B: {B.shape}, C: {C.shape}") total_size = projected_states.shape[-1]
remaining_size = total_size - self.intermediate_size - self.conv_dim - self.num_heads
d_mlp = remaining_size // 2
sizes = [
d_mlp,
d_mlp,
self.intermediate_size,
self.conv_dim,
self.num_heads
]
batch_size = B.shape[0] # Perform the split operation
B = B.reshape(batch_size, self.n_groups, self.state_size) split_result = mx.split(projected_states, sizes, axis=-1)
C = C.reshape(batch_size, -1, self.state_size)
print(f"After reshape - B: {B.shape}, C: {C.shape}")
delta = delta.reshape(batch_size, self.num_heads, 1) # Print debug information
A = A.reshape(1, self.num_heads, 1) print(f"Number of split parts: {len(split_result)}")
print(f"Shapes of split parts: {[part.shape for part in split_result]}")
if state is None: # Flexibly handle the split result
new_state = delta * B _, _, _, gate, hidden_states, dt = split_result
if cache is not None:
conv_state = cache.conv_states[self.layer_idx]
if conv_state is None:
# Initialize conv_state if it's None
conv_state = mx.zeros((batch_size, 1, self.conv_kernel_size, hidden_states.shape[-1]))
conv_state = mx.roll(conv_state, -1, -2) # Roll along the kernel dimension
# Reshape hidden_states to match conv_state dimensions
hidden_states_reshaped = hidden_states[:, None, None, :]
conv_state = mx.concat([conv_state[:, :, :-1, :], hidden_states_reshaped], axis=-2)
cache.conv_states[self.layer_idx] = conv_state
# Adjust the convolution operation
hidden_states = mx.sum(conv_state * self.conv1d.weight[:, :, None, :], axis=(-2, -1))
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = self.act(hidden_states)[:, None, :]
else: else:
new_state = delta * (B + state * mx.exp(delta * A)) hidden_states = hidden_states.transpose(0, 2, 1)
hidden_states = self.act(self.conv1d(hidden_states)).transpose(0, 2, 1)
print(f"Before final computation - new_state: {new_state.shape}, C: {C.shape}")
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
print(f"ssm_step output shape - y: {y.shape}")
return y, new_state
def __call__(self, x, cache): hidden_states, B, C = mx.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], axis=-1)
B, T, D = x.shape
print(f"__call__ input shape - x: {x.shape}")
if cache is None:
cache = [None, None]
outputs = [] A = -mx.exp(self.A_log.astype(mx.float32))
for t in range(T): dt = nn.softplus(dt + self.dt_bias)
xt = x[:, t, :] dt = mx.clip(dt, self.time_step_min, self.time_step_max)
xz = self.in_proj(xt)
print(f"After in_proj shape - xz: {xz.shape}")
x_t, z_t, dt_proj = mx.split(
xz,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
axis=-1
)
print(f"After split shapes - x_t: {x_t.shape}, z_t: {z_t.shape}, dt_proj: {dt_proj.shape}")
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).astype(mx.float32)
x_t = conv_out.squeeze(1) B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32)
x_t = nn.silu(x_t) C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32)
print(f"Before ssm_step shape - x_t: {x_t.shape}")
y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj) B = mx.repeat(B, repeats=self.num_heads // self.n_groups, axis=2)
z_t = nn.silu(z_t) C = mx.repeat(C, repeats=self.num_heads // self.n_groups, axis=2)
print(f"After ssm_step shapes - y_t: {y_t.shape}, z_t: {z_t.shape}")
if cache is not None and cache.seqlen_offset > 0:
# Element-wise multiplication ssm_state = cache.ssm_states[self.layer_idx]
output_t = y_t[:, :, None] * z_t[:, None, :] dA = mx.exp(dt[:, None, :, None] * A[None, :, None, None])
print(f"After multiplication shape - output_t: {output_t.shape}") dB = dt[:, None, :, None] * B
dBx = dB * hidden_states[:, :, :, None]
# Sum across the second dimension to match the intermediate_size ssm_state = ssm_state * dA + dBx
output_t = output_t.sum(axis=1) cache.ssm_states[self.layer_idx] = ssm_state
print(f"After sum shape - output_t: {output_t.shape}")
y = mx.sum(ssm_state * C[:, None, :, :], axis=-1)
output_t = self.out_proj(output_t) D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim)
print(f"After out_proj shape - output_t: {output_t.shape}") y = y + hidden_states * D
outputs.append(output_t)
y = y.reshape(batch_size, -1)[:, None, :]
output = mx.stack(outputs, axis=1) else:
print(f"Final output shape: {output.shape}") # Implement chunked computation here (simplified version)
return output pad_size = self.chunk_size - (seq_len % self.chunk_size)
hidden_states_padded = mx.pad(hidden_states, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
B_padded = mx.pad(B, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
C_padded = mx.pad(C, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
chunks = seq_len // self.chunk_size + (1 if pad_size > 0 else 0)
y_list = []
ssm_state = mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size))
for i in range(chunks):
chunk_start = i * self.chunk_size
chunk_end = (i + 1) * self.chunk_size
chunk_h = hidden_states_padded[:, chunk_start:chunk_end]
chunk_B = B_padded[:, chunk_start:chunk_end]
chunk_C = C_padded[:, chunk_start:chunk_end]
chunk_dt = dt[:, chunk_start:chunk_end]
dA = mx.exp(chunk_dt[:, :, None, None] * A[None, None, :, None])
dB = chunk_dt[:, :, None, None] * chunk_B
dBx = dB * chunk_h[:, :, :, None]
chunk_y = mx.zeros_like(chunk_h)
for j in range(self.chunk_size):
ssm_state = ssm_state * dA[:, j] + dBx[:, j]
chunk_y[:, j] = mx.sum(ssm_state * chunk_C[:, j], axis=-1)
y_list.append(chunk_y)
y = mx.concat(y_list, axis=1)
if pad_size > 0:
y = y[:, :seq_len]
D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim)
y = y + hidden_states * D
y = y.reshape(batch_size, seq_len, -1)
y = self.norm(y, gate)
contextualized_states = self.out_proj(y.astype(dtype))
return contextualized_states
class Mamba2Block(nn.Module): class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__() super().__init__()
self.mixer = Mamba2Mixer(args) self.mixer = Mamba2Mixer(args, layer_idx)
self.norm = nn.RMSNorm(args.hidden_size) self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache): def __call__(self, x: mx.array, cache):
@ -235,7 +277,7 @@ class Mamba2(nn.Module):
super().__init__() super().__init__()
self.args = args self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [Mamba2Block(args) for idx in range(args.num_hidden_layers)] self.layers = [Mamba2Block(args, idx) for idx in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__( def __call__(
@ -274,6 +316,9 @@ class Model(nn.Module):
else: else:
logits = self.lm_head(x) logits = self.lm_head(x)
print(logits)
print(logits.shape)
return logits return logits
def sanitize(self, weights): def sanitize(self, weights):
@ -282,8 +327,8 @@ class Model(nn.Module):
weights[k] = v.moveaxis(2, 1) weights[k] = v.moveaxis(2, 1)
return weights return weights
def make_cache(self, batch_size: int = 1): def make_cache(self):
return [Mamba2Cache() for _ in range(len(self.layers))] return [Mamba2Cache(self.args.num_hidden_layers) for _ in range(len(self.layers))]
@property @property
def layers(self): def layers(self):