generation works! trying training now

This commit is contained in:
Goekdeniz-Guelmez
2024-10-22 18:25:59 +02:00
parent c1634ce81b
commit b9c57cd429
3 changed files with 537 additions and 327 deletions

View File

@@ -106,14 +106,16 @@ class Mamba2Block(nn.Module):
self.head_dim = args.hidden_size // args.num_heads
self.n_groups = args.n_groups
projection_size = 2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads
# projection_size = 2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads
projection_size = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(
args.hidden_size,
projection_size,
bias=args.use_bias
)
self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
# self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
self.conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
@@ -130,62 +132,125 @@ class Mamba2Block(nn.Module):
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
def ssm_step(self, x, state, dt):
def _ssd(self, x, A, B, C, chunk_size):
batch, seq_len, nheads, head_dim = x.shape
n_state = B.shape[-1]
h = mx.zeros((batch, nheads, head_dim, n_state))
ys = []
for i in range(0, seq_len, chunk_size):
chunk_size_i = min(chunk_size, seq_len - i)
xi = x[:, i:i + chunk_size_i]
Bi = B[:, i:i + chunk_size_i]
Ci = C[:, i:i + chunk_size_i]
for t in range(chunk_size_i):
h = h * mx.exp(A)[:, None, None]
h = h + mx.expand_dims(Bi[:, t], -2) * mx.expand_dims(xi[:, t], -1)
y = mx.sum(h * mx.expand_dims(Ci[:, t], -2), axis=-1)
ys.append(y)
y = mx.stack(ys, axis=1)
return y, h
def __call__(self, x: mx.array, cache) -> mx.array:
if cache is not None:
return self.step(x, cache)
A = -mx.exp(self.A_log)
D = self.D
dt = nn.softplus(dt + self.dt_bias)
zxbcdt = self.in_proj(u)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
z, xBC, dt = mx.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
axis=-1,
)
batch_size = B.shape[0]
B = B.reshape(batch_size, self.n_groups, self.state_size)
C = C.reshape(batch_size, -1, self.state_size)
dt = mx.softplus(dt + self.dt_bias)
dt = dt.reshape(batch_size, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1)
# Use the custom DepthWiseConv1d with cache
xBC = self.conv1d(xBC, cache, cache_idx=0)
xBC = mx.sigmoid(xBC) * xBC # SiLU activation
if state is None:
new_state = dt * B
else:
new_state = dt * (B + state * mx.exp(dt * A))
x, B, C = mx.split(
xBC,
[self.args.d_inner, self.args.d_state, self.args.d_state],
axis=-1
)
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
return y, new_state
x = self._reshape_heads(x, True)
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
y, ssm_state = self._ssd(
x * mx.expand_dims(dt, -1),
A * dt,
B,
C,
self.args.chunk_size
)
y = y + x * mx.expand_dims(self.D, -1)
y = self._reshape_heads(y, False)
y = self.norm(y, z)
y = self.out_proj(y)
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
if cache is not None:
cache[1] = ssm_state
outputs = []
for t in range(T):
xt = x[:, t, :]
zxbcdt = self.in_proj(xt)
z, xBC, dt = mx.split(
zxbcdt,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
axis=-1
)
return y
# Use the new DepthWiseConv1d with caching
conv_out, cache[0] = self.conv1d(mx.expand_dims(z, 1), cache[0])
z = conv_out.squeeze(1)
z = nn.silu(z)
y_t, cache[1] = self.ssm_step(z, cache[1], dt)
xBC = nn.silu(xBC)
# Element-wise multiplication
output_t = y_t[:, :, None] * xBC[:, None, :]
output_t = self.norm(output_t)
output_t = output_t.sum(axis=1)
output_t = self.out_proj(output_t)
outputs.append(output_t)
def step(self, x: mx.array, cache) -> mx.array:
"""Single inference step"""
assert x.shape[1] == 1, "Only one token can be decoded per inference step"
output = mx.stack(outputs, axis=1)
return output
zxbcdt = self.in_proj(mx.squeeze(x, 1))
z, xBC, dt = mx.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
axis=-1,
)
# Use the custom DepthWiseConv1d with cache
xBC = self.conv1d(xBC, cache, cache_idx=0)
xBC = mx.sigmoid(xBC) * xBC # SiLU activation
x, B, C = mx.split(
xBC,
[self.args.d_inner, self.args.d_state, self.args.d_state],
axis=-1
)
A = -mx.exp(self.A_log)
dt = mx.softplus(dt + self.dt_bias)
dA = mx.exp(dt * A)
x = mx.reshape(x, (-1, self.args.nheads, self.args.headdim))
ssm_state = cache[1]
dBx = mx.expand_dims(dt, -1) * mx.expand_dims(B, 1) * mx.expand_dims(x, -1)
ssm_state = ssm_state * mx.expand_dims(mx.expand_dims(dA, -1), -1) + dBx
y = mx.sum(ssm_state * mx.expand_dims(mx.expand_dims(C, 1), 1), axis=-1)
y = y + mx.expand_dims(self.D, -1) * x
y = mx.reshape(y, (-1, self.args.nheads * self.args.headdim))
y = self.norm(y, z)
y = self.out_proj(y)
# Update SSM state in cache
cache[1] = ssm_state
return mx.expand_dims(y, 1)
class ResidualBlock(nn.Module):