mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
clean up
This commit is contained in:
@@ -106,8 +106,7 @@ def pscan_f(A, X):
|
||||
# X : (B, D, L, N)
|
||||
|
||||
# modifies X in place by doing a parallel scan.
|
||||
# more formally, X will be populated by these values :
|
||||
# H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
|
||||
# more formally, X will be populated by these values: H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
|
||||
# which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
|
||||
|
||||
Aa = A
|
||||
@@ -164,11 +163,11 @@ def pscan(A_in, X_in):
|
||||
Applies the parallel scan operation, as defined above. Returns a new tensor.
|
||||
|
||||
Args:
|
||||
A_in : (B, L, ED, N)
|
||||
X_in : (B, L, ED, N)
|
||||
A_in: mx.array =-> Shape(B, L, ED, N)
|
||||
X_in: mx.array -> Shape (B, L, ED, N)
|
||||
|
||||
Returns:
|
||||
H : (B, L, ED, N)
|
||||
H: mx.array -> Shape (B, L, ED, N)
|
||||
"""
|
||||
|
||||
A = A_in[:].transpose(0, 2, 1, 3)
|
||||
@@ -277,14 +276,11 @@ class MambaBlock(nn.Module):
|
||||
|
||||
def selective_scan(self, x, delta, A, B, C, D):
|
||||
# x : (B, L, ED)
|
||||
# Δ : (B, L, ED)
|
||||
# A : (ED, N)
|
||||
# B : (B, L, N)
|
||||
# C : (B, L, N)
|
||||
# D : (ED)
|
||||
|
||||
# y : (B, L, ED)
|
||||
|
||||
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N)
|
||||
deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N)
|
||||
|
||||
@@ -296,18 +292,15 @@ class MambaBlock(nn.Module):
|
||||
|
||||
y = y + D * x
|
||||
|
||||
return y
|
||||
return y # (B, L, ED)
|
||||
|
||||
def selective_scan_seq(self, x, delta, A, B, C, D):
|
||||
# x : (B, L, ED)
|
||||
# Δ : (B, L, ED)
|
||||
# A : (ED, N)
|
||||
# B : (B, L, N)
|
||||
# C : (B, L, N)
|
||||
# D : (ED)
|
||||
|
||||
# y : (B, L, ED)
|
||||
|
||||
_, L, _ = x.shape
|
||||
|
||||
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N)
|
||||
@@ -328,7 +321,7 @@ class MambaBlock(nn.Module):
|
||||
|
||||
y = y + D * x
|
||||
|
||||
return y
|
||||
return y # (B, L, ED)
|
||||
|
||||
|
||||
def __call__(self, x, cache: MambaCache, layer_idx: int):
|
||||
|
Reference in New Issue
Block a user