mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-13 12:46:43 +08:00
nits
This commit is contained in:
parent
b10afe3662
commit
80e88b4f4d
@ -65,6 +65,7 @@ def silu(x):
|
|||||||
|
|
||||||
def ssd(x, A, B, C, chunk_size):
|
def ssd(x, A, B, C, chunk_size):
|
||||||
batch, seqlen, nheads, dim = x.shape
|
batch, seqlen, nheads, dim = x.shape
|
||||||
|
|
||||||
B = mx.expand_dims(B, axis=2)
|
B = mx.expand_dims(B, axis=2)
|
||||||
C = mx.expand_dims(C, axis=2)
|
C = mx.expand_dims(C, axis=2)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user