summarize segsum

This commit is contained in:
Goekdeniz-Guelmez 2025-02-28 15:04:03 +01:00
parent 932b196b48
commit 313d4a2ac9

View File

@ -42,6 +42,10 @@ class ModelArgs(BaseModelArgs):
self.time_step_rank = math.ceil(self.hidden_size / 16)
def segsum(x):
return mx.cumsum(x, axis=-1).reshape(*x.shape[:-1], 1, x.shape[-1])
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
@ -133,19 +137,6 @@ def ssd_forward_attn(
return y, next_state
def segsum(x):
# x shape: [b, h, l]
b, h, l = x.shape
indices = mx.arange(l)
mask = indices[:, None] >= indices[None, :] # [l, l] lower triangular mask
# Expand x for broadcasting
x_expanded = x.reshape(b, h, l, 1) # [b, h, l, 1]
# Apply mask and sum
masked_x = x_expanded * mask.reshape(1, 1, l, l) # [b, h, l, l]
x_segsum = mx.sum(masked_x, axis=2, keepdims=True) # [b, h, 1, l]
return x_segsum
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()