mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
quick clean up and fix
This commit is contained in:
parent
9c075a71f8
commit
6f88dd59d7
@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
|
||||
- Markus Enzweiler: Added the `cvae` examples.
|
||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||
- Shiyu Li: Added the `Segment Anything Model`.
|
||||
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba v1`, `Mamba v2` and support for `full-fine-tuning`.
|
||||
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba version 1`, `Mamba version 2` and support for `full-fine-tuning`.
|
||||
|
@ -1 +0,0 @@
|
||||
Subproject commit 05e8773fc4ac1cd067e8a18a5c45372ce5178405
|
@ -149,79 +149,25 @@ class Mamba2Mixer(nn.Module):
|
||||
|
||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
|
||||
|
||||
# def ssm_step(self, x, state=None):
|
||||
# A = -mx.exp(self.A_log)
|
||||
# D = self.D
|
||||
# deltaBC = self.x_proj(x)
|
||||
# delta, B, C = mx.split(
|
||||
# deltaBC,
|
||||
# indices_or_sections=[
|
||||
# self.time_step_rank,
|
||||
# self.time_step_rank + self.ssm_state_size,
|
||||
# ],
|
||||
# axis=-1,
|
||||
# )
|
||||
# delta = nn.softplus(self.dt_proj(delta))
|
||||
# new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
||||
# if state is not None:
|
||||
# new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
|
||||
# y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
|
||||
# y = y + D * x
|
||||
# return y, new_state
|
||||
|
||||
def ssm_step(self, x, dt, state):
|
||||
B, L, C = x.shape
|
||||
print(f"x shape: {x.shape}")
|
||||
projected_states = self.in_proj(x)
|
||||
print(f"deltaBC shape: {projected_states.shape}")
|
||||
|
||||
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.state_size - self.num_heads) // 2
|
||||
|
||||
gate = projected_states[:, :, 2*d_mlp:2*d_mlp+self.intermediate_size]
|
||||
conv_state = projected_states[:, :, 2*d_mlp+self.intermediate_size:2*d_mlp+self.intermediate_size+self.conv_dim]
|
||||
time_step = projected_states[:, :, -self.num_heads:]
|
||||
|
||||
print(f"conv_state shape before reshape: {conv_state.shape}")
|
||||
print(f"self.conv_dim: {self.conv_dim}")
|
||||
|
||||
# Reshape and handle the case where L=1
|
||||
conv_state = conv_state.reshape(B, self.conv_dim, L)
|
||||
if L == 1:
|
||||
# If sequence length is 1, we need to pad to apply convolution
|
||||
conv_state = mx.pad(conv_state, ((0, 0), (0, 0), (0, self.conv_kernel_size - 1)))
|
||||
|
||||
conv_out = self.conv1d(conv_state)
|
||||
|
||||
# If we padded, we need to remove the padding
|
||||
if L == 1:
|
||||
conv_out = conv_out[:, :, :L]
|
||||
|
||||
# Reshape back to (B, L, C)
|
||||
conv_out = conv_out.transpose(0, 2, 1)
|
||||
|
||||
x_and_conv_out, B, C = mx.split(
|
||||
conv_out,
|
||||
[self.intermediate_size, self.n_groups * self.state_size],
|
||||
axis=-1
|
||||
def ssm_step(self, x, state=None):
|
||||
A = -mx.exp(self.A_log)
|
||||
D = self.D
|
||||
deltaBC = self.x_proj(x)
|
||||
delta, B, C = mx.split(
|
||||
deltaBC,
|
||||
indices_or_sections=[
|
||||
self.time_step_rank,
|
||||
self.time_step_rank + self.ssm_state_size,
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
dt = nn.softplus(time_step + self.dt_bias)
|
||||
dt = mx.clip(dt, self.args.time_step_min, self.args.time_step_max)
|
||||
|
||||
B = B.reshape(-1, self.num_heads, self.head_dim, self.state_size)
|
||||
C = C.reshape(-1, self.num_heads, self.head_dim, self.state_size)
|
||||
|
||||
dA = mx.exp(dt[:, :, None, None] * A[None, :, None, None])
|
||||
dB = dt[:, :, None, None] * B
|
||||
|
||||
new_state = state * dA + x_and_conv_out[:, :, None, None] * dB
|
||||
y = mx.sum(new_state * C, axis=-1)
|
||||
y = y + C[None, :, None] * x_and_conv_out
|
||||
|
||||
y = self.norm(y.reshape(-1, self.intermediate_size), gate)
|
||||
output = self.out_proj(y)
|
||||
|
||||
return output, new_state
|
||||
delta = nn.softplus(self.dt_proj(delta))
|
||||
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
||||
if state is not None:
|
||||
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
|
||||
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
|
||||
y = y + D * x
|
||||
return y, new_state
|
||||
|
||||
def __call__(self, x, cache):
|
||||
B, T, D = x.shape
|
||||
@ -232,7 +178,7 @@ class Mamba2Mixer(nn.Module):
|
||||
for t in range(T):
|
||||
xt = x[:, t, :]
|
||||
xz = self.in_proj(xt)
|
||||
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
|
||||
x_t, z_t = xz.split(indices_or_sections=2, axis=-1)
|
||||
|
||||
if x_t.shape[-1] != self.conv_dim:
|
||||
raise ValueError(f"Expected conv input dim {self.conv_dim}, got {x_t.shape[-1]}")
|
||||
|
Loading…
Reference in New Issue
Block a user