diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 7c044dba..521d3e8d 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -42,6 +42,10 @@ class ModelArgs(BaseModelArgs): self.time_step_rank = math.ceil(self.hidden_size / 16) +def segsum(x): + return mx.cumsum(x, axis=-1).reshape(*x.shape[:-1], 1, x.shape[-1]) + + class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() @@ -133,19 +137,6 @@ def ssd_forward_attn( return y, next_state -def segsum(x): - # x shape: [b, h, l] - b, h, l = x.shape - indices = mx.arange(l) - mask = indices[:, None] >= indices[None, :] # [l, l] lower triangular mask - # Expand x for broadcasting - x_expanded = x.reshape(b, h, l, 1) # [b, h, l, 1] - # Apply mask and sum - masked_x = x_expanded * mask.reshape(1, 1, l, l) # [b, h, l, l] - x_segsum = mx.sum(masked_x, axis=2, keepdims=True) # [b, h, 1, l] - return x_segsum - - class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__()