From 264ba43707000760862b6e39e6d5077c4a470754 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 2 Oct 2024 19:19:32 +0200 Subject: [PATCH] update trainer/lora.py and adding DepthWiseConv1d because mlx 0.18.0 doesnt axepts groups parameter --- llms/mlx_lm/models/mamba2.py | 42 +++++++++++++++++++++++++++++++++++- llms/mlx_lm/tuner/utils.py | 5 +++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 6e6e268c..f74ae826 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -74,6 +74,46 @@ class MambaRMSNormGated(nn.Module): variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states + + +class DepthWiseConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): + super().__init__() + assert in_channels == out_channels, "For depthwise conv, in_channels must equal out_channels" + self.channels = in_channels + self.kernel_size = kernel_size + self.padding = padding + + # For depthwise conv, we use groups equal to the number of channels + self.groups = self.channels if groups is None else groups + assert self.groups == self.channels, "For depthwise conv, groups must equal the number of channels" + + # Weight shape: (channels, 1, kernel_size) for depthwise conv + self.weight = mx.random.normal((self.channels, 1, kernel_size)) + self.bias = mx.zeros((self.channels,)) if bias else None + + def __call__(self, x, cache=None): + B, L, C = x.shape + K = self.kernel_size + + if cache is not None: + x = mx.concatenate([cache, x], axis=1) + else: + x = mx.pad(x, [(0, 0), (self.padding, 0), (0, 0)]) + + # Reshape for depthwise convolution + x = x.transpose(0, 2, 1) # (B, C, L) + + # Perform depthwise convolution + y = mx.conv(x, self.weight, groups=self.groups) + + # Reshape back + y = y.transpose(0, 2, 1) # (B, L, C) + + if self.bias is not None: + y = y + self.bias + + return y, x.transpose(0, 2, 1)[:, -K:, :] class Mamba2Mixer(nn.Module): @@ -97,7 +137,7 @@ class Mamba2Mixer(nn.Module): ) self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size - self.conv1d = nn.Conv1d( + self.conv1d = DepthWiseConv1d( in_channels=self.conv_dim, out_channels=self.conv_dim, bias=args.use_conv_bias, diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 7c78ee91..1d223ce5 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -143,6 +143,11 @@ def linear_to_lora_layers( "mixer.out_proj", ] ) + elif model.model_type == "mamba2": + keys = set( + [ + ] + ) else: raise ValueError(f"Lora does not support {model.model_type}")