From 758597eaa80c9b07acf9588c6e04c22f06eeb3ee Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 22 Oct 2024 20:44:23 +0200 Subject: [PATCH] adding multi token input and correct cache handling in ssm step --- llms/mlx_lm/models/mamba2.py | 386 ++++++++++++++++++++++------------- llms/mlx_lm/tuner/utils.py | 2 + 2 files changed, 251 insertions(+), 137 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 9186acfe..3c4b4e7c 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -27,10 +27,10 @@ class ModelArgs(BaseModelArgs): time_step_max: float time_step_floor: float rescale_prenorm_residual: bool - use_cache: bool rms_norm: bool chunk_size: int tie_word_embeddings: bool + use_cache: bool = True time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) time_step_rank: Union[int, str] = "auto" model_type: str = "mamba2" @@ -58,6 +58,29 @@ class MambaRMSNormGated(nn.Module): return self.weight * hidden_states +def silu(x): + return x * mx.sigmoid(x) + +def ssd(x, A, B, C, chunk_size): + batch, seqlen, nheads, dim = x.shape + B = mx.expand_dims(B, axis=2) + C = mx.expand_dims(C, axis=2) + + state = mx.zeros((batch, nheads, dim, B.shape[-1])) + outputs = [] + + for i in range(0, seqlen, chunk_size): + chunk = slice(i, min(i + chunk_size, seqlen)) + dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) + + dBx = mx.einsum('blhp,bln->bhpn', x[:, chunk], B[:, chunk]) + state = state * mx.expand_dims(dA, axis=-1) + dBx + y = mx.einsum('bhpn,bln->blhp', state, C[:, chunk]) + outputs.append(y) + + return mx.concatenate(outputs, axis=1), state + + class DepthWiseConv1d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): super().__init__() @@ -66,128 +89,143 @@ class DepthWiseConv1d(nn.Module): self.kernel_size = kernel_size self.padding = padding self.groups = groups if groups is not None else in_channels - - # Ensure in_channels and out_channels are the same for depthwise conv - assert in_channels == out_channels, "In and out channels must be the same for depthwise convolution" - # Ensure groups is equal to in_channels for depthwise conv + + assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" - - # Initialize weight with shape (out_channels, kernel_size, 1) - self.weight = mx.random.normal((out_channels, kernel_size, 1)) + + # Initialize with shape (channels, 1, kernel_size) to match pretrained weights + self.weight = mx.random.normal((in_channels, 1, kernel_size)) self.bias = mx.zeros((out_channels,)) if bias else None - def __call__(self, x, cache=None): + def __call__(self, x: mx.array, cache=None, cache_idx: int = 0) -> mx.array: B, L, C = x.shape - _, K, _ = self.weight.shape + K = self.kernel_size + # Handle padding and caching if cache is not None: - x = mx.concatenate([cache, x], axis=1) + conv_cache = cache[cache_idx] + if conv_cache is not None: + x = mx.concatenate([conv_cache, x], axis=1) + L = x.shape[1] # Update L after concatenation else: - x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) + pad_left = K - 1 + x = mx.pad(x, [(0, 0), (pad_left, 0), (0, 0)]) + L = x.shape[1] # Update L after padding - y = mx.conv_general(x, self.weight, groups=self.groups) + # Implement depthwise convolution manually for each channel + outputs = [] + for c in range(C): + # Extract single channel and reshape for 1D convolution + x_c = x[:, :, c] # Shape: [B, L] + x_c = mx.expand_dims(x_c, axis=1) # Shape: [B, 1, L] + + # Extract and ensure filter is 3D + w_c = self.weight[c] # Shape: [1, kernel_size] or [1, 1, kernel_size] + if w_c.ndim == 2: + w_c = mx.expand_dims(w_c, axis=0) # Shape: [1, 1, kernel_size] + elif w_c.ndim == 1: + w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) + + # For inference mode (single token), adjust the input + if L < K: + # Pad input to match kernel size + pad_size = K - L + x_c = mx.pad(x_c, [(0, 0), (0, 0), (pad_size, 0)]) + + # Apply 1D convolution for this channel + y_c = mx.conv_general( + x_c, + w_c, + stride=1, + padding=0 # We've already handled padding + ) + + if self.bias is not None: + y_c = y_c + self.bias[c] + + outputs.append(mx.squeeze(y_c, axis=1)) # Shape: [B, 1] + + # Stack all channel outputs + y = mx.stack(outputs, axis=-1) # Shape: [B, L', C] + + if cache is not None: + # Update cache with the most recent K-1 tokens + cache[cache_idx] = x[:, -(K-1):, :] if L >= K else x - if self.bias is not None: - y = y + self.bias - - return y, x[:, -K + 1 :, :] + return y class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args - self.intermediate_size = args.intermediate_size - self.time_step_rank = args.time_step_rank - self.conv_kernel_size = args.conv_kernel - self.hidden_size = args.hidden_size - self.state_size = args.state_size - self.num_heads = args.num_heads - self.head_dim = args.hidden_size // args.num_heads - self.n_groups = args.n_groups + + d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads + self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) - # projection_size = 2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads - projection_size = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads - self.in_proj = nn.Linear( - args.hidden_size, - projection_size, - bias=args.use_bias - ) - - # self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size - self.conv_dim = args.intermediate_size + 2 * args.state_size + conv_dim = args.intermediate_size + 2 * args.state_size self.conv1d = DepthWiseConv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, + in_channels=conv_dim, + out_channels=conv_dim, kernel_size=args.conv_kernel, + groups=conv_dim, bias=args.use_conv_bias, - groups=self.conv_dim, padding=args.conv_kernel - 1 ) - self.A_log = mx.zeros(args.num_heads) - self.D = mx.ones((args.num_heads,)) - self.dt_bias = mx.zeros(args.num_heads) + self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range + self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range + self.D = mx.random.normal((args.num_heads,)) * args.initializer_range - self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) + self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - def _ssd(self, x, A, B, C, chunk_size): - batch, seq_len, nheads, head_dim = x.shape - n_state = B.shape[-1] - - h = mx.zeros((batch, nheads, head_dim, n_state)) - ys = [] - - for i in range(0, seq_len, chunk_size): - chunk_size_i = min(chunk_size, seq_len - i) - xi = x[:, i:i + chunk_size_i] - Bi = B[:, i:i + chunk_size_i] - Ci = C[:, i:i + chunk_size_i] - - for t in range(chunk_size_i): - h = h * mx.exp(A)[:, None, None] - h = h + mx.expand_dims(Bi[:, t], -2) * mx.expand_dims(xi[:, t], -1) - y = mx.sum(h * mx.expand_dims(Ci[:, t], -2), axis=-1) - ys.append(y) - - y = mx.stack(ys, axis=1) - return y, h + if args.rescale_prenorm_residual: + layer_scale = math.sqrt(1.0 / args.num_hidden_layers) + self.out_proj.weight = self.out_proj.weight * layer_scale - def __call__(self, x: mx.array, cache) -> mx.array: - if cache is not None: - return self.step(x, cache) + def __call__(self, u: mx.array, cache = None): + if cache is not None and self.args.use_cache: + return self.step(u, cache) A = -mx.exp(self.A_log) zxbcdt = self.in_proj(u) - z, xBC, dt = mx.split( - zxbcdt, - [ - self.args.d_inner, - self.args.d_inner + 2 * self.args.d_state, - self.args.nheads, - ], - axis=-1, + splits = [ + self.args.intermediate_size, + self.args.intermediate_size + 2 * self.args.state_size, + self.args.num_heads, + ] + + z, xBC, dt = mx.split(zxbcdt, splits, axis=-1) + + dt = mx.clip( + nn.softplus(dt + self.dt_bias), + self.args.time_step_min, + self.args.time_step_max ) - dt = mx.softplus(dt + self.dt_bias) - - # Use the custom DepthWiseConv1d with cache - xBC = self.conv1d(xBC, cache, cache_idx=0) - xBC = mx.sigmoid(xBC) * xBC # SiLU activation - - x, B, C = mx.split( + dt = mx.maximum(dt, self.args.time_step_floor) + + xBC = silu(self.conv1d(xBC)) + + xBC_parts = mx.split( xBC, - [self.args.d_inner, self.args.d_state, self.args.d_state], + [self.args.intermediate_size, self.args.state_size, self.args.state_size], axis=-1 ) + + x = xBC_parts[0] + B = xBC_parts[1] + C = xBC_parts[2] - x = self._reshape_heads(x, True) - B = mx.expand_dims(B, axis=2) - C = mx.expand_dims(C, axis=2) + # Replace rearrange with reshape and transpose + b, l, hp = x.shape + h = self.args.num_heads + p = hp // h + x = mx.reshape(x, (b, l, h, p)) - y, ssm_state = self._ssd( + y, ssm_state = ssd( x * mx.expand_dims(dt, -1), A * dt, B, @@ -196,61 +234,127 @@ class Mamba2Block(nn.Module): ) y = y + x * mx.expand_dims(self.D, -1) - y = self._reshape_heads(y, False) - y = self.norm(y, z) + # Replace rearrange with reshape + y = mx.reshape(y, (b, l, h * p)) + + y = self.norm(y + z) y = self.out_proj(y) - if cache is not None: + if cache is not None and self.args.use_cache: cache[1] = ssm_state + if self.args.residual_in_fp32: + y = mx.cast(y, mx.float32) + return y - def step(self, x: mx.array, cache) -> mx.array: - """Single inference step""" - assert x.shape[1] == 1, "Only one token can be decoded per inference step" - - zxbcdt = self.in_proj(mx.squeeze(x, 1)) - z, xBC, dt = mx.split( - zxbcdt, - [ - self.args.d_inner, - self.args.d_inner + 2 * self.args.d_state, - self.args.nheads, - ], - axis=-1, - ) + def step(self, u: mx.array, cache: MambaCache): + batch_size = u.shape[0] + seq_len = u.shape[1] + outputs = [] - # Use the custom DepthWiseConv1d with cache - xBC = self.conv1d(xBC, cache, cache_idx=0) - xBC = mx.sigmoid(xBC) * xBC # SiLU activation + # Initialize SSM state if needed + if cache[1] is None: + cache[1] = mx.zeros(( + batch_size, + self.args.num_heads, + self.args.head_dim, + self.args.state_size + )) - x, B, C = mx.split( - xBC, - [self.args.d_inner, self.args.d_state, self.args.d_state], - axis=-1 - ) - A = -mx.exp(self.A_log) + for pos in range(seq_len): + # Get single token + u_t = u[:, pos:pos+1, :] - dt = mx.softplus(dt + self.dt_bias) - dA = mx.exp(dt * A) - - x = mx.reshape(x, (-1, self.args.nheads, self.args.headdim)) - - ssm_state = cache[1] - dBx = mx.expand_dims(dt, -1) * mx.expand_dims(B, 1) * mx.expand_dims(x, -1) - ssm_state = ssm_state * mx.expand_dims(mx.expand_dims(dA, -1), -1) + dBx - - y = mx.sum(ssm_state * mx.expand_dims(mx.expand_dims(C, 1), 1), axis=-1) - y = y + mx.expand_dims(self.D, -1) * x - y = mx.reshape(y, (-1, self.args.nheads * self.args.headdim)) - - y = self.norm(y, z) - y = self.out_proj(y) + # Project input + zxbcdt = self.in_proj(u_t) + + # Calculate sizes + d_model = self.args.intermediate_size + d_state = self.args.state_size + n_heads = self.args.num_heads + d_head = self.args.head_dim + + # Correct splits for z, xBC, dt + splits = [ + d_model, # z size + d_model + 2 * d_state, # xBC size (delta, B, C) + n_heads # dt size + ] + + # Split the projected input + z = zxbcdt[:, :, :splits[0]] + xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]] + dt = zxbcdt[:, :, -splits[2]:] # Take last n_heads elements - # Update SSM state in cache - cache[1] = ssm_state + # Process dt + dt = mx.reshape(dt, (batch_size, n_heads)) + dt = mx.clip( + nn.softplus(dt + self.dt_bias), + self.args.time_step_min, + self.args.time_step_max + ) + dt = mx.maximum(dt, self.args.time_step_floor) - return mx.expand_dims(y, 1) + # Process convolution + xBC = self.conv1d(xBC, cache=cache, cache_idx=0) + xBC = silu(xBC) + + # Split convolved xBC into x, B, C + x = xBC[:, :, :d_model] + B = xBC[:, :, d_model:d_model + d_state] + C = xBC[:, :, -d_state:] + + # Reshape x into (batch, heads, dim) + x = mx.reshape(x, (batch_size, 1, n_heads, d_head)) + x = mx.squeeze(x, axis=1) # (batch, heads, dim) + + # Reshape B into (batch, heads, dim, state) + B = mx.reshape(B, (batch_size, 1, d_state)) + B = mx.broadcast_to(B, (batch_size, n_heads, d_state)) + B = mx.expand_dims(B, axis=2) # (batch, heads, 1, state) + + # Reshape C for later use + C = mx.reshape(C, (batch_size, 1, d_state)) + C = mx.broadcast_to(C, (batch_size, n_heads, d_state)) + C = mx.expand_dims(C, axis=3) # (batch, heads, state, 1) + + # Compute SSM updates + A = -mx.exp(self.A_log) + dA = mx.exp(dt * mx.expand_dims(A, 0)) + dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) # (batch, heads, 1, 1) + + # Prepare x for Bx computation + x = mx.expand_dims(x, axis=3) # (batch, heads, dim, 1) + + # Compute dBx with proper broadcasting + dBx = mx.matmul(x, B) # (batch, heads, dim, state) + + # Update state + ssm_state = cache[1] # (batch, heads, dim, state) + ssm_state = ssm_state * dA + dBx + cache[1] = ssm_state + + # Compute output + y = mx.matmul(ssm_state, C) # (batch, heads, dim, 1) + y = mx.squeeze(y, axis=-1) # (batch, heads, dim) + + # Add skip connection with D + y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) + + # Reshape to original dimensions + y = mx.reshape(y, (batch_size, 1, n_heads * d_head)) + + # Apply norm and output projection + y = self.norm(y + z) + y = self.out_proj(y) + + if self.args.residual_in_fp32: + y.astype(mx.float32) + + outputs.append(y) + + return mx.concatenate(outputs, axis=1) class ResidualBlock(nn.Module): @@ -287,7 +391,6 @@ class Model(nn.Module): self.model_type = args.model_type self.backbone = Mamba2(args) - # self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) @@ -302,16 +405,25 @@ class Model(nn.Module): else: logits = self.lm_head(x) + print('ouput') return logits - - def sanitize(self, weights): - for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: - weights[k] = v.moveaxis(2, 1) - return weights def make_cache(self): return [MambaCache() for _ in range(len(self.layers))] + + def sanitize(self, weights): + sanitized = {} + for k, v in weights.items(): + if "conv1d.weight" in k: + # Ensure weights are in correct shape (channels, 1, kernel_size) + if v.ndim == 2: + v = mx.expand_dims(v, axis=1) + elif v.ndim == 1: + v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0) + sanitized[k] = v + else: + sanitized[k] = v + return sanitized @property def layers(self): diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 1d223ce5..a44663fb 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -146,6 +146,8 @@ def linear_to_lora_layers( elif model.model_type == "mamba2": keys = set( [ + "mixer.in_proj", + "mixer.out_proj", ] ) else: