clean up, reformating will come later

This commit is contained in:
Goekdeniz-Guelmez
2024-09-18 14:44:49 +02:00
parent 399de78f51
commit 13af75d88a

View File

@@ -30,8 +30,7 @@ class ModelArgs(BaseModelArgs):
time_step_floor: float
rescale_prenorm_residual: bool
use_cache: bool
pscan: bool = False # use parallel scan mode or sequential mode when training
use_mambapy: bool = False
pscan: bool = False
tie_word_embeddings: bool = True
@@ -102,13 +101,6 @@ def clamp(x, min=None, max=None):
def pscan_f(A, X):
# A : (B, D, L, N)
# X : (B, D, L, N)
# modifies X in place by doing a parallel scan.
# more formally, X will be populated by these values: H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
# which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
Aa = A
Xa = X
@@ -157,10 +149,10 @@ def pscan_f(A, X):
A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2)
X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2)
# main function, used in the Mamba model (mamba_mlx.py)
def pscan(A_in, X_in):
"""
Applies the parallel scan operation, as defined above. Returns a new tensor.
Applies the parallel scan operation, as defined above. Returns a new array.
Args:
A_in: mx.array =-> Shape(B, L, ED, N)
@@ -169,12 +161,9 @@ def pscan(A_in, X_in):
Returns:
H: mx.array -> Shape (B, L, ED, N)
"""
A = A_in[:].transpose(0, 2, 1, 3)
X = X_in[:].transpose(0, 2, 1, 3)
pscan_f(A, X)
return X.transpose(0, 2, 1, 3)
@@ -223,7 +212,6 @@ class MambaBlock(nn.Module):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
def ssm_step(self, x, ssm_state=None):
# Modify this method to work without state during training
A = -mx.exp(self.A_log) # (ED, N)
D = self.D # (ED,)
@@ -237,7 +225,6 @@ class MambaBlock(nn.Module):
BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N)
if self.training:
# During training, we don't use or update the state
new_ssm_state = BX
else:
if ssm_state is None:
@@ -255,7 +242,6 @@ class MambaBlock(nn.Module):
def ssm(self, x):
# x : (B, L, ED)
# y : (B, L, ED)
A = -mx.exp(self.A_log) # (ED, N)
@@ -280,7 +266,6 @@ class MambaBlock(nn.Module):
# B : (B, L, N)
# C : (B, L, N)
# D : (ED)
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N)
deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N)
@@ -300,7 +285,6 @@ class MambaBlock(nn.Module):
# B : (B, L, N)
# C : (B, L, N)
# D : (ED)
_, L, _ = x.shape
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N)
@@ -366,7 +350,6 @@ class ResidualBlock(nn.Module):
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache: MambaCache, layer_idx: int):
# Ensure x is 3D before passing to mixer
if x.ndim == 2:
x = mx.expand_dims(x, 1) # Make it (B, 1, D)