From ab4cf1d1cf554bf39954714d5692397886c42b23 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 20 Oct 2024 18:04:34 +0200 Subject: [PATCH] generation works but outputs gibberish --- llms/mlx_lm/models/cache.py | 27 ---- llms/mlx_lm/models/mamba2.py | 283 +++++++++++++---------------------- 2 files changed, 102 insertions(+), 208 deletions(-) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 4d04dac0..a6a56e0a 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -338,30 +338,3 @@ class MambaCache(_BaseCache): @state.setter def state(self, v): self.cache = v - - -class Mamba2Cache: - def __init__(self, num_layers): - self.conv_states = [None] * num_layers - self.ssm_states = [None] * num_layers - self.seqlen_offset = 0 - - def __getitem__(self, idx): - return (self.conv_states[idx], self.ssm_states[idx]) - - def __setitem__(self, idx, value): - self.conv_states[idx], self.ssm_states[idx] = value - - @property - def state(self): - return { - 'conv_states': self.conv_states, - 'ssm_states': self.ssm_states, - 'seqlen_offset': self.seqlen_offset - } - - @state.setter - def state(self, v): - self.conv_states = v['conv_states'] - self.ssm_states = v['ssm_states'] - self.seqlen_offset = v['seqlen_offset'] \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 7ac6ecc8..433f9716 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -1,16 +1,11 @@ -# Copyright © 2024 Apple Inc. - import math from dataclasses import dataclass, field -from typing import Tuple, Union, Optional - -import mlx.nn as nn +from typing import Tuple, Union import mlx.core as mx +import mlx.nn as nn from .base import BaseModelArgs -from .cache import Mamba2Cache - -# python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello how are you." +from .cache import MambaCache @dataclass class ModelArgs(BaseModelArgs): @@ -26,7 +21,7 @@ class ModelArgs(BaseModelArgs): n_groups: int use_bias: bool use_conv_bias: bool - initializer_range: float + initializer_range: float residual_in_fp32: bool time_step_min: float time_step_max: float @@ -48,6 +43,7 @@ class ModelArgs(BaseModelArgs): if self.time_step_rank == "auto": self.time_step_rank = math.ceil(self.hidden_size / 16) + class MambaRMSNormGated(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() @@ -60,7 +56,7 @@ class MambaRMSNormGated(nn.Module): variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states - + class DepthWiseConv1d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): @@ -95,177 +91,110 @@ class DepthWiseConv1d(nn.Module): y = y + self.bias return y, x[:, -K + 1 :, :] + - -class Mamba2Mixer(nn.Module): - def __init__(self, args, layer_idx): +class Mamba2Block(nn.Module): + def __init__(self, args: ModelArgs): super().__init__() - self.layer_idx = layer_idx - self.hidden_size = args.hidden_size + self.args = args self.intermediate_size = args.intermediate_size - self.num_heads = args.num_heads - self.head_dim = args.head_dim - self.ssm_state_size = args.state_size - self.n_groups = args.n_groups + self.time_step_rank = args.time_step_rank self.conv_kernel_size = args.conv_kernel - self.use_conv_bias = args.use_conv_bias - self.use_bias = args.use_bias - self.time_step_min = args.time_step_min - self.time_step_max = args.time_step_max - self.chunk_size = args.chunk_size - self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.hidden_size = args.hidden_size + self.state_size = args.state_size + self.num_heads = args.num_heads + self.head_dim = args.hidden_size // args.num_heads + self.n_groups = args.n_groups - projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size + self.conv1d = DepthWiseConv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + kernel_size=args.conv_kernel, + bias=args.use_conv_bias, + groups=self.conv_dim, + padding=args.conv_kernel - 1 + ) + + projection_size = args.intermediate_size + self.conv_dim + args.num_heads self.in_proj = nn.Linear( - self.hidden_size, + args.hidden_size, projection_size, bias=args.use_bias ) - self.conv1d = nn.Conv1d( - self.conv_dim, - self.conv_dim, - self.conv_kernel_size, - groups=self.conv_dim, - bias=self.use_conv_bias - ) + self.act = nn.SiLU() - self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon) - self.out_proj = nn.Linear( - self.intermediate_size, - self.hidden_size, - bias=self.use_bias - ) - self.A_log = mx.zeros(self.num_heads) - self.D = mx.ones(self.num_heads) - self.dt_bias = mx.zeros(self.num_heads) - - def __call__(self, input_states, cache): - batch_size, seq_len, _ = input_states.shape - dtype = input_states.dtype + self.A_log = mx.zeros(args.num_heads) + self.D = mx.ones((args.num_heads,)) + self.dt_bias = mx.zeros(args.num_heads) - projected_states = self.in_proj(input_states) - - # Calculate the sizes of each split - total_size = projected_states.shape[-1] - remaining_size = total_size - self.intermediate_size - self.conv_dim - self.num_heads - d_mlp = remaining_size // 2 - sizes = [ - d_mlp, - d_mlp, - self.intermediate_size, - self.conv_dim, - self.num_heads - ] - - # Perform the split operation - split_result = mx.split(projected_states, sizes, axis=-1) - - # Print debug information - print(f"Number of split parts: {len(split_result)}") - print(f"Shapes of split parts: {[part.shape for part in split_result]}") - - # Flexibly handle the split result - _, _, _, gate, hidden_states, dt = split_result + self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) + self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) - if cache is not None: - conv_state = cache.conv_states[self.layer_idx] - if conv_state is None: - # Initialize conv_state if it's None - conv_state = mx.zeros((batch_size, 1, self.conv_kernel_size, hidden_states.shape[-1])) - - conv_state = mx.roll(conv_state, -1, -2) # Roll along the kernel dimension - - # Reshape hidden_states to match conv_state dimensions - hidden_states_reshaped = hidden_states[:, None, None, :] - - conv_state = mx.concat([conv_state[:, :, :-1, :], hidden_states_reshaped], axis=-2) - cache.conv_states[self.layer_idx] = conv_state - - # Adjust the convolution operation - hidden_states = mx.sum(conv_state * self.conv1d.weight[:, :, None, :], axis=(-2, -1)) - - if self.use_conv_bias: - hidden_states += self.conv1d.bias - hidden_states = self.act(hidden_states)[:, None, :] + def ssm_step(self, x, state, dt_proj): + A = -mx.exp(self.A_log) + D = self.D + delta = nn.softplus(dt_proj + self.dt_bias) + + B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1) + + batch_size = B.shape[0] + B = B.reshape(batch_size, self.n_groups, self.state_size) + C = C.reshape(batch_size, -1, self.state_size) + + delta = delta.reshape(batch_size, self.num_heads, 1) + A = A.reshape(1, self.num_heads, 1) + + if state is None: + new_state = delta * B else: - hidden_states = hidden_states.transpose(0, 2, 1) - hidden_states = self.act(self.conv1d(hidden_states)).transpose(0, 2, 1) + new_state = delta * (B + state * mx.exp(delta * A)) + + y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2)) + y = y + D * x[:, :self.num_heads] + return y, new_state - hidden_states, B, C = mx.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], axis=-1) + def __call__(self, x, cache): + B, T, D = x.shape + if cache is None: + cache = [None, None] - A = -mx.exp(self.A_log.astype(mx.float32)) - dt = nn.softplus(dt + self.dt_bias) - dt = mx.clip(dt, self.time_step_min, self.time_step_max) + outputs = [] + for t in range(T): + xt = x[:, t, :] + xz = self.in_proj(xt) + + x_t, z_t, dt_proj = mx.split( + xz, + indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size], + axis=-1 + ) - hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).astype(mx.float32) - B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32) - C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32) - - B = mx.repeat(B, repeats=self.num_heads // self.n_groups, axis=2) - C = mx.repeat(C, repeats=self.num_heads // self.n_groups, axis=2) - - if cache is not None and cache.seqlen_offset > 0: - ssm_state = cache.ssm_states[self.layer_idx] - dA = mx.exp(dt[:, None, :, None] * A[None, :, None, None]) - dB = dt[:, None, :, None] * B - dBx = dB * hidden_states[:, :, :, None] - ssm_state = ssm_state * dA + dBx - cache.ssm_states[self.layer_idx] = ssm_state - - y = mx.sum(ssm_state * C[:, None, :, :], axis=-1) - D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim) - y = y + hidden_states * D - - y = y.reshape(batch_size, -1)[:, None, :] - else: - # Implement chunked computation here (simplified version) - pad_size = self.chunk_size - (seq_len % self.chunk_size) - hidden_states_padded = mx.pad(hidden_states, [(0, 0), (0, pad_size), (0, 0), (0, 0)]) - B_padded = mx.pad(B, [(0, 0), (0, pad_size), (0, 0), (0, 0)]) - C_padded = mx.pad(C, [(0, 0), (0, pad_size), (0, 0), (0, 0)]) - - chunks = seq_len // self.chunk_size + (1 if pad_size > 0 else 0) - y_list = [] - ssm_state = mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size)) - - for i in range(chunks): - chunk_start = i * self.chunk_size - chunk_end = (i + 1) * self.chunk_size - chunk_h = hidden_states_padded[:, chunk_start:chunk_end] - chunk_B = B_padded[:, chunk_start:chunk_end] - chunk_C = C_padded[:, chunk_start:chunk_end] - - chunk_dt = dt[:, chunk_start:chunk_end] - dA = mx.exp(chunk_dt[:, :, None, None] * A[None, None, :, None]) - dB = chunk_dt[:, :, None, None] * chunk_B - dBx = dB * chunk_h[:, :, :, None] - - chunk_y = mx.zeros_like(chunk_h) - for j in range(self.chunk_size): - ssm_state = ssm_state * dA[:, j] + dBx[:, j] - chunk_y[:, j] = mx.sum(ssm_state * chunk_C[:, j], axis=-1) - - y_list.append(chunk_y) - - y = mx.concat(y_list, axis=1) - if pad_size > 0: - y = y[:, :seq_len] - - D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim) - y = y + hidden_states * D - y = y.reshape(batch_size, seq_len, -1) - - y = self.norm(y, gate) - contextualized_states = self.out_proj(y.astype(dtype)) - - return contextualized_states + # Use the new DepthWiseConv1d with caching + conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) + x_t = conv_out.squeeze(1) + x_t = nn.silu(x_t) + y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj) + z_t = nn.silu(z_t) + + # Element-wise multiplication + output_t = y_t[:, :, None] * z_t[:, None, :] + + # Sum across the second dimension to match the intermediate_size + output_t = output_t.sum(axis=1) + + output_t = self.out_proj(output_t) + outputs.append(output_t) + + output = mx.stack(outputs, axis=1) + return output -class Mamba2Block(nn.Module): - def __init__(self, args: ModelArgs, layer_idx: int): +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): super().__init__() - self.mixer = Mamba2Mixer(args, layer_idx) + self.mixer = Mamba2Block(args) self.norm = nn.RMSNorm(args.hidden_size) def __call__(self, x: mx.array, cache): @@ -277,24 +206,16 @@ class Mamba2(nn.Module): super().__init__() self.args = args self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [Mamba2Block(args, idx) for idx in range(args.num_hidden_layers)] + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - def __call__( - self, - inputs: mx.array, - cache=None - ): - hidden_states = self.embeddings(inputs) - + def __call__(self, x: mx.array, cache): + x = self.embeddings(x) if cache is None: - cache = Mamba2Cache(len(self.layers)) - - for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, cache[i]) - - hidden_states = self.norm_f(hidden_states) - return hidden_states + cache = [None] * len(self.layers) + for layer, c in zip(self.layers, cache): + x = layer(x, c) + return self.norm_f(x) class Model(nn.Module): @@ -302,7 +223,10 @@ class Model(nn.Module): super().__init__() self.args = args self.model_type = args.model_type + self.backbone = Mamba2(args) + # self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) @@ -316,11 +240,8 @@ class Model(nn.Module): else: logits = self.lm_head(x) - print(logits) - print(logits.shape) - return logits - + def sanitize(self, weights): for k, v in weights.items(): if "conv1d.weight" in k and v.ndim == 3: @@ -328,7 +249,7 @@ class Model(nn.Module): return weights def make_cache(self): - return [Mamba2Cache(self.args.num_hidden_layers) for _ in range(len(self.layers))] + return [MambaCache() for _ in range(len(self.layers))] @property def layers(self):