update trainer/lora.py and adding DepthWiseConv1d because mlx 0.18.0 doesnt axepts groups parameter

This commit is contained in:
Goekdeniz-Guelmez 2024-10-02 19:19:32 +02:00
parent 409ddc427e
commit 264ba43707
2 changed files with 46 additions and 1 deletions

View File

@ -74,6 +74,46 @@ class MambaRMSNormGated(nn.Module):
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
assert in_channels == out_channels, "For depthwise conv, in_channels must equal out_channels"
self.channels = in_channels
self.kernel_size = kernel_size
self.padding = padding
# For depthwise conv, we use groups equal to the number of channels
self.groups = self.channels if groups is None else groups
assert self.groups == self.channels, "For depthwise conv, groups must equal the number of channels"
# Weight shape: (channels, 1, kernel_size) for depthwise conv
self.weight = mx.random.normal((self.channels, 1, kernel_size))
self.bias = mx.zeros((self.channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
K = self.kernel_size
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (self.padding, 0), (0, 0)])
# Reshape for depthwise convolution
x = x.transpose(0, 2, 1) # (B, C, L)
# Perform depthwise convolution
y = mx.conv(x, self.weight, groups=self.groups)
# Reshape back
y = y.transpose(0, 2, 1) # (B, L, C)
if self.bias is not None:
y = y + self.bias
return y, x.transpose(0, 2, 1)[:, -K:, :]
class Mamba2Mixer(nn.Module):
@ -97,7 +137,7 @@ class Mamba2Mixer(nn.Module):
)
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = nn.Conv1d(
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=args.use_conv_bias,

View File

@ -143,6 +143,11 @@ def linear_to_lora_layers(
"mixer.out_proj",
]
)
elif model.model_type == "mamba2":
keys = set(
[
]
)
else:
raise ValueError(f"Lora does not support {model.model_type}")