mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-29 04:31:13 +08:00
inference fixed
This commit is contained in:
parent
117ffd3909
commit
57b1717cf5
@ -89,63 +89,29 @@ def ssd(x, A, B, C, chunk_size):
|
|||||||
|
|
||||||
|
|
||||||
class DepthWiseConv1d(nn.Module):
|
class DepthWiseConv1d(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
|
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels
|
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.groups = groups if groups is not None else in_channels
|
self.weight = mx.random.normal((self.channels, kernel_size, 1))
|
||||||
|
self.bias = mx.zeros((channels,)) if bias else None
|
||||||
|
|
||||||
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
|
def __call__(self, x, cache=None):
|
||||||
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
|
|
||||||
|
|
||||||
self.weight = mx.random.normal((in_channels, 1, kernel_size))
|
|
||||||
self.bias = mx.zeros((out_channels,)) if bias else None
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array, cache=None) -> mx.array:
|
|
||||||
B, L, C = x.shape
|
B, L, C = x.shape
|
||||||
K = self.kernel_size
|
groups, K, _ = self.weight.shape
|
||||||
|
|
||||||
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
|
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
# Access conv_state directly from cache[0]
|
x = mx.concatenate([cache, x], axis=1)
|
||||||
if cache[0] is None:
|
else:
|
||||||
cache[0] = mx.zeros((B, K-1, C))
|
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||||
|
|
||||||
x = mx.concatenate([cache[0], x], axis=1)
|
y = mx.conv_general(x, self.weight, groups=groups)
|
||||||
|
|
||||||
outputs = []
|
|
||||||
for c in range(C):
|
|
||||||
x_c = x[:, :, c]
|
|
||||||
x_c = mx.expand_dims(x_c, axis=1)
|
|
||||||
|
|
||||||
w_c = self.weight[c]
|
|
||||||
if w_c.ndim == 2:
|
|
||||||
w_c = mx.expand_dims(w_c, axis=0)
|
|
||||||
elif w_c.ndim == 1:
|
|
||||||
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
|
|
||||||
|
|
||||||
y_c = mx.conv_general(
|
|
||||||
x_c,
|
|
||||||
w_c,
|
|
||||||
stride=1,
|
|
||||||
padding=0
|
|
||||||
)
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
y_c = y_c + self.bias[c]
|
y = y + self.bias
|
||||||
|
|
||||||
y_c = mx.squeeze(y_c, axis=1)
|
return y, x[:, -K + 1:, :]
|
||||||
outputs.append(y_c)
|
|
||||||
|
|
||||||
y = mx.stack(outputs, axis=-1)
|
|
||||||
|
|
||||||
# Update cache directly using cache[0]
|
|
||||||
if cache is not None:
|
|
||||||
cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
class Mamba2Block(nn.Module):
|
class Mamba2Block(nn.Module):
|
||||||
@ -169,11 +135,9 @@ class Mamba2Block(nn.Module):
|
|||||||
# Convolution
|
# Convolution
|
||||||
conv_dim = self.d_inner + 2 * self.d_state
|
conv_dim = self.d_inner + 2 * self.d_state
|
||||||
self.conv1d = DepthWiseConv1d(
|
self.conv1d = DepthWiseConv1d(
|
||||||
in_channels=conv_dim,
|
channels=conv_dim,
|
||||||
out_channels=conv_dim,
|
|
||||||
kernel_size=self.d_conv,
|
kernel_size=self.d_conv,
|
||||||
bias=args.use_conv_bias,
|
bias=args.use_conv_bias
|
||||||
groups=conv_dim
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# SSM parameters
|
# SSM parameters
|
||||||
@ -206,7 +170,9 @@ class Mamba2Block(nn.Module):
|
|||||||
dt = mx.maximum(dt, self.args.time_step_floor)
|
dt = mx.maximum(dt, self.args.time_step_floor)
|
||||||
|
|
||||||
# Convolution and activation
|
# Convolution and activation
|
||||||
x_conv = self.conv1d(x_conv, cache=[cache[0] if cache else None])
|
x_conv, conv_state = self.conv1d(x_conv, cache[0] if cache else None)
|
||||||
|
if cache is not None:
|
||||||
|
cache[0] = conv_state
|
||||||
x_conv = silu(x_conv)
|
x_conv = silu(x_conv)
|
||||||
|
|
||||||
# Split conv output
|
# Split conv output
|
||||||
@ -329,6 +295,12 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
for k, v in weights.items():
|
||||||
|
if "conv1d.weight" in k and v.shape[-1] != 1:
|
||||||
|
weights[k] = v.moveaxis(2, 1)
|
||||||
|
return weights
|
||||||
|
|
||||||
def make_cache(self):
|
def make_cache(self):
|
||||||
return [MambaCache() for _ in range(len(self.layers))]
|
return [MambaCache() for _ in range(len(self.layers))]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user