mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-03 21:36:37 +08:00
update trainer/lora.py and adding DepthWiseConv1d because mlx 0.18.0 doesnt axepts groups parameter
This commit is contained in:
parent
409ddc427e
commit
264ba43707
@ -76,6 +76,46 @@ class MambaRMSNormGated(nn.Module):
|
|||||||
return self.weight * hidden_states
|
return self.weight * hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class DepthWiseConv1d(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
|
||||||
|
super().__init__()
|
||||||
|
assert in_channels == out_channels, "For depthwise conv, in_channels must equal out_channels"
|
||||||
|
self.channels = in_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
# For depthwise conv, we use groups equal to the number of channels
|
||||||
|
self.groups = self.channels if groups is None else groups
|
||||||
|
assert self.groups == self.channels, "For depthwise conv, groups must equal the number of channels"
|
||||||
|
|
||||||
|
# Weight shape: (channels, 1, kernel_size) for depthwise conv
|
||||||
|
self.weight = mx.random.normal((self.channels, 1, kernel_size))
|
||||||
|
self.bias = mx.zeros((self.channels,)) if bias else None
|
||||||
|
|
||||||
|
def __call__(self, x, cache=None):
|
||||||
|
B, L, C = x.shape
|
||||||
|
K = self.kernel_size
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
x = mx.concatenate([cache, x], axis=1)
|
||||||
|
else:
|
||||||
|
x = mx.pad(x, [(0, 0), (self.padding, 0), (0, 0)])
|
||||||
|
|
||||||
|
# Reshape for depthwise convolution
|
||||||
|
x = x.transpose(0, 2, 1) # (B, C, L)
|
||||||
|
|
||||||
|
# Perform depthwise convolution
|
||||||
|
y = mx.conv(x, self.weight, groups=self.groups)
|
||||||
|
|
||||||
|
# Reshape back
|
||||||
|
y = y.transpose(0, 2, 1) # (B, L, C)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
y = y + self.bias
|
||||||
|
|
||||||
|
return y, x.transpose(0, 2, 1)[:, -K:, :]
|
||||||
|
|
||||||
|
|
||||||
class Mamba2Mixer(nn.Module):
|
class Mamba2Mixer(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -97,7 +137,7 @@ class Mamba2Mixer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
|
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
|
||||||
self.conv1d = nn.Conv1d(
|
self.conv1d = DepthWiseConv1d(
|
||||||
in_channels=self.conv_dim,
|
in_channels=self.conv_dim,
|
||||||
out_channels=self.conv_dim,
|
out_channels=self.conv_dim,
|
||||||
bias=args.use_conv_bias,
|
bias=args.use_conv_bias,
|
||||||
|
@ -143,6 +143,11 @@ def linear_to_lora_layers(
|
|||||||
"mixer.out_proj",
|
"mixer.out_proj",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
elif model.model_type == "mamba2":
|
||||||
|
keys = set(
|
||||||
|
[
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Lora does not support {model.model_type}")
|
raise ValueError(f"Lora does not support {model.model_type}")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user