inference fixed

This commit is contained in:
Goekdeniz-Guelmez 2024-11-21 22:25:58 +01:00
parent 117ffd3909
commit 57b1717cf5

View File

@ -89,63 +89,29 @@ def ssd(x, A, B, C, chunk_size):
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.channels = channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
self.weight = mx.random.normal((in_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
self.weight = mx.random.normal((self.channels, kernel_size, 1))
self.bias = mx.zeros((channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
def __call__(self, x, cache=None):
B, L, C = x.shape
K = self.kernel_size
groups, K, _ = self.weight.shape
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None:
# Access conv_state directly from cache[0]
if cache[0] is None:
cache[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
x = mx.concatenate([cache[0], x], axis=1)
outputs = []
for c in range(C):
x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1)
w_c = self.weight[c]
if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0)
elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0
)
if self.bias is not None:
y_c = y_c + self.bias[c]
y_c = mx.squeeze(y_c, axis=1)
outputs.append(y_c)
y = mx.conv_general(x, self.weight, groups=groups)
y = mx.stack(outputs, axis=-1)
if self.bias is not None:
y = y + self.bias
# Update cache directly using cache[0]
if cache is not None:
cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
return y, x[:, -K + 1:, :]
class Mamba2Block(nn.Module):
@ -169,11 +135,9 @@ class Mamba2Block(nn.Module):
# Convolution
conv_dim = self.d_inner + 2 * self.d_state
self.conv1d = DepthWiseConv1d(
in_channels=conv_dim,
out_channels=conv_dim,
channels=conv_dim,
kernel_size=self.d_conv,
bias=args.use_conv_bias,
groups=conv_dim
bias=args.use_conv_bias
)
# SSM parameters
@ -206,7 +170,9 @@ class Mamba2Block(nn.Module):
dt = mx.maximum(dt, self.args.time_step_floor)
# Convolution and activation
x_conv = self.conv1d(x_conv, cache=[cache[0] if cache else None])
x_conv, conv_state = self.conv1d(x_conv, cache[0] if cache else None)
if cache is not None:
cache[0] = conv_state
x_conv = silu(x_conv)
# Split conv output
@ -328,6 +294,12 @@ class Model(nn.Module):
logits = self.lm_head(hidden)
return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]