mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 18:51:18 +08:00
nits
This commit is contained in:
parent
1a6688384d
commit
1d851069ea
@ -62,7 +62,6 @@ def silu(x):
|
|||||||
return x * mx.sigmoid(x)
|
return x * mx.sigmoid(x)
|
||||||
|
|
||||||
def ssd(x, A, B, C, chunk_size):
|
def ssd(x, A, B, C, chunk_size):
|
||||||
# Replace einsum operations with explicit reshape and matrix multiply
|
|
||||||
batch, seqlen, nheads, dim = x.shape
|
batch, seqlen, nheads, dim = x.shape
|
||||||
B = mx.expand_dims(B, axis=2)
|
B = mx.expand_dims(B, axis=2)
|
||||||
C = mx.expand_dims(C, axis=2)
|
C = mx.expand_dims(C, axis=2)
|
||||||
@ -74,7 +73,6 @@ def ssd(x, A, B, C, chunk_size):
|
|||||||
chunk = slice(i, min(i + chunk_size, seqlen))
|
chunk = slice(i, min(i + chunk_size, seqlen))
|
||||||
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
|
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
|
||||||
|
|
||||||
# Replace einsum with explicit operations
|
|
||||||
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
|
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
|
||||||
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
|
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
|
||||||
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
|
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
|
||||||
@ -82,7 +80,6 @@ def ssd(x, A, B, C, chunk_size):
|
|||||||
|
|
||||||
state = state * mx.expand_dims(dA, axis=-1) + dBx
|
state = state * mx.expand_dims(dA, axis=-1) + dBx
|
||||||
|
|
||||||
# Replace einsum with explicit operations
|
|
||||||
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
|
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
|
||||||
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
|
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
|
||||||
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
|
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
|
||||||
|
Loading…
Reference in New Issue
Block a user