This commit is contained in:
Goekdeniz-Guelmez 2024-12-10 18:18:59 +01:00
parent b10afe3662
commit 80e88b4f4d

View File

@ -65,6 +65,7 @@ def silu(x):
def ssd(x, A, B, C, chunk_size):
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)