From 38e5801edb8073b5eaf78da1b8b19838d8ee172a Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 24 Nov 2024 16:26:45 +0100 Subject: [PATCH] loading codestral works but no tinference --- llms/mlx_lm/models/mamba2.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 141ffeee..a8e8e891 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -30,6 +30,7 @@ class ModelArgs(BaseModelArgs): rms_norm: bool chunk_size: int tie_word_embeddings: bool + intermediate_size: int = None use_cache: bool = True time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) time_step_rank: Union[int, str] = "auto" @@ -93,20 +94,21 @@ class DepthWiseConv1d(nn.Module): super().__init__() self.channels = channels self.kernel_size = kernel_size + self.groups = channels self.padding = padding self.weight = mx.random.normal((self.channels, kernel_size, 1)) self.bias = mx.zeros((channels,)) if bias else None def __call__(self, x, cache=None): B, L, C = x.shape - groups, K, _ = self.weight.shape + _, K, _ = self.weight.shape if cache is not None: x = mx.concatenate([cache, x], axis=1) else: x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) - y = mx.conv_general(x, self.weight, groups=groups) + y = mx.conv_general(x, self.weight, groups=self.groups) if self.bias is not None: y = y + self.bias @@ -124,16 +126,20 @@ class Mamba2Block(nn.Module): self.d_state = args.state_size self.d_conv = args.conv_kernel self.expand = args.expand - self.d_inner = int(self.expand * self.d_model) + if args.intermediate_size == None: + self.d_inner = int(self.expand * self.d_model) + else: + self.d_inner = args.intermediate_size + self.n_groups = args.n_groups self.n_heads = args.num_heads self.d_head = self.d_inner // self.n_heads # Input projection - d_in_proj = self.d_inner * 2 + self.d_state * 2 + self.n_heads + d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias) # Convolution - conv_dim = self.d_inner + 2 * self.d_state + conv_dim = self.d_inner + 2 * self.n_groups * self.d_state self.conv1d = DepthWiseConv1d( channels=conv_dim, kernel_size=self.d_conv,