quick clean up and fix

This commit is contained in:
Goekdeniz-Guelmez 2024-10-11 21:08:13 +02:00
parent 9c075a71f8
commit 6f88dd59d7
3 changed files with 20 additions and 75 deletions

View File

@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba v1`, `Mamba v2` and support for `full-fine-tuning`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba version 1`, `Mamba version 2` and support for `full-fine-tuning`.

@ -1 +0,0 @@
Subproject commit 05e8773fc4ac1cd067e8a18a5c45372ce5178405

View File

@ -149,79 +149,25 @@ class Mamba2Mixer(nn.Module):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
# def ssm_step(self, x, state=None):
# A = -mx.exp(self.A_log)
# D = self.D
# deltaBC = self.x_proj(x)
# delta, B, C = mx.split(
# deltaBC,
# indices_or_sections=[
# self.time_step_rank,
# self.time_step_rank + self.ssm_state_size,
# ],
# axis=-1,
# )
# delta = nn.softplus(self.dt_proj(delta))
# new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
# if state is not None:
# new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
# y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
# y = y + D * x
# return y, new_state
def ssm_step(self, x, dt, state):
B, L, C = x.shape
print(f"x shape: {x.shape}")
projected_states = self.in_proj(x)
print(f"deltaBC shape: {projected_states.shape}")
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.state_size - self.num_heads) // 2
gate = projected_states[:, :, 2*d_mlp:2*d_mlp+self.intermediate_size]
conv_state = projected_states[:, :, 2*d_mlp+self.intermediate_size:2*d_mlp+self.intermediate_size+self.conv_dim]
time_step = projected_states[:, :, -self.num_heads:]
print(f"conv_state shape before reshape: {conv_state.shape}")
print(f"self.conv_dim: {self.conv_dim}")
# Reshape and handle the case where L=1
conv_state = conv_state.reshape(B, self.conv_dim, L)
if L == 1:
# If sequence length is 1, we need to pad to apply convolution
conv_state = mx.pad(conv_state, ((0, 0), (0, 0), (0, self.conv_kernel_size - 1)))
conv_out = self.conv1d(conv_state)
# If we padded, we need to remove the padding
if L == 1:
conv_out = conv_out[:, :, :L]
# Reshape back to (B, L, C)
conv_out = conv_out.transpose(0, 2, 1)
x_and_conv_out, B, C = mx.split(
conv_out,
[self.intermediate_size, self.n_groups * self.state_size],
axis=-1
def ssm_step(self, x, state=None):
A = -mx.exp(self.A_log)
D = self.D
deltaBC = self.x_proj(x)
delta, B, C = mx.split(
deltaBC,
indices_or_sections=[
self.time_step_rank,
self.time_step_rank + self.ssm_state_size,
],
axis=-1,
)
dt = nn.softplus(time_step + self.dt_bias)
dt = mx.clip(dt, self.args.time_step_min, self.args.time_step_max)
B = B.reshape(-1, self.num_heads, self.head_dim, self.state_size)
C = C.reshape(-1, self.num_heads, self.head_dim, self.state_size)
dA = mx.exp(dt[:, :, None, None] * A[None, :, None, None])
dB = dt[:, :, None, None] * B
new_state = state * dA + x_and_conv_out[:, :, None, None] * dB
y = mx.sum(new_state * C, axis=-1)
y = y + C[None, :, None] * x_and_conv_out
y = self.norm(y.reshape(-1, self.intermediate_size), gate)
output = self.out_proj(y)
return output, new_state
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None:
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
y = y + D * x
return y, new_state
def __call__(self, x, cache):
B, T, D = x.shape
@ -232,7 +178,7 @@ class Mamba2Mixer(nn.Module):
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
x_t, z_t = xz.split(indices_or_sections=2, axis=-1)
if x_t.shape[-1] != self.conv_dim:
raise ValueError(f"Expected conv input dim {self.conv_dim}, got {x_t.shape[-1]}")