diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index df6784b5..37b414da 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -421,24 +421,24 @@ class RotatingKVCache(_BaseCache): class MambaCache: def __init__(self): - # cache[0] is conv state, cache[1] is ssm state + # [conv_state, ssm_state] self.cache = [None, None] - self.offset = 0 - + self.offset = 0 # Sliding window caching + def __setitem__(self, idx, value): self.cache[idx] = value def __getitem__(self, idx): return self.cache[idx] - + @property def state(self): return self.cache - + @state.setter def state(self, v): self.cache = v - + @property def conv_states(self): return [self.cache[0]] @@ -446,13 +446,7 @@ class MambaCache: @property def ssm_states(self): return [self.cache[1]] - - -class Mamba2Cache: - def __init__(self): - self.conv_states = [None] # Initialize as None, will be set on first use - self.ssm_states = [None] # Initialize as None, will be set on first use - - @property - def state(self): - return [self.conv_states[0], self.ssm_states[0]] \ No newline at end of file + + def reset(self): + self.cache = [None, None] + self.offset = 0 diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 51820221..1cf493dc 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs -from .cache import Mamba2Cache +from .cache import MambaCache @dataclass class ModelArgs(BaseModelArgs): @@ -91,6 +91,168 @@ def ssd(x, A, B, C, chunk_size): return mx.concatenate(outputs, axis=1), state +# class DepthWiseConv1d(nn.Module): +# def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): +# super().__init__() +# self.in_channels = in_channels +# self.out_channels = out_channels +# self.kernel_size = kernel_size +# self.padding = padding +# self.groups = groups if groups is not None else in_channels + +# assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" +# assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" + +# self.weight = mx.random.normal((in_channels, 1, kernel_size)) +# self.bias = mx.zeros((out_channels,)) if bias else None + +# def __call__(self, x: mx.array, cache=None) -> mx.array: +# B, L, C = x.shape +# K = self.kernel_size + +# assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" + +# if cache is not None: +# if isinstance(cache.conv_states[0], type(None)): +# cache.conv_states[0] = mx.zeros((B, K-1, C)) + +# x = mx.concatenate([cache.conv_states[0], x], axis=1) + +# outputs = [] +# for c in range(C): +# # Input prep debug +# x_c = x[:, :, c] +# x_c = mx.expand_dims(x_c, axis=1) + +# # Weight prep debug +# w_c = self.weight[c] +# if w_c.ndim == 2: +# w_c = mx.expand_dims(w_c, axis=0) +# elif w_c.ndim == 1: +# w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) + +# y_c = mx.conv_general( +# x_c, +# w_c, +# stride=1, +# padding=0 +# ) +# if self.bias is not None: +# y_c = y_c + self.bias[c] + +# y_c = mx.squeeze(y_c, axis=1) +# outputs.append(y_c) + +# # Output statistics +# y = mx.stack(outputs, axis=-1) + +# # Cache update debug +# if cache is not None: +# cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x + +# return y + + +# class Mamba2Block(nn.Module): +# def __init__(self, args: ModelArgs): +# super().__init__() +# self.args = args + +# d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads +# self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) + +# conv_dim = args.intermediate_size + 2 * args.state_size +# self.conv1d = DepthWiseConv1d( +# in_channels=conv_dim, +# out_channels=conv_dim, +# kernel_size=args.conv_kernel, +# groups=conv_dim, +# bias=args.use_conv_bias, +# padding=args.conv_kernel - 1 +# ) + +# self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range +# self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range +# self.D = mx.random.normal((args.num_heads,)) * args.initializer_range + +# self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) +# self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) + +# if args.rescale_prenorm_residual: +# layer_scale = math.sqrt(1.0 / args.num_hidden_layers) +# self.out_proj.weight = self.out_proj.weight * layer_scale + +# def __call__(self, u: mx.array, cache=None): +# # Expect input to be shape [batch_size, 1, dim] +# batch_size, seq_len, dimension = u.shape +# assert seq_len == 1, "Input should be a single token" + +# # Initialize cache if needed +# if cache.conv_states[0] is None: +# conv_dim = self.args.intermediate_size + 2 * self.args.state_size +# cache.conv_states[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) + +# if cache.ssm_states[0] is None: +# cache.ssm_states[0] = mx.zeros(( +# batch_size, +# self.args.num_heads, +# self.args.head_dim, +# self.args.state_size +# )) + +# # Project input +# zxbcdt = self.in_proj(u) + +# # Split projections +# n_heads = self.args.num_heads +# z = zxbcdt[:, :, :self.args.intermediate_size] +# xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] +# dt = zxbcdt[:, :, -(n_heads):] + +# # Time steps +# dt = mx.reshape(dt, (batch_size, n_heads)) +# dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max) +# dt = mx.maximum(dt, self.args.time_step_floor) + +# # Convolution +# xBC = self.conv1d(xBC, cache=cache) +# xBC = silu(xBC) + +# # Split states +# x = xBC[:, :, :self.args.intermediate_size] +# B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] +# C = xBC[:, :, -self.args.state_size:] + +# # Reshape for SSM +# x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) +# x = mx.squeeze(x, axis=1) +# B = mx.reshape(B, (batch_size, 1, self.args.state_size)) +# B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) +# B = mx.expand_dims(B, axis=2) +# C = mx.reshape(C, (batch_size, 1, self.args.state_size)) +# C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) +# C = mx.expand_dims(C, axis=3) + +# # SSM updates +# A = -mx.exp(self.A_log) +# dA = mx.exp(dt * mx.expand_dims(A, 0)) +# dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) + +# # Update state +# x = mx.expand_dims(x, axis=3) +# dBx = mx.matmul(x, B) +# cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx + +# # Compute output +# y = mx.matmul(cache.ssm_states[0], C) +# y = mx.squeeze(y, axis=-1) +# y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) +# y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) +# y = self.norm(y + z) + +# return self.out_proj(y) + + class DepthWiseConv1d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): super().__init__() @@ -113,18 +275,17 @@ class DepthWiseConv1d(nn.Module): assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" if cache is not None: - if isinstance(cache.conv_states[0], type(None)): - cache.conv_states[0] = mx.zeros((B, K-1, C)) + # Access conv_state directly from cache[0] + if cache[0] is None: + cache[0] = mx.zeros((B, K-1, C)) - x = mx.concatenate([cache.conv_states[0], x], axis=1) + x = mx.concatenate([cache[0], x], axis=1) outputs = [] for c in range(C): - # Input prep debug x_c = x[:, :, c] x_c = mx.expand_dims(x_c, axis=1) - # Weight prep debug w_c = self.weight[c] if w_c.ndim == 2: w_c = mx.expand_dims(w_c, axis=0) @@ -143,12 +304,11 @@ class DepthWiseConv1d(nn.Module): y_c = mx.squeeze(y_c, axis=1) outputs.append(y_c) - # Output statistics y = mx.stack(outputs, axis=-1) - # Cache update debug + # Update cache directly using cache[0] if cache is not None: - cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x + cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x return y @@ -171,11 +331,9 @@ class Mamba2Block(nn.Module): padding=args.conv_kernel - 1 ) - # Initialize parameters - self.dt_bias = mx.ones(args.num_heads) - A = mx.arange(1, args.num_heads + 1) - self.A_log = mx.log(A) - self.D = mx.ones(args.num_heads) + self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range + self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range + self.D = mx.random.normal((args.num_heads,)) * args.initializer_range self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) @@ -185,47 +343,40 @@ class Mamba2Block(nn.Module): self.out_proj.weight = self.out_proj.weight * layer_scale def __call__(self, u: mx.array, cache=None): - # Expect input to be shape [batch_size, 1, dim] batch_size, seq_len, dimension = u.shape assert seq_len == 1, "Input should be a single token" - # Initialize cache if needed - if cache.conv_states[0] is None: + # Initialize cache states directly using indices + if cache[0] is None: # conv state conv_dim = self.args.intermediate_size + 2 * self.args.state_size - cache.conv_states[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) + cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) - if cache.ssm_states[0] is None: - cache.ssm_states[0] = mx.zeros(( + if cache[1] is None: # ssm state + cache[1] = mx.zeros(( batch_size, self.args.num_heads, self.args.head_dim, self.args.state_size )) - # Project input zxbcdt = self.in_proj(u) - # Split projections n_heads = self.args.num_heads z = zxbcdt[:, :, :self.args.intermediate_size] xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] dt = zxbcdt[:, :, -(n_heads):] - # Time steps dt = mx.reshape(dt, (batch_size, n_heads)) dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max) dt = mx.maximum(dt, self.args.time_step_floor) - # Convolution xBC = self.conv1d(xBC, cache=cache) xBC = silu(xBC) - # Split states x = xBC[:, :, :self.args.intermediate_size] B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] C = xBC[:, :, -self.args.state_size:] - # Reshape for SSM x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) x = mx.squeeze(x, axis=1) B = mx.reshape(B, (batch_size, 1, self.args.state_size)) @@ -235,24 +386,23 @@ class Mamba2Block(nn.Module): C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) C = mx.expand_dims(C, axis=3) - # SSM updates A = -mx.exp(self.A_log) dA = mx.exp(dt * mx.expand_dims(A, 0)) dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) - # Update state x = mx.expand_dims(x, axis=3) dBx = mx.matmul(x, B) - cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx + # Update ssm state directly using cache[1] + cache[1] = cache[1] * dA + dBx - # Compute output - y = mx.matmul(cache.ssm_states[0], C) + y = mx.matmul(cache[1], C) y = mx.squeeze(y, axis=-1) y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) y = self.norm(y + z) return self.out_proj(y) + class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): @@ -308,7 +458,7 @@ class Model(nn.Module): return logits def make_cache(self): - return [Mamba2Cache() for _ in range(len(self.layers))] + return [MambaCache() for _ in range(len(self.layers))] @property def layers(self):