mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-11 03:36:42 +08:00
summarize segsum
This commit is contained in:
parent
932b196b48
commit
313d4a2ac9
@ -42,6 +42,10 @@ class ModelArgs(BaseModelArgs):
|
||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||
|
||||
|
||||
def segsum(x):
|
||||
return mx.cumsum(x, axis=-1).reshape(*x.shape[:-1], 1, x.shape[-1])
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||
super().__init__()
|
||||
@ -133,19 +137,6 @@ def ssd_forward_attn(
|
||||
return y, next_state
|
||||
|
||||
|
||||
def segsum(x):
|
||||
# x shape: [b, h, l]
|
||||
b, h, l = x.shape
|
||||
indices = mx.arange(l)
|
||||
mask = indices[:, None] >= indices[None, :] # [l, l] lower triangular mask
|
||||
# Expand x for broadcasting
|
||||
x_expanded = x.reshape(b, h, l, 1) # [b, h, l, 1]
|
||||
# Apply mask and sum
|
||||
masked_x = x_expanded * mask.reshape(1, 1, l, l) # [b, h, l, l]
|
||||
x_segsum = mx.sum(masked_x, axis=2, keepdims=True) # [b, h, 1, l]
|
||||
return x_segsum
|
||||
|
||||
|
||||
class Mamba2Block(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
Loading…
Reference in New Issue
Block a user