mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 19:31:20 +08:00
inference fixed
This commit is contained in:
parent
117ffd3909
commit
57b1717cf5
@ -89,63 +89,29 @@ def ssd(x, A, B, C, chunk_size):
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
|
||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = padding
|
||||
self.groups = groups if groups is not None else in_channels
|
||||
|
||||
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
|
||||
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
|
||||
|
||||
self.weight = mx.random.normal((in_channels, 1, kernel_size))
|
||||
self.bias = mx.zeros((out_channels,)) if bias else None
|
||||
self.weight = mx.random.normal((self.channels, kernel_size, 1))
|
||||
self.bias = mx.zeros((channels,)) if bias else None
|
||||
|
||||
def __call__(self, x: mx.array, cache=None) -> mx.array:
|
||||
def __call__(self, x, cache=None):
|
||||
B, L, C = x.shape
|
||||
K = self.kernel_size
|
||||
groups, K, _ = self.weight.shape
|
||||
|
||||
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
|
||||
|
||||
if cache is not None:
|
||||
# Access conv_state directly from cache[0]
|
||||
if cache[0] is None:
|
||||
cache[0] = mx.zeros((B, K-1, C))
|
||||
x = mx.concatenate([cache, x], axis=1)
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||
|
||||
x = mx.concatenate([cache[0], x], axis=1)
|
||||
|
||||
outputs = []
|
||||
for c in range(C):
|
||||
x_c = x[:, :, c]
|
||||
x_c = mx.expand_dims(x_c, axis=1)
|
||||
|
||||
w_c = self.weight[c]
|
||||
if w_c.ndim == 2:
|
||||
w_c = mx.expand_dims(w_c, axis=0)
|
||||
elif w_c.ndim == 1:
|
||||
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
|
||||
|
||||
y_c = mx.conv_general(
|
||||
x_c,
|
||||
w_c,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
if self.bias is not None:
|
||||
y_c = y_c + self.bias[c]
|
||||
|
||||
y_c = mx.squeeze(y_c, axis=1)
|
||||
outputs.append(y_c)
|
||||
y = mx.conv_general(x, self.weight, groups=groups)
|
||||
|
||||
y = mx.stack(outputs, axis=-1)
|
||||
if self.bias is not None:
|
||||
y = y + self.bias
|
||||
|
||||
# Update cache directly using cache[0]
|
||||
if cache is not None:
|
||||
cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
|
||||
|
||||
return y
|
||||
return y, x[:, -K + 1:, :]
|
||||
|
||||
|
||||
class Mamba2Block(nn.Module):
|
||||
@ -169,11 +135,9 @@ class Mamba2Block(nn.Module):
|
||||
# Convolution
|
||||
conv_dim = self.d_inner + 2 * self.d_state
|
||||
self.conv1d = DepthWiseConv1d(
|
||||
in_channels=conv_dim,
|
||||
out_channels=conv_dim,
|
||||
channels=conv_dim,
|
||||
kernel_size=self.d_conv,
|
||||
bias=args.use_conv_bias,
|
||||
groups=conv_dim
|
||||
bias=args.use_conv_bias
|
||||
)
|
||||
|
||||
# SSM parameters
|
||||
@ -206,7 +170,9 @@ class Mamba2Block(nn.Module):
|
||||
dt = mx.maximum(dt, self.args.time_step_floor)
|
||||
|
||||
# Convolution and activation
|
||||
x_conv = self.conv1d(x_conv, cache=[cache[0] if cache else None])
|
||||
x_conv, conv_state = self.conv1d(x_conv, cache[0] if cache else None)
|
||||
if cache is not None:
|
||||
cache[0] = conv_state
|
||||
x_conv = silu(x_conv)
|
||||
|
||||
# Split conv output
|
||||
@ -328,6 +294,12 @@ class Model(nn.Module):
|
||||
logits = self.lm_head(hidden)
|
||||
|
||||
return logits
|
||||
|
||||
def sanitize(self, weights):
|
||||
for k, v in weights.items():
|
||||
if "conv1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
return weights
|
||||
|
||||
def make_cache(self):
|
||||
return [MambaCache() for _ in range(len(self.layers))]
|
||||
|
Loading…
Reference in New Issue
Block a user