diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 727c88dc..1f1fe7d6 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals: - Markus Enzweiler: Added the `cvae` examples. - Prince Canuma: Helped add support for `Starcoder2` models. - Shiyu Li: Added the `Segment Anything Model`. -- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba v1`, `Mamba v2` and support for `full-fine-tuning`. +- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba version 1`, `Mamba version 2` and support for `full-fine-tuning`. diff --git a/llms/mamba2-130m-hf b/llms/mamba2-130m-hf deleted file mode 160000 index 05e8773f..00000000 --- a/llms/mamba2-130m-hf +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 05e8773fc4ac1cd067e8a18a5c45372ce5178405 diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 8b0ee59e..790f3756 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -149,79 +149,25 @@ class Mamba2Mixer(nn.Module): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - # def ssm_step(self, x, state=None): - # A = -mx.exp(self.A_log) - # D = self.D - # deltaBC = self.x_proj(x) - # delta, B, C = mx.split( - # deltaBC, - # indices_or_sections=[ - # self.time_step_rank, - # self.time_step_rank + self.ssm_state_size, - # ], - # axis=-1, - # ) - # delta = nn.softplus(self.dt_proj(delta)) - # new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) - # if state is not None: - # new_state += state * mx.exp(mx.expand_dims(delta, -1) * A) - # y = (new_state @ mx.expand_dims(C, -1)).squeeze(2) - # y = y + D * x - # return y, new_state - - def ssm_step(self, x, dt, state): - B, L, C = x.shape - print(f"x shape: {x.shape}") - projected_states = self.in_proj(x) - print(f"deltaBC shape: {projected_states.shape}") - - d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.state_size - self.num_heads) // 2 - - gate = projected_states[:, :, 2*d_mlp:2*d_mlp+self.intermediate_size] - conv_state = projected_states[:, :, 2*d_mlp+self.intermediate_size:2*d_mlp+self.intermediate_size+self.conv_dim] - time_step = projected_states[:, :, -self.num_heads:] - - print(f"conv_state shape before reshape: {conv_state.shape}") - print(f"self.conv_dim: {self.conv_dim}") - - # Reshape and handle the case where L=1 - conv_state = conv_state.reshape(B, self.conv_dim, L) - if L == 1: - # If sequence length is 1, we need to pad to apply convolution - conv_state = mx.pad(conv_state, ((0, 0), (0, 0), (0, self.conv_kernel_size - 1))) - - conv_out = self.conv1d(conv_state) - - # If we padded, we need to remove the padding - if L == 1: - conv_out = conv_out[:, :, :L] - - # Reshape back to (B, L, C) - conv_out = conv_out.transpose(0, 2, 1) - - x_and_conv_out, B, C = mx.split( - conv_out, - [self.intermediate_size, self.n_groups * self.state_size], - axis=-1 + def ssm_step(self, x, state=None): + A = -mx.exp(self.A_log) + D = self.D + deltaBC = self.x_proj(x) + delta, B, C = mx.split( + deltaBC, + indices_or_sections=[ + self.time_step_rank, + self.time_step_rank + self.ssm_state_size, + ], + axis=-1, ) - - dt = nn.softplus(time_step + self.dt_bias) - dt = mx.clip(dt, self.args.time_step_min, self.args.time_step_max) - - B = B.reshape(-1, self.num_heads, self.head_dim, self.state_size) - C = C.reshape(-1, self.num_heads, self.head_dim, self.state_size) - - dA = mx.exp(dt[:, :, None, None] * A[None, :, None, None]) - dB = dt[:, :, None, None] * B - - new_state = state * dA + x_and_conv_out[:, :, None, None] * dB - y = mx.sum(new_state * C, axis=-1) - y = y + C[None, :, None] * x_and_conv_out - - y = self.norm(y.reshape(-1, self.intermediate_size), gate) - output = self.out_proj(y) - - return output, new_state + delta = nn.softplus(self.dt_proj(delta)) + new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) + if state is not None: + new_state += state * mx.exp(mx.expand_dims(delta, -1) * A) + y = (new_state @ mx.expand_dims(C, -1)).squeeze(2) + y = y + D * x + return y, new_state def __call__(self, x, cache): B, T, D = x.shape @@ -232,7 +178,7 @@ class Mamba2Mixer(nn.Module): for t in range(T): xt = x[:, t, :] xz = self.in_proj(xt) - x_t, z_t = xz.split(indices_or_sections=2, axis=1) + x_t, z_t = xz.split(indices_or_sections=2, axis=-1) if x_t.shape[-1] != self.conv_dim: raise ValueError(f"Expected conv input dim {self.conv_dim}, got {x_t.shape[-1]}")