mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-30 02:31:13 +08:00
update trainer/lora.py and adding DepthWiseConv1d because mlx 0.18.0 doesnt axepts groups parameter
This commit is contained in:
parent
409ddc427e
commit
264ba43707
@ -74,6 +74,46 @@ class MambaRMSNormGated(nn.Module):
|
||||
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
|
||||
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
|
||||
super().__init__()
|
||||
assert in_channels == out_channels, "For depthwise conv, in_channels must equal out_channels"
|
||||
self.channels = in_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = padding
|
||||
|
||||
# For depthwise conv, we use groups equal to the number of channels
|
||||
self.groups = self.channels if groups is None else groups
|
||||
assert self.groups == self.channels, "For depthwise conv, groups must equal the number of channels"
|
||||
|
||||
# Weight shape: (channels, 1, kernel_size) for depthwise conv
|
||||
self.weight = mx.random.normal((self.channels, 1, kernel_size))
|
||||
self.bias = mx.zeros((self.channels,)) if bias else None
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
B, L, C = x.shape
|
||||
K = self.kernel_size
|
||||
|
||||
if cache is not None:
|
||||
x = mx.concatenate([cache, x], axis=1)
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (self.padding, 0), (0, 0)])
|
||||
|
||||
# Reshape for depthwise convolution
|
||||
x = x.transpose(0, 2, 1) # (B, C, L)
|
||||
|
||||
# Perform depthwise convolution
|
||||
y = mx.conv(x, self.weight, groups=self.groups)
|
||||
|
||||
# Reshape back
|
||||
y = y.transpose(0, 2, 1) # (B, L, C)
|
||||
|
||||
if self.bias is not None:
|
||||
y = y + self.bias
|
||||
|
||||
return y, x.transpose(0, 2, 1)[:, -K:, :]
|
||||
|
||||
|
||||
class Mamba2Mixer(nn.Module):
|
||||
@ -97,7 +137,7 @@ class Mamba2Mixer(nn.Module):
|
||||
)
|
||||
|
||||
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
|
||||
self.conv1d = nn.Conv1d(
|
||||
self.conv1d = DepthWiseConv1d(
|
||||
in_channels=self.conv_dim,
|
||||
out_channels=self.conv_dim,
|
||||
bias=args.use_conv_bias,
|
||||
|
@ -143,6 +143,11 @@ def linear_to_lora_layers(
|
||||
"mixer.out_proj",
|
||||
]
|
||||
)
|
||||
elif model.model_type == "mamba2":
|
||||
keys = set(
|
||||
[
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Lora does not support {model.model_type}")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user