This commit is contained in:
Goekdeniz-Guelmez
2024-09-18 11:13:22 +02:00
parent 511cdf89b1
commit 602c9f18bd

View File

@@ -106,8 +106,7 @@ def pscan_f(A, X):
# X : (B, D, L, N)
# modifies X in place by doing a parallel scan.
# more formally, X will be populated by these values :
# H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
# more formally, X will be populated by these values: H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
# which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
Aa = A
@@ -164,11 +163,11 @@ def pscan(A_in, X_in):
Applies the parallel scan operation, as defined above. Returns a new tensor.
Args:
A_in : (B, L, ED, N)
X_in : (B, L, ED, N)
A_in: mx.array =-> Shape(B, L, ED, N)
X_in: mx.array -> Shape (B, L, ED, N)
Returns:
H : (B, L, ED, N)
H: mx.array -> Shape (B, L, ED, N)
"""
A = A_in[:].transpose(0, 2, 1, 3)
@@ -277,14 +276,11 @@ class MambaBlock(nn.Module):
def selective_scan(self, x, delta, A, B, C, D):
# x : (B, L, ED)
# Δ : (B, L, ED)
# A : (ED, N)
# B : (B, L, N)
# C : (B, L, N)
# D : (ED)
# y : (B, L, ED)
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N)
deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N)
@@ -296,18 +292,15 @@ class MambaBlock(nn.Module):
y = y + D * x
return y
return y # (B, L, ED)
def selective_scan_seq(self, x, delta, A, B, C, D):
# x : (B, L, ED)
# Δ : (B, L, ED)
# A : (ED, N)
# B : (B, L, N)
# C : (B, L, N)
# D : (ED)
# y : (B, L, ED)
_, L, _ = x.shape
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N)
@@ -328,7 +321,7 @@ class MambaBlock(nn.Module):
y = y + D * x
return y
return y # (B, L, ED)
def __call__(self, x, cache: MambaCache, layer_idx: int):