From 313d4a2ac974c697bdccc98f2ba1e20a51140c1e Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 28 Feb 2025 15:04:03 +0100 Subject: [PATCH] summarize segsum --- llms/mlx_lm/models/mamba2.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 7c044dba..521d3e8d 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -42,6 +42,10 @@ class ModelArgs(BaseModelArgs): self.time_step_rank = math.ceil(self.hidden_size / 16) +def segsum(x): + return mx.cumsum(x, axis=-1).reshape(*x.shape[:-1], 1, x.shape[-1]) + + class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() @@ -133,19 +137,6 @@ def ssd_forward_attn( return y, next_state -def segsum(x): - # x shape: [b, h, l] - b, h, l = x.shape - indices = mx.arange(l) - mask = indices[:, None] >= indices[None, :] # [l, l] lower triangular mask - # Expand x for broadcasting - x_expanded = x.reshape(b, h, l, 1) # [b, h, l, 1] - # Apply mask and sum - masked_x = x_expanded * mask.reshape(1, 1, l, l) # [b, h, l, l] - x_segsum = mx.sum(masked_x, axis=2, keepdims=True) # [b, h, 1, l] - return x_segsum - - class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__()